cbensimon HF Staff commited on
Commit
861c889
·
unverified ·
0 Parent(s):

Initial commit

Browse files
.circleci/config.yml ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ version: 2
2
+ jobs:
3
+ build:
4
+ working_directory: ~/autoprompt
5
+ docker:
6
+ - image: circleci/python:3.7
7
+ environment:
8
+ OMP_NUM_THREADS: 1
9
+ resource_class: medium
10
+ parallelism: 1
11
+ steps:
12
+ - checkout
13
+ - run: pip install --upgrade pip
14
+ - run: pip install -r requirements.txt
15
+ - run: python -m pytest --disable-warnings
16
+ - store_test_results:
17
+ path: test-results
18
+
.gitignore ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+
5
+ # C extensions
6
+ *.so
7
+
8
+ # Distribution / packaging
9
+ .Python
10
+ env/
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ *.egg-info/
23
+ .installed.cfg
24
+ *.egg
25
+
26
+ # PyInstaller
27
+ # Usually these files are written by a python script from a template
28
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
29
+ *.manifest
30
+ *.spec
31
+
32
+ # Installer logs
33
+ pip-log.txt
34
+ pip-delete-this-directory.txt
35
+
36
+ # Unit test / coverage reports
37
+ htmlcov/
38
+ .tox/
39
+ .coverage
40
+ .coverage.*
41
+ .cache
42
+ nosetests.xml
43
+ coverage.xml
44
+ *,cover
45
+
46
+ # Translations
47
+ *.mo
48
+ *.pot
49
+
50
+ # Django stuff:
51
+ *.log
52
+
53
+ # Sphinx documentation
54
+ docs/_build/
55
+
56
+ # PyBuilder
57
+ target/
58
+
59
+ # IPython checkpoints
60
+ .ipynb_checkpoints
61
+
62
+ # Miscellaneous
63
+ .DS_Store
64
+ .vscode/
65
+ out/
66
+ #data/
README.md ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # AutoPrompt
2
+ An automated method based on gradient-guided search to create prompts for a diverse set of NLP tasks. AutoPrompt demonstrates that masked language models (MLMs) have an innate ability to perform sentiment analysis, natural language inference, fact retrieval, and relation extraction. Check out our [website](https://ucinlp.github.io/autoprompt/) for the paper and more information.
3
+
4
+ ## Table of Contents
5
+ * [Setup](#setup)
6
+ * [Generating Prompts](#generating-prompts)
7
+ * [Label Token Selection](#label-token-selection)
8
+ * [Evaluation for Fact Retrieval and Relation Extraction](#evaluation-for-fact-retrieval-and-relation-extraction)
9
+ * [Citation](#citation)
10
+
11
+ ## Setup
12
+
13
+ ### 1. Create conda environment
14
+ ```
15
+ conda create -n autoprompt -y python=3.7 && conda activate autoprompt
16
+ ```
17
+
18
+ ### 2. Install dependecies
19
+ Install the required packages
20
+ ```
21
+ pip install -r requirements.txt
22
+ ```
23
+ Also download the spacy model
24
+ ```
25
+ python -m spacy download en
26
+ ```
27
+
28
+ ### 3. Download the data
29
+ The datasets for sentiment analysis, NLI, fact retrieval, and relation extraction are available to download [here](https://drive.google.com/drive/folders/1vVhgnSXmbuJb6GLPn_FErY1xDTh1xyv-?usp=sharing)
30
+
31
+ There are a couple different datasets for fact retrieval and relation extraction so here are brief overviews of each:
32
+ - Fact Retrieval
33
+ - `original`: We used the T-REx subset provided by LAMA as our test set and gathered more facts from the [original T-REx dataset](https://hadyelsahar.github.io/t-rex/) that we partitioned into train and dev sets
34
+ - `original_rob`: We filtered facts in `original` so that each object is a single token for both BERT and RoBERTa
35
+ - `trex`: We split the extra T-REx data collected (for train/val sets of `original`) into train, dev, test sets
36
+ - Relation Extraction
37
+ - Trimmed the `original` dataset to compensate for both the [RE baseline](https://github.com/UKPLab/emnlp2017-relation-extraction) and RoBERTa. We also excluded relations `P527` and `P1376` because the RE baseline doesn’t consider them.
38
+
39
+ ## Generating Prompts
40
+
41
+ ### Quick Overview of Templates
42
+ A prompt is constructed by mapping things like the original input and trigger tokens to a template that looks something like
43
+
44
+ `[CLS] {sub_label} [T] [T] [T] [P]. [SEP]`
45
+
46
+ The example above is a template for generating fact retrieval prompts with 3 trigger tokens where `{sub_label}` is a placeholder for the subject in any (subject, relation, object) triplet in fact retrieval. `[P]` denotes the placement of a special `[MASK]` token that will be used to "fill-in-the-blank" by the language model. Each trigger token in the set of trigger tokens that are shared across all prompts is denoted by `[T]`.
47
+
48
+ Depending on the language model (i.e. BERT or RoBERTa) you choose to generate prompts, the special tokens will be different. For BERT, stick `[CLS]` and `[SEP]` to each end of the template. For RoBERTa, use `<s>` and `</s>` instead.
49
+
50
+ ### Sentiment Analysis
51
+ ```
52
+ python -m autoprompt.create_trigger \
53
+ --train glue_data/SST-2/train.tsv \
54
+ --dev glue_data/SST-2/dev.tsv \
55
+ --template '<s> {sentence} [T] [T] [T] [P] . </s>' \
56
+ --label-map '{"0": ["Ġworse", "Ġincompetence", "ĠWorse", "Ġblamed", "Ġsucked"], "1": ["ĠCris", "Ġmarvelous", "Ġphilanthrop", "Ġvisionary", "Ġwonderful"]}' \
57
+ --num-cand 100 \
58
+ --accumulation-steps 30 \
59
+ --bsz 24 \
60
+ --eval-size 48 \
61
+ --iters 180 \
62
+ --model-name roberta-large
63
+ ```
64
+
65
+ ### Natural Language Inference
66
+ ```
67
+ python -m autoprompt.create_trigger --train SICK_TRAIN_ALL_S.tsv --dev SICK_DEV_ALL_S.tsv --template '<s> {sentence_A} [P] [T] [T] [T] [T] {sentence_B} </s>' --label-map '{"ENTAILMENT": ["\u0120Taiwan", "\u0120Ara", "abet"], "CONTRADICTION": ["\u0120Only", "\u0120Didn", "\u0120BUT"], "NEUTRAL": ["icy", "oder", "agna"]}' --bsz 120 --model-name roberta-large
68
+ ```
69
+
70
+ ### Fact Retrieval
71
+ ```
72
+ python -m autoprompt.create_trigger \
73
+ --train $path/train.jsonl \
74
+ --dev $path/dev.jsonl \
75
+ --template '<s> {sub_label} [T] [T] [T] [P] . </s>' \
76
+ --num-cand 10 \
77
+ --accumulation-steps 1 \
78
+ --model-name roberta-large \
79
+ --bsz 56 \
80
+ --eval-size 56 \
81
+ --iters 1000 \
82
+ --label-field 'obj_label' \
83
+ --tokenize-labels \
84
+ --filter \
85
+ --print-lama
86
+ ```
87
+
88
+ ### Relation Extraction
89
+ ```
90
+ python -m autoprompt.create_trigger \
91
+ --train $path/train.jsonl \
92
+ --dev $path/dev.jsonl \
93
+ --template '[CLS] {context} [SEP] {sub_label} [T] [T] [T] [P] . [SEP]' \
94
+ --num-cand 10 \
95
+ --accumulation-steps 1 \
96
+ --model-name bert-base-cased \
97
+ --bsz 32 \
98
+ --eval-size 32 \
99
+ --iters 500 \
100
+ --label-field 'obj_label' \
101
+ --tokenize-labels \
102
+ --filter \
103
+ --print-lama \
104
+ --use-ctx
105
+ ```
106
+
107
+ ## Label Token Selection
108
+
109
+ For sentiment analysis
110
+ ```
111
+ python -m autoprompt.label_search --train ../data/SST-2/train.tsv --template '[CLS] {sentence} [T] [T] [T] [P]. [SEP]' --label-map '{"0": 0, "1": 1}' --iters 50 --model-name 'bert-base-cased'
112
+ ```
113
+
114
+ For NLI
115
+ ```
116
+ python -m autoprompt.label_search --train ../data/SICK-E-balanced/3-balance/SICK_TRAIN_ALL_S.tsv --template '[CLS] {sentence} [T] [T] [T] [P]. [SEP]' --label-map '{"entailment": 0, "contradiction": 1, "neutral": 2}' --iters 50 --model-name 'bert-base-cased'
117
+ ```
118
+
119
+ ## Evaluation for Fact Retrieval and Relation Extraction
120
+
121
+ ### 1. Setup LAMA
122
+ Clone [our fork](https://github.com/taylorshin/LAMA) of the LAMA repo and follow the directions to set it up outside of the AutoPrompt repo.
123
+ We recommended creating a separate conda environment for LAMA due to different dependencies and requirements.
124
+
125
+ Copy the AutoPrompt data folder into the `data` directory of LAMA or set `data_path_pre` in `scripts/run_experiments.py` to a custom data location.
126
+
127
+ In order to get LAMA to work with RoBERTa, run the following commands:
128
+ ```
129
+ mkdir pre-trained_language_models/roberta
130
+ cd pre-trained_language_models/roberta
131
+ curl -O https://dl.fbaipublicfiles.com/fairseq/models/roberta.large.tar.gz
132
+ tar -xvzf roberta.large.tar.gz
133
+ ```
134
+
135
+ ### 2. Update prompts
136
+ Update the `data/relations.jsonl` file with your own automatically generated prompts
137
+
138
+ ### 3. Configure settings
139
+ To change evaluation settings, go to `scripts/run_experiments.py` and update the configurable values accordingly.
140
+ Note: each of the configurable settings are marked with a `[CONFIGURABLE]` comment.
141
+
142
+ - Uncomment the settings of the LM you want to evaluate with (and comment out the other LM settings) in the `LMs` list at the top of the file
143
+ - Update the `common_vocab_filename` field to the appropriate file. Anything evaluating both BERT and RoBERTa requires this field to be `common_vocab_cased_rob.txt` instead of the usual `common_vocab_cased.txt`.
144
+ - Set `use_ctx` to `True` if running evaluation for Relation Extraction
145
+ - Set `synthetic` to `True` for perturbed sentence evaluation for Relation Extraction
146
+ - In `get_TREx_parameters` function, set `data_path_pre` to the corresponding data path (e.g. `"../data/relation_extraction"` for Relation Extraction)
147
+
148
+ ### 4. Evaluate prompts
149
+ Run the evaluation code
150
+ ```
151
+ python scripts/run_experiments.py
152
+ ```
153
+
154
+ ### 4. Miscellaneous
155
+ Set `PYTHONPATH` if the following error occurs: `ModuleNotFoundError: No module named 'lama'`
156
+ ```
157
+ export PYTHONPATH="${PYTHONPATH}:/path/to/the/AutoPrompt/repo"
158
+ ```
159
+
160
+ ## Citation
161
+ ```
162
+ @inproceedings{autoprompt:emnlp20,
163
+ author = {Taylor Shin and Yasaman Razeghi and Robert L. Logan IV and Eric Wallace and Sameer Singh},
164
+ title = { {AutoPrompt}: Eliciting Knowledge from Language Models with Automatically Generated Prompts },
165
+ booktitle = {Empirical Methods in Natural Language Processing (EMNLP)},
166
+ year = {2020}
167
+ }
168
+ ```
app.py ADDED
@@ -0,0 +1,580 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import csv
2
+ from dataclasses import dataclass
3
+ import io
4
+ import json
5
+ import logging
6
+ import random
7
+ import sys
8
+ from typing import Dict, List
9
+
10
+ import pandas as pd
11
+ import streamlit as st
12
+ import torch
13
+ import transformers
14
+ from tqdm import tqdm
15
+
16
+ from autoprompt import utils
17
+ import autoprompt.create_trigger as ct
18
+
19
+
20
+ # logging.getLogger("streamlit.caching").addHandler(logging.StreamHandler(sys.stdout))
21
+ # logging.getLogger("streamlit.caching").setLevel(logging.DEBUG)
22
+
23
+ logger = logging.getLogger(__name__)
24
+
25
+
26
+ with open('assets/sst2_train.jsonl', 'r') as f:
27
+ DEFAULT_TRAIN = [json.loads(line) for line in f]
28
+
29
+
30
+ @dataclass
31
+ class CacheTest:
32
+ """
33
+ Stores whether the train button has been pressed for a given
34
+ set of inputs to run_autoprompt.
35
+ """
36
+ is_test: bool
37
+
38
+
39
+ class CacheMiss(Exception):
40
+ pass
41
+
42
+
43
+ def css_hack():
44
+ """
45
+ Inject some style into this app. ヽ(⌐■_■)ノ
46
+ """
47
+ st.markdown(
48
+ """
49
+ <style>
50
+ code {
51
+ color: #eec66d;
52
+ }
53
+ .css-gtmd9c a {
54
+ color: #6f98af;
55
+ }
56
+ </style>
57
+ """,
58
+ unsafe_allow_html=True
59
+ )
60
+
61
+
62
+ # Setting eq and frozen ensures that a __hash__ method is generated which is needed for caching to
63
+ # properly respond to changed args.
64
+ @dataclass(eq=True, frozen=True)
65
+ class Args:
66
+ # Configurable
67
+ template: str
68
+ model_name: str
69
+ iters: int
70
+ num_cand: int
71
+ accumulation_steps: int
72
+
73
+ # Non-Configurable
74
+ seed = 0
75
+ sentence_size = 64
76
+ tokenize_labels = True
77
+ filter = False
78
+ initial_trigger = None
79
+ label_field = "label"
80
+ bsz = 32
81
+ eval_size = 1
82
+
83
+ @classmethod
84
+ def from_streamlit(cls):
85
+ st.sidebar.image('assets/icon.png', width=150)
86
+ st.sidebar.markdown('### Training Parameters')
87
+ model_name = st.sidebar.selectbox(
88
+ "Model",
89
+ options=['roberta-large', 'bert-base-cased'],
90
+ help="Language model used for training and evaluation."
91
+ )
92
+ iters = int(st.sidebar.number_input(
93
+ "Iterations",
94
+ value=10,
95
+ min_value=1,
96
+ max_value=100,
97
+ help="Number of trigger search iterations. Larger values may yield better results."
98
+ ))
99
+ num_cand = int(st.sidebar.number_input(
100
+ "Number of Candidates",
101
+ value=25,
102
+ min_value=1,
103
+ max_value=100,
104
+ help="Number of candidate trigger token replacements to evaluate during each search "
105
+ "iteration. Larger values may yield better results."
106
+ ))
107
+ accumulation_steps = int(st.sidebar.number_input(
108
+ "Gradient Accumulation Steps",
109
+ value=1,
110
+ min_value=1,
111
+ max_value=10,
112
+ help="Number of gradient accumulation steps used during training. Larger values may yield "
113
+ "better results. Cannot be larger than half the dataset size."
114
+ ))
115
+ st.sidebar.markdown(
116
+ """
117
+ ### Template
118
+
119
+ Templates define how task-specific inputs are combined with trigger tokens to create
120
+ the prompt. They should contain the following placeholders:
121
+ - `{sentence}`: Placeholders for the task-specific input fields contain the field name
122
+ between curly brackets. For manually entered data the field name is `{sentence}`. For
123
+ uploaded csv's, field names should correspond to columns in the csv.
124
+ - `[T]`: Placeholder for a trigger token. These are learned from the training data.
125
+ - `[P]`: Placeholder for where to insert the [MASK] token that the model will predict
126
+ on.
127
+
128
+ Templates can also include manually written text (such as the
129
+ period in the default example below).
130
+ """
131
+ )
132
+ template = st.sidebar.text_input("Template", "{sentence} [T] [T] [T] [P].")
133
+ return cls(
134
+ template=template,
135
+ model_name=model_name,
136
+ iters=iters,
137
+ num_cand=num_cand,
138
+ accumulation_steps=accumulation_steps,
139
+ )
140
+
141
+
142
+ # TODO(rloganiv): This probably could use a better name...
143
+ @dataclass
144
+ class GlobalData:
145
+ device: torch.device
146
+ config: transformers.PretrainedConfig
147
+ model: transformers.PreTrainedModel
148
+ tokenizer: transformers.PreTrainedTokenizer
149
+ embeddings: torch.nn.Module
150
+ embedding_gradient: ct.GradientStorage
151
+ predictor: ct.PredictWrapper
152
+
153
+ @classmethod
154
+ @st.cache(allow_output_mutation=True)
155
+ def from_pretrained(cls, model_name):
156
+ logger.info(f'Loading pretrained model: {model_name}')
157
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
158
+ if torch.cuda.is_available():
159
+ st.write('CUDA is available')
160
+ else:
161
+ st.write('CUDA not available')
162
+ config, model, tokenizer = ct.load_pretrained(model_name)
163
+ model.to(device)
164
+ embeddings = ct.get_embeddings(model, config)
165
+ embedding_gradient = ct.GradientStorage(embeddings)
166
+ predictor = ct.PredictWrapper(model)
167
+ return cls(
168
+ device,
169
+ config,
170
+ model,
171
+ tokenizer,
172
+ embeddings,
173
+ embedding_gradient,
174
+ predictor
175
+ )
176
+
177
+
178
+ @dataclass
179
+ class Dataset:
180
+ train: List[int]
181
+ label_map: Dict[str, str]
182
+
183
+
184
+ def load_trigger_dataset(dataset, templatizer):
185
+ instances = []
186
+ for x in dataset:
187
+ instances.append(templatizer(x))
188
+ return instances
189
+
190
+
191
+ @st.cache(suppress_st_warning=True, allow_output_mutation=True, hash_funcs={CacheTest: lambda o: 0})
192
+ def run_autoprompt(args, dataset, cache_test):
193
+ if cache_test.is_test:
194
+ raise CacheMiss()
195
+
196
+ ct.set_seed(args.seed)
197
+ global_data = GlobalData.from_pretrained(args.model_name)
198
+
199
+ templatizer = utils.TriggerTemplatizer(
200
+ args.template,
201
+ global_data.config,
202
+ global_data.tokenizer,
203
+ label_field=args.label_field,
204
+ label_map=dataset.label_map,
205
+ tokenize_labels=args.tokenize_labels,
206
+ add_special_tokens=True,
207
+ )
208
+ evaluation_fn = ct.AccuracyFn(global_data.tokenizer, dataset.label_map, global_data.device,
209
+ tokenize_labels=args.tokenize_labels)
210
+
211
+ # Do not allow for initial trigger specification.
212
+ trigger_ids = [global_data.tokenizer.mask_token_id] * templatizer.num_trigger_tokens
213
+ trigger_ids = torch.tensor(trigger_ids, device=global_data.device).unsqueeze(0)
214
+ best_trigger_ids = trigger_ids.clone()
215
+
216
+ # Load datasets
217
+ logger.info('Loading datasets')
218
+ collator = utils.Collator(pad_token_id=global_data.tokenizer.pad_token_id)
219
+ try:
220
+ train_dataset = load_trigger_dataset(dataset.train, templatizer)
221
+ except KeyError as e:
222
+ raise RuntimeError(
223
+ 'A field in your template is not present in the uploaded dataset. '
224
+ f'Check that there is a column with the name: {e}'
225
+ )
226
+
227
+ train_loader = torch.utils.data.DataLoader(
228
+ train_dataset, batch_size=args.bsz, shuffle=True, collate_fn=collator)
229
+
230
+ progress = st.progress(0.0)
231
+ trigger_placeholder = st.empty()
232
+ best_dev_metric = -float('inf')
233
+ for i in range(args.iters):
234
+ logger.info(f'Iteration: {i}')
235
+ progress.progress(float(i)/args.iters)
236
+
237
+ current_trigger = ','.join(global_data.tokenizer.convert_ids_to_tokens(best_trigger_ids.squeeze(0)))
238
+ trigger_placeholder.markdown(f'**Current trigger**: {current_trigger}')
239
+
240
+ global_data.model.zero_grad()
241
+ train_iter = iter(train_loader)
242
+ averaged_grad = None
243
+
244
+ # Compute gradient of loss
245
+ for step in range(args.accumulation_steps):
246
+ try:
247
+ model_inputs, labels = next(train_iter)
248
+ except:
249
+ logger.warning(
250
+ 'Insufficient data for number of accumulation steps. '
251
+ 'Effective batch size will be smaller than specified.'
252
+ )
253
+ break
254
+ model_inputs = {k: v.to(global_data.device) for k, v in model_inputs.items()}
255
+ labels = labels.to(global_data.device)
256
+ predict_logits = global_data.predictor(model_inputs, trigger_ids)
257
+ loss = ct.get_loss(predict_logits, labels).mean()
258
+ loss.backward()
259
+
260
+ grad = global_data.embedding_gradient.get()
261
+ bsz, _, emb_dim = grad.size()
262
+ selection_mask = model_inputs['trigger_mask'].unsqueeze(-1)
263
+ grad = torch.masked_select(grad, selection_mask)
264
+ grad = grad.view(bsz, templatizer.num_trigger_tokens, emb_dim)
265
+
266
+ if averaged_grad is None:
267
+ averaged_grad = grad.sum(dim=0) / args.accumulation_steps
268
+ else:
269
+ averaged_grad += grad.sum(dim=0) / args.accumulation_steps
270
+
271
+ logger.info('Evaluating Candidates')
272
+ pbar = tqdm(range(args.accumulation_steps))
273
+ train_iter = iter(train_loader)
274
+
275
+ token_to_flip = i % templatizer.num_trigger_tokens
276
+ candidates = ct.hotflip_attack(averaged_grad[token_to_flip],
277
+ global_data.embeddings.weight,
278
+ increase_loss=False,
279
+ num_candidates=args.num_cand)
280
+ current_score = 0
281
+ candidate_scores = torch.zeros(args.num_cand, device=global_data.device)
282
+ denom = 0
283
+ for step in pbar:
284
+ try:
285
+ model_inputs, labels = next(train_iter)
286
+ except:
287
+ logger.warning(
288
+ 'Insufficient data for number of accumulation steps. '
289
+ 'Effective batch size will be smaller than specified.'
290
+ )
291
+ break
292
+ model_inputs = {k: v.to(global_data.device) for k, v in model_inputs.items()}
293
+ labels = labels.to(global_data.device)
294
+ with torch.no_grad():
295
+ predict_logits = global_data.predictor(model_inputs, trigger_ids)
296
+ eval_metric = evaluation_fn(predict_logits, labels)
297
+
298
+ # Update current score
299
+ current_score += eval_metric.sum()
300
+ denom += labels.size(0)
301
+
302
+ # NOTE: Instead of iterating over tokens to flip we randomly change just one each
303
+ # time so the gradients don't get stale.
304
+ for i, candidate in enumerate(candidates):
305
+
306
+ # if candidate.item() in filter_candidates:
307
+ # candidate_scores[i] = -1e32
308
+ # continue
309
+
310
+ temp_trigger = trigger_ids.clone()
311
+ temp_trigger[:, token_to_flip] = candidate
312
+ with torch.no_grad():
313
+ predict_logits = global_data.predictor(model_inputs, temp_trigger)
314
+ eval_metric = evaluation_fn(predict_logits, labels)
315
+
316
+ candidate_scores[i] += eval_metric.sum()
317
+
318
+ if (candidate_scores >= current_score).any():
319
+ logger.info('Better trigger detected.')
320
+ best_candidate_score = candidate_scores.max()
321
+ best_candidate_idx = candidate_scores.argmax()
322
+ trigger_ids[:, token_to_flip] = candidates[best_candidate_idx]
323
+ logger.info(f'Train metric: {best_candidate_score / (denom + 1e-13): 0.4f}')
324
+
325
+ # Skip eval
326
+ best_trigger_ids = trigger_ids.clone()
327
+
328
+ progress.progress(1.0)
329
+ current_trigger = ','.join(global_data.tokenizer.convert_ids_to_tokens(best_trigger_ids.squeeze(0)))
330
+ trigger_placeholder.markdown(f'**Current trigger**: {current_trigger}')
331
+
332
+ best_trigger_tokens = global_data.tokenizer.convert_ids_to_tokens(best_trigger_ids.squeeze(0))
333
+
334
+ train_output = predict_test(map(lambda x: x['sentence'], dataset.train), dataset.label_map,
335
+ templatizer, best_trigger_ids, global_data.tokenizer, global_data.predictor, args)
336
+
337
+ # Streamlit does not like accessing widgets across functions, which is
338
+ # problematic for this "live updating" widget which we want to still
339
+ # display even if the train output is cached. To get around this, we're
340
+ # going to delete the widget and replace it with a very similar looking
341
+ # widget outside the function...no one will ever notice ;)
342
+ trigger_placeholder.empty()
343
+
344
+ return (
345
+ best_trigger_tokens,
346
+ current_score/denom,
347
+ dataset.label_map,
348
+ templatizer,
349
+ best_trigger_ids,
350
+ global_data.tokenizer,
351
+ global_data.predictor,
352
+ args,
353
+ train_output
354
+ )
355
+
356
+
357
+ def predict_test(sentences, label_map, templatizer, best_trigger_ids, tokenizer, predictor, args):
358
+ # Evaluate clean
359
+ output = { 'sentences': [] }
360
+ any_label = None
361
+ for label in label_map.values():
362
+ output[label] = []
363
+ any_label = label
364
+ output['prompt'] = []
365
+ for sentence in sentences:
366
+ model_inputs, _ = templatizer({'sentence': sentence, 'label': any_label})
367
+ model_inputs = {k: v.to(best_trigger_ids.device) for k, v in model_inputs.items()}
368
+
369
+ prompt_ids = ct.replace_trigger_tokens(
370
+ model_inputs, best_trigger_ids, model_inputs['trigger_mask'])
371
+
372
+ prompt = ' '.join(tokenizer.convert_ids_to_tokens(prompt_ids['input_ids'][0]))
373
+ output['prompt'].append(prompt)
374
+
375
+ predict_logits = predictor(model_inputs, best_trigger_ids)
376
+ output['sentences'].append(sentence)
377
+ for label in label_map.values():
378
+ label_id = utils.encode_label(tokenizer=tokenizer, label=label, tokenize=args.tokenize_labels)
379
+ label_id = label_id.to(best_trigger_ids.device)
380
+ label_loss = ct.get_loss(predict_logits, label_id)
381
+ # st.write(sentence, label, label_loss)
382
+ output[label].append(label_loss.item())
383
+ return output
384
+
385
+
386
+ def manual_dataset(use_defaults):
387
+
388
+ num_train_instances = st.slider("Number of Train Instances", 4, 32, 8)
389
+ any_empty = False
390
+ dataset = []
391
+ data_col, label_col = st.beta_columns([3,1])
392
+ for i in range(num_train_instances):
393
+ default_data = DEFAULT_TRAIN[i]['sentence'] if use_defaults else ''
394
+ default_label = DEFAULT_TRAIN[i]['label'] if use_defaults else ''
395
+ with data_col:
396
+ data = st.text_input("Train Instance " + str(i+1), default_data)
397
+ with label_col:
398
+ label = st.text_input("Train Label " + str(i+1), default_label, max_chars=20)
399
+ if data == "" or label == "":
400
+ any_empty = True
401
+ dataset.append({'sentence': data, 'label': label})
402
+
403
+ label_set = list(set(map(lambda x: x['label'], dataset)))
404
+ label_idx = {x: i for i, x in enumerate(label_set)}
405
+ label_map = dict(map(lambda x: (x, x), label_set))
406
+
407
+ if any_empty:
408
+ st.warning('Waiting for data to be added')
409
+ st.stop()
410
+
411
+ if len(label_set) < 2:
412
+ st.warning('Not enough labels')
413
+ st.stop()
414
+
415
+ return Dataset(
416
+ train=dataset,
417
+ label_map=label_map
418
+ )
419
+
420
+
421
+ def csv_dataset():
422
+ st.markdown("""
423
+ Please upload your training and evaluation csv files.
424
+
425
+ Format restrictions:
426
+ - The file is required to have a header
427
+ - The column name of the output field should be `label`.
428
+ - Each file should contain no more than 64 rows.
429
+ """)
430
+ train_csv = st.file_uploader('Train', accept_multiple_files=False)
431
+
432
+ if train_csv is None:
433
+ st.stop()
434
+
435
+ with io.StringIO(train_csv.getvalue().decode('utf-8')) as f:
436
+ reader = csv.DictReader(f)
437
+ train_dataset = list(reader)
438
+ if len(train_dataset) > 64:
439
+ raise ValueError('Train dataset is too large. Please limit the number '
440
+ 'of examples to 64 or less.')
441
+
442
+ labels = set(x['label'] for x in train_dataset)
443
+ label_map = {x: x for x in labels}
444
+
445
+ return Dataset(
446
+ train=train_dataset,
447
+ label_map=label_map
448
+ )
449
+
450
+
451
+ def run():
452
+ css_hack()
453
+ st.title('AutoPrompt Demo')
454
+ st.markdown('''
455
+ For many years, the predominant approach for training machine learning
456
+ models to solve NLP tasks has been to use supervised training data to
457
+ estimate model parameters using maximum likelihood estimation or some
458
+ similar paradigm. Whether fitting a logistic regression model over a
459
+ bag-of-words, an LSTM over a sequence of GloVe embeddings, or finetuning a
460
+ language model such as ELMo or BERT, the approach is essentially the same.
461
+ However, as language models have become more and more capable of accurately
462
+ generating plausible text a new possibility for solving classification
463
+ tasks has emerged...
464
+
465
+ ## Prompting
466
+
467
+ Prompting is the method of converting classification tasks into
468
+ *fill-in-the-blanks* problems that can be solved by a language model **without
469
+ modifying the model's internals**. For example, to perform sentiment analysis,
470
+ we may take the sentence we wish to classify and append the text "Overall, this
471
+ movie was ____." and feed it into a language model like so:
472
+ ''')
473
+ # st.image('assets/bert-mouth.png', use_column_width=True)
474
+ st.markdown('''
475
+ By measuring whether the language model assigns a higher probability to
476
+ words that are associated with a **positive** sentiment ("good", "great",
477
+ and "fantastic") vs. words that are associated with a **negative**
478
+ sentiment ("bad", "terrible", or "awful") we can infer the
479
+ predicted label for the given input. So in this example, because the word "good"
480
+ has a higher probability than "bad", the predicted label is **positive**.
481
+
482
+ ## AutoPrompt
483
+
484
+ One issue that arises when using prompts is that it is not usually clear
485
+ how to best pose a task as a fill-in-the-blanks problem in a way that gets
486
+ the most performance from the language model. Even for a simple problem
487
+ like sentiment analysis, we don't know whether it is better to ask whether
488
+ a movie is good/bad, or whether you feel great/terrible about it, and for
489
+ more abstract problems like natural language inference it is difficult to
490
+ even know where to start.
491
+
492
+ To cure this writer's block we introduce **AutoPrompt**, a data-driven
493
+ approach for automatic prompt construction. The basic idea is
494
+ straightfoward: instead of writing a prompt, a user need only write a
495
+ **template** that specfies where the *task inputs* go along with placeholders for
496
+ a number of *trigger tokens* that will automatically be learned by the
497
+ model and the *predict token* that the model will fill in:
498
+ ''')
499
+ # st.image('assets/template.png', use_column_width=True)
500
+ st.markdown(
501
+ '''
502
+ In each iteration of the search process:
503
+ 1. The template is instantiated using a batch of training inputs.
504
+ 2. The loss of the model on each input is measured and used to identify a
505
+ number of candidate replacements for the current trigger tokens.
506
+ 3. The performance of each candidate is measured on another batch of
507
+ training data, and the best performing candidate is used in the next
508
+ iteration.
509
+
510
+ ### Demo
511
+
512
+ To give a better sense of how AutoPrompt works, we have provided a simple
513
+ interactive demo. You can generate a prompt using the training data we have
514
+ pre-populated for you, or alternatively write your own training/evaluation
515
+ instances or upload them using a csv below. In addition, you can vary
516
+ some of the training parameters, as well as the template using the sidebar
517
+ on the left.
518
+ '''
519
+ )
520
+ args = Args.from_streamlit()
521
+ dataset_mode = st.radio('How would you like to input your training data?',
522
+ options=['Example Data', 'Manual Input', 'From CSV'])
523
+
524
+ if dataset_mode == 'Example Data':
525
+ dataset = manual_dataset(use_defaults=True)
526
+ elif dataset_mode == 'Manual Input':
527
+ dataset = manual_dataset(use_defaults=False)
528
+ else:
529
+ dataset = csv_dataset()
530
+
531
+ button = st.empty()
532
+ clicked = button.button('Train')
533
+
534
+ if clicked:
535
+ trigger_tokens, eval_metric, label_map, templatizer, best_trigger_ids, tokenizer, predictor, args, train_output = run_autoprompt(args, dataset, cache_test=CacheTest(False))
536
+ else:
537
+ try:
538
+ trigger_tokens, eval_metric, label_map, templatizer, best_trigger_ids, tokenizer, predictor, args, train_output = run_autoprompt(args, dataset, cache_test=CacheTest(True))
539
+ except CacheMiss:
540
+ st.stop()
541
+ else:
542
+ button.empty()
543
+
544
+
545
+ st.markdown(f'**Final trigger**: {", ".join(trigger_tokens)}')
546
+ st.dataframe(pd.DataFrame(train_output).style.highlight_min(axis=1, color='#94666b'))
547
+ logger.debug('Dev metric')
548
+ st.write('Accuracy: ' + str(round(eval_metric.item()*100, 1)))
549
+ st.write("""
550
+ Et voila, you've now effectively finetuned a classifier using just a few
551
+ kilobytes of parameters (the tokens in the prompt). If you like you can
552
+ write down your "model" on the back of a napkin and take it with you.
553
+
554
+ ### Try it out yourself!
555
+
556
+ """)
557
+ sentence = st.text_input("Sentence", 'Enter a test input here')
558
+ pred_output = predict_test([sentence], label_map ,templatizer, best_trigger_ids, tokenizer, predictor, args)
559
+ st.dataframe(pd.DataFrame(pred_output).style.highlight_min(axis=1, color='#94666b'))
560
+
561
+ st.markdown('''
562
+ ## Where can I learn more?
563
+
564
+ If you are interested in learning more about AutoPrompt we recommend
565
+ [reading our paper](https://arxiv.org/abs/2010.15980) and [checking out our
566
+ code](https://github.com/ucinlp/autoprompt), or if you'd like you can also
567
+ watch our presentation at EMNLP 2020:
568
+ ''')
569
+ st.components.v1.iframe(
570
+ src="https://www.youtube.com/embed/IBMT_oOCBbc",
571
+ height=400,
572
+ )
573
+ st.markdown('Thanks!')
574
+
575
+
576
+ if __name__ == '__main__':
577
+ logging.basicConfig(level=logging.INFO,
578
+ stream=sys.stdout)
579
+ run()
580
+
app/.streamlit/config.toml ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ [server]
2
+ enableCORS = false
3
+ enableXsrfProtection = false
4
+
5
+ [theme]
6
+ primaryColor="#96666b"
7
+ backgroundColor="#28282d"
8
+ secondaryBackgroundColor="#333333"
9
+ textColor="#f3f3f3"
10
+ font="monospace"
assets/icon.png ADDED
assets/sst2_train.jsonl ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {"label": "terrible", "idx": "61123", "sentence": "in its yearning for the days "}
2
+ {"label": "great", "idx": "23159", "sentence": "of riveting set pieces "}
3
+ {"label": "great", "idx": "21277", "sentence": "all-star reunions "}
4
+ {"label": "terrible", "idx": "27987", "sentence": "( swimfan ) falls victim to sloppy plotting , an insultingly unbelievable final act and a villainess who is too crazy to be interesting . "}
5
+ {"label": "great", "idx": "23240", "sentence": "the leads ) are such a companionable couple "}
6
+ {"label": "great", "idx": "25838", "sentence": "astonishingly skillful and moving "}
7
+ {"label": "terrible", "idx": "35653", "sentence": "has been sacrificed for the sake of spectacle "}
8
+ {"label": "terrible", "idx": "49778", "sentence": "hard to imagine acting that could be any flatter "}
9
+ {"label": "great", "idx": "2403", "sentence": "the better video-game-based flicks , "}
10
+ {"label": "great", "idx": "135", "sentence": "so many of the challenges it poses for itself that one can forgive the film its flaws "}
11
+ {"label": "great", "idx": "42426", "sentence": "while somewhat less than it might have been , the film is a good one "}
12
+ {"label": "great", "idx": "6863", "sentence": "has a dashing and resourceful hero ; a lisping , reptilian villain ; big fights ; big hair ; lavish period scenery ; and a story "}
13
+ {"label": "terrible", "idx": "42330", "sentence": "in the end , the weight of water comes to resemble the kind of soft-core twaddle you 'd expect to see on showtime 's ` red shoe diaries . ' "}
14
+ {"label": "terrible", "idx": "57545", "sentence": "stuck in heaven because he 's afraid of his best-known creation ? "}
15
+ {"label": "terrible", "idx": "23530", "sentence": "can be as tiresome as 9 seconds of jesse helms ' anti- castro "}
16
+ {"label": "great", "idx": "54745", "sentence": "tackles the difficult subject of grief and loss with such life-embracing spirit that the theme does n't drag an audience down "}
17
+ {"label": "terrible", "idx": "30797", "sentence": "violence "}
18
+ {"label": "great", "idx": "30169", "sentence": "feminine energy , a tribute to the power of women to heal "}
19
+ {"label": "terrible", "idx": "42869", "sentence": "flat as a spoof "}
20
+ {"label": "terrible", "idx": "33313", "sentence": "( somebody suggested the stills might make a nice coffee table book ) "}
21
+ {"label": "great", "idx": "10766", "sentence": "brosnan 's finest non-bondish performance "}
22
+ {"label": "terrible", "idx": "15631", "sentence": "all seemed wasted like deniro 's once promising career and the once grand long beach boardwalk . "}
23
+ {"label": "great", "idx": "3401", "sentence": "a lot smarter and more unnerving than the sequels "}
24
+ {"label": "terrible", "idx": "29775", "sentence": "it may not be particularly innovative "}
25
+ {"label": "great", "idx": "39776", "sentence": "it 's packed with adventure and a worthwhile environmental message , so it 's great for the kids . "}
26
+ {"label": "terrible", "idx": "64246", "sentence": "mind ugly "}
27
+ {"label": "terrible", "idx": "57404", "sentence": "meandering , norton has to recite bland police procedural details , fiennes wanders around in an attempt to seem weird and distanced , hopkins looks like a drag queen "}
28
+ {"label": "terrible", "idx": "3766", "sentence": "pathetic idea "}
29
+ {"label": "great", "idx": "63925", "sentence": "has done his homework and "}
30
+ {"label": "great", "idx": "50687", "sentence": "a very compelling , sensitive , intelligent and almost cohesive piece "}
31
+ {"label": "terrible", "idx": "54945", "sentence": "turn and devolves "}
32
+ {"label": "great", "idx": "35874", "sentence": "about this silly , outrageous , ingenious thriller "}
autoprompt/__init__.py ADDED
File without changes
autoprompt/create_trigger.py ADDED
@@ -0,0 +1,523 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import argparse
3
+ import json
4
+ import logging
5
+ from pathlib import Path
6
+ import random
7
+
8
+ import numpy as np
9
+ import torch
10
+ import torch.nn.functional as F
11
+ from torch.utils.data import DataLoader
12
+ import transformers
13
+ from transformers import AutoConfig, AutoModelWithLMHead, AutoTokenizer
14
+ from tqdm import tqdm
15
+
16
+ import autoprompt.utils as utils
17
+
18
+
19
+ logger = logging.getLogger(__name__)
20
+
21
+
22
+ class GradientStorage:
23
+ """
24
+ This object stores the intermediate gradients of the output a the given PyTorch module, which
25
+ otherwise might not be retained.
26
+ """
27
+ def __init__(self, module):
28
+ self._stored_gradient = None
29
+ module.register_backward_hook(self.hook)
30
+
31
+ def hook(self, module, grad_in, grad_out):
32
+ self._stored_gradient = grad_out[0]
33
+
34
+ def get(self):
35
+ return self._stored_gradient
36
+
37
+
38
+ class PredictWrapper:
39
+ """
40
+ PyTorch transformers model wrapper. Handles necc. preprocessing of inputs for triggers
41
+ experiments.
42
+ """
43
+ def __init__(self, model):
44
+ self._model = model
45
+
46
+ def __call__(self, model_inputs, trigger_ids):
47
+ # Copy dict so pop operations don't have unwanted side-effects
48
+ model_inputs = model_inputs.copy()
49
+ trigger_mask = model_inputs.pop('trigger_mask')
50
+ predict_mask = model_inputs.pop('predict_mask')
51
+ model_inputs = replace_trigger_tokens(model_inputs, trigger_ids, trigger_mask)
52
+ logits, *_ = self._model(**model_inputs)
53
+ predict_logits = logits.masked_select(predict_mask.unsqueeze(-1)).view(logits.size(0), -1)
54
+ return predict_logits
55
+
56
+
57
+ class AccuracyFn:
58
+ """
59
+ Computing the accuracy when a label is mapped to multiple tokens is difficult in the current
60
+ framework, since the data generator only gives us the token ids. To get around this we
61
+ compare the target logp to the logp of all labels. If target logp is greater than all (but)
62
+ one of the label logps we know we are accurate.
63
+ """
64
+ def __init__(self, tokenizer, label_map, device, tokenize_labels=False):
65
+ self._all_label_ids = []
66
+ self._pred_to_label = []
67
+ logger.info(label_map)
68
+ for label, label_tokens in label_map.items():
69
+ self._all_label_ids.append(utils.encode_label(tokenizer, label_tokens, tokenize_labels).to(device))
70
+ self._pred_to_label.append(label)
71
+ logger.info(self._all_label_ids)
72
+
73
+ def __call__(self, predict_logits, gold_label_ids):
74
+ # Get total log-probability for the true label
75
+ gold_logp = get_loss(predict_logits, gold_label_ids)
76
+
77
+ # Get total log-probability for all labels
78
+ bsz = predict_logits.size(0)
79
+ all_label_logp = []
80
+ for label_ids in self._all_label_ids:
81
+ label_logp = get_loss(predict_logits, label_ids.repeat(bsz, 1))
82
+ all_label_logp.append(label_logp)
83
+ all_label_logp = torch.stack(all_label_logp, dim=-1)
84
+ _, predictions = all_label_logp.max(dim=-1)
85
+ predictions = [self._pred_to_label[x] for x in predictions.tolist()]
86
+
87
+ # Add up the number of entries where loss is greater than or equal to gold_logp.
88
+ ge_count = all_label_logp.le(gold_logp.unsqueeze(-1)).sum(-1)
89
+ correct = ge_count.le(1) # less than in case of num. prec. issues
90
+
91
+ return correct.float()
92
+
93
+ # TODO: @rloganiv - This is hacky. Replace with something sensible.
94
+ def predict(self, predict_logits):
95
+ bsz = predict_logits.size(0)
96
+ all_label_logp = []
97
+ for label_ids in self._all_label_ids:
98
+ label_logp = get_loss(predict_logits, label_ids.repeat(bsz, 1))
99
+ all_label_logp.append(label_logp)
100
+ all_label_logp = torch.stack(all_label_logp, dim=-1)
101
+ _, predictions = all_label_logp.max(dim=-1)
102
+ predictions = [self._pred_to_label[x] for x in predictions.tolist()]
103
+ return predictions
104
+
105
+
106
+ def load_pretrained(model_name):
107
+ """
108
+ Loads pretrained HuggingFace config/model/tokenizer, as well as performs required
109
+ initialization steps to facilitate working with triggers.
110
+ """
111
+ config = AutoConfig.from_pretrained(model_name)
112
+ model = AutoModelWithLMHead.from_pretrained(model_name)
113
+ model.eval()
114
+ tokenizer = AutoTokenizer.from_pretrained(model_name, add_prefix_space=True)
115
+ utils.add_task_specific_tokens(tokenizer)
116
+ return config, model, tokenizer
117
+
118
+
119
+ def set_seed(seed: int):
120
+ """Sets the relevant random seeds."""
121
+ random.seed(seed)
122
+ np.random.seed(seed)
123
+ torch.random.manual_seed(seed)
124
+ torch.cuda.manual_seed(seed)
125
+
126
+
127
+ def get_embeddings(model, config):
128
+ """Returns the wordpiece embedding module."""
129
+ base_model = getattr(model, config.model_type)
130
+ embeddings = base_model.embeddings.word_embeddings
131
+ return embeddings
132
+
133
+
134
+ def hotflip_attack(averaged_grad,
135
+ embedding_matrix,
136
+ increase_loss=False,
137
+ num_candidates=1,
138
+ filter=None):
139
+ """Returns the top candidate replacements."""
140
+ with torch.no_grad():
141
+ gradient_dot_embedding_matrix = torch.matmul(
142
+ embedding_matrix,
143
+ averaged_grad
144
+ )
145
+ if filter is not None:
146
+ gradient_dot_embedding_matrix -= filter
147
+ if not increase_loss:
148
+ gradient_dot_embedding_matrix *= -1
149
+ _, top_k_ids = gradient_dot_embedding_matrix.topk(num_candidates)
150
+
151
+ return top_k_ids
152
+
153
+
154
+ def replace_trigger_tokens(model_inputs, trigger_ids, trigger_mask):
155
+ """Replaces the trigger tokens in input_ids."""
156
+ out = model_inputs.copy()
157
+ input_ids = model_inputs['input_ids']
158
+ trigger_ids = trigger_ids.repeat(trigger_mask.size(0), 1)
159
+ try:
160
+ filled = input_ids.masked_scatter(trigger_mask, trigger_ids)
161
+ except RuntimeError:
162
+ filled = input_ids
163
+ out['input_ids'] = filled
164
+ return out
165
+
166
+
167
+ def get_loss(predict_logits, label_ids):
168
+ predict_logp = F.log_softmax(predict_logits, dim=-1)
169
+ target_logp = predict_logp.gather(-1, label_ids)
170
+ target_logp = target_logp - 1e32 * label_ids.eq(0) # Apply mask
171
+ target_logp = torch.logsumexp(target_logp, dim=-1)
172
+ return -target_logp
173
+
174
+
175
+ def isupper(idx, tokenizer):
176
+ """
177
+ Determines whether a token (e.g., word piece) begins with a capital letter.
178
+ """
179
+ _isupper = False
180
+ # We only want to check tokens that begin words. Since byte-pair encoding
181
+ # captures a prefix space, we need to check that the decoded token begins
182
+ # with a space, and has a capitalized second character.
183
+ if isinstance(tokenizer, transformers.GPT2Tokenizer):
184
+ decoded = tokenizer.decode([idx])
185
+ if decoded[0] == ' ' and decoded[1].isupper():
186
+ _isupper = True
187
+ # For all other tokenization schemes, we can just check the first character
188
+ # is capitalized.
189
+ elif tokenizer.decode([idx])[0].isupper():
190
+ _isupper = True
191
+ return _isupper
192
+
193
+
194
+ def run_model(args):
195
+
196
+ set_seed(args.seed)
197
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
198
+
199
+ logger.info('Loading model, tokenizer, etc.')
200
+ config, model, tokenizer = load_pretrained(args.model_name)
201
+ model.to(device)
202
+ embeddings = get_embeddings(model, config)
203
+ embedding_gradient = GradientStorage(embeddings)
204
+ predictor = PredictWrapper(model)
205
+
206
+ if args.label_map is not None:
207
+ label_map = json.loads(args.label_map)
208
+ logger.info(f"Label map: {label_map}")
209
+ else:
210
+ label_map = None
211
+ logger.info('No label map')
212
+
213
+ templatizer = utils.TriggerTemplatizer(
214
+ args.template,
215
+ config,
216
+ tokenizer,
217
+ label_map=label_map,
218
+ label_field=args.label_field,
219
+ tokenize_labels=args.tokenize_labels,
220
+ add_special_tokens=False,
221
+ use_ctx=args.use_ctx
222
+ )
223
+
224
+ # Obtain the initial trigger tokens and label mapping
225
+ if args.initial_trigger:
226
+ trigger_ids = tokenizer.convert_tokens_to_ids(args.initial_trigger)
227
+ logger.debug(f'Initial trigger: {args.initial_trigger}')
228
+ logger.debug(f'Trigger ids: {trigger_ids}')
229
+ assert len(trigger_ids) == templatizer.num_trigger_tokens
230
+ else:
231
+ trigger_ids = [tokenizer.mask_token_id] * templatizer.num_trigger_tokens
232
+ trigger_ids = torch.tensor(trigger_ids, device=device).unsqueeze(0)
233
+ best_trigger_ids = trigger_ids.clone()
234
+
235
+ # NOTE: Accuracy can only be computed if a fixed pool of labels is given, which currently
236
+ # requires the label map to be specified. Since producing a label map may be cumbersome (e.g.,
237
+ # for link prediction tasks), we just use (negative) loss as the evaluation metric in these cases.
238
+ if label_map:
239
+ evaluation_fn = AccuracyFn(tokenizer, label_map, device)
240
+ else:
241
+ evaluation_fn = lambda x, y: -get_loss(x, y)
242
+
243
+ logger.info('Loading datasets')
244
+ collator = utils.Collator(pad_token_id=tokenizer.pad_token_id)
245
+
246
+ if args.perturbed:
247
+ train_dataset = utils.load_augmented_trigger_dataset(args.train, templatizer, limit=args.limit)
248
+ else:
249
+ train_dataset = utils.load_trigger_dataset(args.train, templatizer, use_ctx=args.use_ctx, limit=args.limit)
250
+ train_loader = DataLoader(train_dataset, batch_size=args.bsz, shuffle=True, collate_fn=collator)
251
+
252
+ if args.perturbed:
253
+ dev_dataset = utils.load_augmented_trigger_dataset(args.dev, templatizer)
254
+ else:
255
+ dev_dataset = utils.load_trigger_dataset(args.dev, templatizer, use_ctx=args.use_ctx)
256
+ dev_loader = DataLoader(dev_dataset, batch_size=args.eval_size, shuffle=False, collate_fn=collator)
257
+
258
+ # To "filter" unwanted trigger tokens, we subtract a huge number from their logits.
259
+ filter = torch.zeros(tokenizer.vocab_size, dtype=torch.float32, device=device)
260
+ if args.filter:
261
+ logger.info('Filtering label tokens.')
262
+ if label_map:
263
+ for label_tokens in label_map.values():
264
+ label_ids = utils.encode_label(tokenizer, label_tokens).unsqueeze(0)
265
+ filter[label_ids] = -1e32
266
+ else:
267
+ for _, label_ids in train_dataset:
268
+ filter[label_ids] = -1e32
269
+ logger.info('Filtering special tokens and capitalized words.')
270
+ for word, idx in tokenizer.get_vocab().items():
271
+ if len(word) == 1 or idx >= tokenizer.vocab_size:
272
+ continue
273
+ # Filter special tokens.
274
+ if idx in tokenizer.all_special_ids:
275
+ logger.debug('Filtered: %s', word)
276
+ filter[idx] = -1e32
277
+ # Filter capitalized words (lazy way to remove proper nouns).
278
+ if isupper(idx, tokenizer):
279
+ logger.debug('Filtered: %s', word)
280
+ filter[idx] = -1e32
281
+
282
+ logger.info('Evaluating')
283
+ numerator = 0
284
+ denominator = 0
285
+ for model_inputs, labels in tqdm(dev_loader):
286
+ model_inputs = {k: v.to(device) for k, v in model_inputs.items()}
287
+ labels = labels.to(device)
288
+ with torch.no_grad():
289
+ predict_logits = predictor(model_inputs, trigger_ids)
290
+ numerator += evaluation_fn(predict_logits, labels).sum().item()
291
+ denominator += labels.size(0)
292
+ dev_metric = numerator / (denominator + 1e-13)
293
+ logger.info(f'Dev metric: {dev_metric}')
294
+
295
+ best_dev_metric = -float('inf')
296
+ # Measure elapsed time of trigger search
297
+ start = time.time()
298
+
299
+ for i in range(args.iters):
300
+
301
+ logger.info(f'Iteration: {i}')
302
+
303
+ logger.info('Accumulating Gradient')
304
+ model.zero_grad()
305
+
306
+ pbar = tqdm(range(args.accumulation_steps))
307
+ train_iter = iter(train_loader)
308
+ averaged_grad = None
309
+
310
+ # Accumulate
311
+ for step in pbar:
312
+
313
+ # Shuttle inputs to GPU
314
+ try:
315
+ model_inputs, labels = next(train_iter)
316
+ except:
317
+ logger.warning(
318
+ 'Insufficient data for number of accumulation steps. '
319
+ 'Effective batch size will be smaller than specified.'
320
+ )
321
+ break
322
+ model_inputs = {k: v.to(device) for k, v in model_inputs.items()}
323
+ labels = labels.to(device)
324
+ predict_logits = predictor(model_inputs, trigger_ids)
325
+ loss = get_loss(predict_logits, labels).mean()
326
+ loss.backward()
327
+
328
+ grad = embedding_gradient.get()
329
+ bsz, _, emb_dim = grad.size()
330
+ selection_mask = model_inputs['trigger_mask'].unsqueeze(-1)
331
+ grad = torch.masked_select(grad, selection_mask)
332
+ grad = grad.view(bsz, templatizer.num_trigger_tokens, emb_dim)
333
+
334
+ if averaged_grad is None:
335
+ averaged_grad = grad.sum(dim=0) / args.accumulation_steps
336
+ else:
337
+ averaged_grad += grad.sum(dim=0) / args.accumulation_steps
338
+
339
+ logger.info('Evaluating Candidates')
340
+ pbar = tqdm(range(args.accumulation_steps))
341
+ train_iter = iter(train_loader)
342
+
343
+ token_to_flip = random.randrange(templatizer.num_trigger_tokens)
344
+ candidates = hotflip_attack(averaged_grad[token_to_flip],
345
+ embeddings.weight,
346
+ increase_loss=False,
347
+ num_candidates=args.num_cand,
348
+ filter=filter)
349
+
350
+ current_score = 0
351
+ candidate_scores = torch.zeros(args.num_cand, device=device)
352
+ denom = 0
353
+ for step in pbar:
354
+
355
+ try:
356
+ model_inputs, labels = next(train_iter)
357
+ except:
358
+ logger.warning(
359
+ 'Insufficient data for number of accumulation steps. '
360
+ 'Effective batch size will be smaller than specified.'
361
+ )
362
+ break
363
+ model_inputs = {k: v.to(device) for k, v in model_inputs.items()}
364
+ labels = labels.to(device)
365
+ with torch.no_grad():
366
+ predict_logits = predictor(model_inputs, trigger_ids)
367
+ eval_metric = evaluation_fn(predict_logits, labels)
368
+
369
+ # Update current score
370
+ current_score += eval_metric.sum()
371
+ denom += labels.size(0)
372
+
373
+ # NOTE: Instead of iterating over tokens to flip we randomly change just one each
374
+ # time so the gradients don't get stale.
375
+ for i, candidate in enumerate(candidates):
376
+
377
+ # if candidate.item() in filter_candidates:
378
+ # candidate_scores[i] = -1e32
379
+ # continue
380
+
381
+ temp_trigger = trigger_ids.clone()
382
+ temp_trigger[:, token_to_flip] = candidate
383
+ with torch.no_grad():
384
+ predict_logits = predictor(model_inputs, temp_trigger)
385
+ eval_metric = evaluation_fn(predict_logits, labels)
386
+
387
+ candidate_scores[i] += eval_metric.sum()
388
+
389
+ # TODO: Something cleaner. LAMA templates can't have mask tokens, so if
390
+ # there are still mask tokens in the trigger then set the current score
391
+ # to -inf.
392
+ if args.print_lama:
393
+ if trigger_ids.eq(tokenizer.mask_token_id).any():
394
+ current_score = float('-inf')
395
+
396
+ if (candidate_scores > current_score).any():
397
+ logger.info('Better trigger detected.')
398
+ best_candidate_score = candidate_scores.max()
399
+ best_candidate_idx = candidate_scores.argmax()
400
+ trigger_ids[:, token_to_flip] = candidates[best_candidate_idx]
401
+ logger.info(f'Train metric: {best_candidate_score / (denom + 1e-13): 0.4f}')
402
+ else:
403
+ logger.info('No improvement detected. Skipping evaluation.')
404
+ continue
405
+
406
+ logger.info('Evaluating')
407
+ numerator = 0
408
+ denominator = 0
409
+ for model_inputs, labels in tqdm(dev_loader):
410
+ model_inputs = {k: v.to(device) for k, v in model_inputs.items()}
411
+ labels = labels.to(device)
412
+ with torch.no_grad():
413
+ predict_logits = predictor(model_inputs, trigger_ids)
414
+ numerator += evaluation_fn(predict_logits, labels).sum().item()
415
+ denominator += labels.size(0)
416
+ dev_metric = numerator / (denominator + 1e-13)
417
+
418
+ logger.info(f'Trigger tokens: {tokenizer.convert_ids_to_tokens(trigger_ids.squeeze(0))}')
419
+ logger.info(f'Dev metric: {dev_metric}')
420
+
421
+ # TODO: Something cleaner. LAMA templates can't have mask tokens, so if
422
+ # there are still mask tokens in the trigger then set the current score
423
+ # to -inf.
424
+ if args.print_lama:
425
+ if best_trigger_ids.eq(tokenizer.mask_token_id).any():
426
+ best_dev_metric = float('-inf')
427
+
428
+ if dev_metric > best_dev_metric:
429
+ logger.info('Best performance so far')
430
+ best_trigger_ids = trigger_ids.clone()
431
+ best_dev_metric = dev_metric
432
+
433
+ best_trigger_tokens = tokenizer.convert_ids_to_tokens(best_trigger_ids.squeeze(0))
434
+ logger.info(f'Best tokens: {best_trigger_tokens}')
435
+ logger.info(f'Best dev metric: {best_dev_metric}')
436
+ if args.print_lama:
437
+ # Templatize with [X] and [Y]
438
+ if args.use_ctx:
439
+ model_inputs, label_ids = templatizer({
440
+ 'sub_label': '[X]',
441
+ 'obj_label': tokenizer.lama_y,
442
+ 'context': ''
443
+ })
444
+ else:
445
+ model_inputs, label_ids = templatizer({
446
+ 'sub_label': '[X]',
447
+ 'obj_label': tokenizer.lama_y,
448
+ })
449
+ lama_template = model_inputs['input_ids']
450
+ # Instantiate trigger tokens
451
+ lama_template.masked_scatter_(
452
+ mask=model_inputs['trigger_mask'],
453
+ source=best_trigger_ids.cpu())
454
+ # Instantiate label token
455
+ lama_template.masked_scatter_(
456
+ mask=model_inputs['predict_mask'],
457
+ source=label_ids)
458
+ # Print LAMA JSON template
459
+ relation = args.train.parent.stem
460
+
461
+ # The following block of code is a bit hacky but whatever, it gets the job done
462
+ if args.use_ctx:
463
+ template = tokenizer.decode(lama_template.squeeze(0)[1:-1]).replace('[SEP] ', '').replace('</s> ', '').replace('[ X ]', '[X]')
464
+ else:
465
+ template = tokenizer.decode(lama_template.squeeze(0)[1:-1]).replace('[ X ]', '[X]')
466
+
467
+ out = {
468
+ 'relation': args.train.parent.stem,
469
+ 'template': template
470
+ }
471
+ print(json.dumps(out))
472
+
473
+
474
+ if __name__ == '__main__':
475
+ parser = argparse.ArgumentParser()
476
+ parser.add_argument('--train', type=Path, required=True, help='Train data path')
477
+ parser.add_argument('--dev', type=Path, required=True, help='Dev data path')
478
+ parser.add_argument('--template', type=str, help='Template string')
479
+ parser.add_argument('--label-map', type=str, default=None, help='JSON object defining label map')
480
+
481
+ # LAMA-specific
482
+ parser.add_argument('--tokenize-labels', action='store_true',
483
+ help='If specified labels are split into word pieces.'
484
+ 'Needed for LAMA probe experiments.')
485
+ parser.add_argument('--filter', action='store_true',
486
+ help='If specified, filter out special tokens and gold objects.'
487
+ 'Furthermore, tokens starting with capital '
488
+ 'letters will not appear in triggers. Lazy '
489
+ 'approach for removing proper nouns.')
490
+ parser.add_argument('--print-lama', action='store_true',
491
+ help='Prints best trigger in LAMA format.')
492
+
493
+ parser.add_argument('--initial-trigger', nargs='+', type=str, default=None, help='Manual prompt')
494
+ parser.add_argument('--label-field', type=str, default='label',
495
+ help='Name of the label field')
496
+
497
+ parser.add_argument('--bsz', type=int, default=32, help='Batch size')
498
+ parser.add_argument('--eval-size', type=int, default=256, help='Eval size')
499
+ parser.add_argument('--iters', type=int, default=100,
500
+ help='Number of iterations to run trigger search algorithm')
501
+ parser.add_argument('--accumulation-steps', type=int, default=10)
502
+ parser.add_argument('--model-name', type=str, default='bert-base-cased',
503
+ help='Model name passed to HuggingFace AutoX classes.')
504
+ parser.add_argument('--seed', type=int, default=0)
505
+ parser.add_argument('--limit', type=int, default=None)
506
+ parser.add_argument('--use-ctx', action='store_true',
507
+ help='Use context sentences for relation extraction only')
508
+ parser.add_argument('--perturbed', action='store_true',
509
+ help='Perturbed sentence evaluation of relation extraction: replace each object in dataset with a random other object')
510
+ parser.add_argument('--patience', type=int, default=5)
511
+ parser.add_argument('--num-cand', type=int, default=10)
512
+ parser.add_argument('--sentence-size', type=int, default=50)
513
+
514
+ parser.add_argument('--debug', action='store_true')
515
+ args = parser.parse_args()
516
+
517
+ if args.debug:
518
+ level = logging.DEBUG
519
+ else:
520
+ level = logging.INFO
521
+ logging.basicConfig(level=level)
522
+
523
+ run_model(args)
autoprompt/finetune.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Script for running finetuning on glue tasks.
3
+
4
+ Largely copied from:
5
+ https://github.com/huggingface/transformers/blob/master/examples/text-classification/run_glue.py
6
+ """
7
+ import argparse
8
+ import logging
9
+ from pathlib import Path
10
+ import random
11
+
12
+ import numpy as np
13
+ import torch
14
+ import torch.nn.functional as F
15
+ from torch.utils.data import DataLoader
16
+ from torch.optim.lr_scheduler import LambdaLR
17
+ import transformers
18
+ from transformers import (
19
+ AdamW, AutoConfig, AutoModelForSequenceClassification, AutoTokenizer
20
+ )
21
+ from tqdm import tqdm
22
+
23
+ import autoprompt.utils as utils
24
+
25
+
26
+ logger = logging.getLogger(__name__)
27
+
28
+
29
+ def set_seed(seed: int):
30
+ """Sets the relevant random seeds."""
31
+ random.seed(seed)
32
+ np.random.seed(seed)
33
+ torch.random.manual_seed(seed)
34
+ torch.cuda.manual_seed(seed)
35
+
36
+
37
+ def get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, last_epoch=-1):
38
+ """ Create a schedule with a learning rate that decreases linearly after
39
+ linearly increasing during a warmup period.
40
+
41
+ From:
42
+ https://github.com/uds-lsv/bert-stable-fine-tuning/blob/master/src/transformers/optimization.py
43
+ """
44
+
45
+ def lr_lambda(current_step):
46
+ if current_step < num_warmup_steps:
47
+ return float(current_step) / float(max(1, num_warmup_steps))
48
+ return max(
49
+ 0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps))
50
+ )
51
+
52
+ return LambdaLR(optimizer, lr_lambda, last_epoch)
53
+
54
+
55
+ def main(args):
56
+ set_seed(args.seed)
57
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
58
+
59
+ config = AutoConfig.from_pretrained(args.model_name, num_labels=args.num_labels)
60
+ tokenizer = AutoTokenizer.from_pretrained(args.model_name)
61
+ model = AutoModelForSequenceClassification.from_pretrained(args.model_name, config=config)
62
+ model.to(device)
63
+
64
+ collator = utils.Collator(pad_token_id=tokenizer.pad_token_id)
65
+ train_dataset, label_map = utils.load_classification_dataset(
66
+ args.train,
67
+ tokenizer,
68
+ args.field_a,
69
+ args.field_b,
70
+ args.label_field,
71
+ limit=args.limit
72
+ )
73
+ train_loader = DataLoader(train_dataset, batch_size=args.bsz, shuffle=True, collate_fn=collator)
74
+ dev_dataset, _ = utils.load_classification_dataset(
75
+ args.dev,
76
+ tokenizer,
77
+ args.field_a,
78
+ args.field_b,
79
+ args.label_field,
80
+ label_map
81
+ )
82
+ dev_loader = DataLoader(dev_dataset, batch_size=args.bsz, shuffle=False, collate_fn=collator)
83
+ test_dataset, _ = utils.load_classification_dataset(
84
+ args.test,
85
+ tokenizer,
86
+ args.field_a,
87
+ args.field_b,
88
+ args.label_field,
89
+ label_map
90
+ )
91
+ test_loader = DataLoader(test_dataset, batch_size=args.bsz, shuffle=False, collate_fn=collator)
92
+
93
+ if args.bias_correction:
94
+ betas = (0.9, 0.999)
95
+ else:
96
+ betas = (0.0, 0.000)
97
+
98
+ optimizer = AdamW(
99
+ model.parameters(),
100
+ lr=args.lr,
101
+ weight_decay=1e-2,
102
+ betas=betas
103
+ )
104
+
105
+ # Use suggested learning rate scheduler
106
+ num_training_steps = len(train_dataset) * args.epochs // args.bsz
107
+ num_warmup_steps = num_training_steps // 10
108
+ scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps,
109
+ num_training_steps)
110
+
111
+ if not args.ckpt_dir.exists():
112
+ logger.info(f'Making checkpoint directory: {args.ckpt_dir}')
113
+ args.ckpt_dir.mkdir(parents=True)
114
+ elif not args.force_overwrite:
115
+ raise RuntimeError('Checkpoint directory already exists.')
116
+
117
+ try:
118
+ best_accuracy = 0
119
+ for epoch in range(args.epochs):
120
+ logger.info('Training...')
121
+ model.train()
122
+ avg_loss = utils.ExponentialMovingAverage()
123
+ pbar = tqdm(train_loader)
124
+ for model_inputs, labels in pbar:
125
+ model_inputs = {k: v.to(device) for k, v in model_inputs.items()}
126
+ labels = labels.to(device)
127
+ optimizer.zero_grad()
128
+ logits, *_ = model(**model_inputs)
129
+ loss = F.cross_entropy(logits, labels.squeeze(-1))
130
+ loss.backward()
131
+ optimizer.step()
132
+ scheduler.step()
133
+ avg_loss.update(loss.item())
134
+ pbar.set_description(f'loss: {avg_loss.get_metric(): 0.4f}, '
135
+ f'lr: {optimizer.param_groups[0]["lr"]: .3e}')
136
+
137
+ logger.info('Evaluating...')
138
+ model.eval()
139
+ correct = 0
140
+ total = 0
141
+ with torch.no_grad():
142
+ for model_inputs, labels in dev_loader:
143
+ model_inputs = {k: v.to(device) for k, v in model_inputs.items()}
144
+ labels = labels.to(device)
145
+ logits, *_ = model(**model_inputs)
146
+ _, preds = logits.max(dim=-1)
147
+ correct += (preds == labels.squeeze(-1)).sum().item()
148
+ total += labels.size(0)
149
+ accuracy = correct / (total + 1e-13)
150
+ logger.info(f'Accuracy: {accuracy : 0.4f}')
151
+
152
+ if accuracy > best_accuracy:
153
+ logger.info('Best performance so far.')
154
+ model.save_pretrained(args.ckpt_dir)
155
+ tokenizer.save_pretrained(args.ckpt_dir)
156
+ best_accuracy = accuracy
157
+ except KeyboardInterrupt:
158
+ logger.info('Interrupted...')
159
+
160
+ logger.info('Testing...')
161
+ model.eval()
162
+ correct = 0
163
+ total = 0
164
+ with torch.no_grad():
165
+ for model_inputs, labels in test_loader:
166
+ model_inputs = {k: v.to(device) for k, v in model_inputs.items()}
167
+ labels = labels.to(device)
168
+ logits, *_ = model(**model_inputs)
169
+ _, preds = logits.max(dim=-1)
170
+ correct += (preds == labels.squeeze(-1)).sum().item()
171
+ total += labels.size(0)
172
+ accuracy = correct / (total + 1e-13)
173
+ logger.info(f'Accuracy: {accuracy : 0.4f}')
174
+
175
+
176
+ if __name__ == '__main__':
177
+ parser = argparse.ArgumentParser()
178
+ parser.add_argument('--model-name', type=str)
179
+ parser.add_argument('--train', type=Path)
180
+ parser.add_argument('--dev', type=Path)
181
+ parser.add_argument('--test', type=Path)
182
+ parser.add_argument('--field-a', type=str)
183
+ parser.add_argument('--field-b', type=str, default=None)
184
+ parser.add_argument('--label-field', type=str, default='label')
185
+ parser.add_argument('--ckpt-dir', type=Path, default=Path('ckpt/'))
186
+ parser.add_argument('--num-labels', type=int, default=2)
187
+ parser.add_argument('--bsz', type=int, default=32)
188
+ parser.add_argument('--epochs', type=int, default=3)
189
+ parser.add_argument('--lr', type=float, default=2e-5)
190
+ parser.add_argument('--limit', type=int, default=None)
191
+ parser.add_argument('--seed', type=int, default=1234)
192
+ parser.add_argument('--bias-correction', action='store_true')
193
+ parser.add_argument('-f', '--force-overwrite', action='store_true')
194
+ parser.add_argument('--debug', action='store_true')
195
+ args = parser.parse_args()
196
+
197
+ if args.debug:
198
+ level = logging.DEBUG
199
+ else:
200
+ level = logging.INFO
201
+ logging.basicConfig(level=level)
202
+
203
+ main(args)
autoprompt/label_search.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This is a hacky little attempt using the tools from the trigger creation script to identify a
3
+ good set of label strings. The idea is to train a linear classifier over the predict token and
4
+ then look at the most similar tokens.
5
+ """
6
+ import argparse
7
+ import json
8
+ import logging
9
+ from pathlib import Path
10
+
11
+ import torch
12
+ import torch.nn.functional as F
13
+ from torch.utils.data import DataLoader
14
+ from transformers import (
15
+ AutoConfig, AutoModelWithLMHead, AutoTokenizer, BertForMaskedLM, RobertaForMaskedLM
16
+ )
17
+ from tqdm import tqdm
18
+
19
+ import autoprompt.utils as utils
20
+ import autoprompt.create_trigger as ct
21
+
22
+
23
+ logger = logging.getLogger(__name__)
24
+
25
+
26
+ def load_pretrained(model_name):
27
+ """
28
+ Loads pretrained HuggingFace config/model/tokenizer, as well as performs required
29
+ initialization steps to facilitate working with triggers.
30
+ """
31
+ config = AutoConfig.from_pretrained(args.model_name)
32
+ model = AutoModelWithLMHead.from_pretrained(args.model_name, config=config)
33
+ model.eval()
34
+ tokenizer = AutoTokenizer.from_pretrained(args.model_name)
35
+ utils.add_task_specific_tokens(tokenizer)
36
+ return config, model, tokenizer
37
+
38
+
39
+ def get_final_embeddings(model):
40
+ if isinstance(model, BertForMaskedLM):
41
+ return model.cls.predictions.transform
42
+ elif isinstance(model, RobertaForMaskedLM):
43
+ return model.lm_head.layer_norm
44
+ else:
45
+ raise NotImplementedError(f'{model} not currently supported')
46
+
47
+
48
+ def get_word_embeddings(model):
49
+ if isinstance(model, BertForMaskedLM):
50
+ return model.cls.predictions.decoder.weight
51
+ elif isinstance(model, RobertaForMaskedLM):
52
+ return model.lm_head.decoder.weight
53
+ else:
54
+ raise NotImplementedError(f'{model} not currently supported')
55
+
56
+
57
+ def main(args):
58
+ ct.set_seed(args.seed)
59
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
60
+
61
+ logger.info('Loading model, tokenizer, etc.')
62
+ config, model, tokenizer = load_pretrained(args.model_name)
63
+ model.to(device)
64
+ final_embeddings = get_final_embeddings(model)
65
+ embedding_storage = utils.OutputStorage(final_embeddings)
66
+ word_embeddings = get_word_embeddings(model)
67
+
68
+ label_map = json.loads(args.label_map)
69
+ reverse_label_map = {y: x for x, y in label_map.items()}
70
+ templatizer = utils.TriggerTemplatizer(
71
+ args.template,
72
+ tokenizer,
73
+ label_map=label_map,
74
+ label_field=args.label_field,
75
+ add_special_tokens=False
76
+ )
77
+
78
+ # The weights of this projection will help identify the best label words.
79
+ projection = torch.nn.Linear(config.hidden_size, len(label_map))
80
+ projection.to(device)
81
+
82
+ # Obtain the initial trigger tokens and label mapping
83
+ if args.initial_trigger:
84
+ trigger_ids = tokenizer.encode(
85
+ args.initial_trigger,
86
+ add_special_tokens=False,
87
+ add_prefix_space=True
88
+ )
89
+ assert len(trigger_ids) == templatizer.num_trigger_tokens
90
+ else:
91
+ trigger_ids = [tokenizer.mask_token_id] * templatizer.num_trigger_tokens
92
+ trigger_ids = torch.tensor(trigger_ids, device=device).unsqueeze(0)
93
+
94
+ logger.info('Loading datasets')
95
+ collator = utils.Collator(pad_token_id=tokenizer.pad_token_id)
96
+ train_dataset = utils.load_trigger_dataset(args.train, templatizer)
97
+ train_loader = DataLoader(train_dataset, batch_size=args.bsz, shuffle=True, collate_fn=collator)
98
+
99
+ optimizer = torch.optim.Adam(projection.parameters(), lr=args.lr)
100
+
101
+ scores = torch.matmul(projection.weight, word_embeddings.transpose(0, 1))
102
+ scores = F.softmax(scores, dim=0)
103
+ for i, row in enumerate(scores):
104
+ _, top = row.topk(args.k)
105
+ decoded = tokenizer.convert_ids_to_tokens(top)
106
+ logger.info(f"Top k for class {reverse_label_map[i]}: {', '.join(decoded)}")
107
+
108
+ logger.info('Training')
109
+ for i in range(args.iters):
110
+ pbar = tqdm(train_loader)
111
+ for model_inputs, labels in pbar:
112
+ optimizer.zero_grad()
113
+ model_inputs = {k: v.to(device) for k, v in model_inputs.items()}
114
+ labels = labels.to(device)
115
+ trigger_mask = model_inputs.pop('trigger_mask')
116
+ predict_mask = model_inputs.pop('predict_mask')
117
+ model_inputs = ct.replace_trigger_tokens(model_inputs, trigger_ids, trigger_mask)
118
+ with torch.no_grad():
119
+ model(**model_inputs)
120
+ embeddings = embedding_storage.get()
121
+ predict_embeddings = embeddings.masked_select(predict_mask.unsqueeze(-1)).view(embeddings.size(0), -1)
122
+ logits = projection(predict_embeddings)
123
+ loss = F.cross_entropy(logits, labels.squeeze(-1))
124
+ loss.backward()
125
+ optimizer.step()
126
+ pbar.set_description(f'loss: {loss : 0.4f}')
127
+
128
+ scores = torch.matmul(projection.weight, word_embeddings.transpose(0, 1))
129
+ scores = F.softmax(scores, dim=0)
130
+ for i, row in enumerate(scores):
131
+ _, top = row.topk(args.k)
132
+ decoded = tokenizer.convert_ids_to_tokens(top)
133
+ logger.info(f"Top k for class {reverse_label_map[i]}: {', '.join(decoded)}")
134
+
135
+
136
+
137
+ if __name__ == '__main__':
138
+ parser = argparse.ArgumentParser()
139
+ parser.add_argument('--train', type=Path, required=True, help='Train data path')
140
+ parser.add_argument('--template', type=str, help='Template string')
141
+ parser.add_argument('--label-map', type=str, help='JSON object defining label map')
142
+ parser.add_argument('--initial-trigger', type=str, default=None, help='Manual prompt')
143
+ parser.add_argument('--label-field', type=str, default='label',
144
+ help='Name of the label field')
145
+ parser.add_argument('--lr', type=float, default=3e-4, help='Learning rate')
146
+ parser.add_argument('--k', type=int, default=50, help='Number of label tokens to print')
147
+ parser.add_argument('--bsz', type=int, default=32, help='Batch size')
148
+ parser.add_argument('--iters', type=int, default=10,
149
+ help='Number of iterations to run label search')
150
+ parser.add_argument('--model-name', type=str, default='bert-base-cased',
151
+ help='Model name passed to HuggingFace AutoX classes.')
152
+ parser.add_argument('--seed', type=int, default=0)
153
+ parser.add_argument('--debug', action='store_true')
154
+ args = parser.parse_args()
155
+
156
+ if args.debug:
157
+ level = logging.DEBUG
158
+ else:
159
+ level = logging.INFO
160
+ logging.basicConfig(level=level)
161
+
162
+ main(args)
autoprompt/popsicle.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Frozen model with a linear topping...I'm really sleepy...
3
+ """
4
+ import logging
5
+
6
+ import torch
7
+ from torch.nn import CrossEntropyLoss, MSELoss
8
+ from transformers import (
9
+ AutoConfig,
10
+ BertConfig,
11
+ BertForSequenceClassification,
12
+ PretrainedConfig,
13
+ RobertaConfig,
14
+ RobertaForSequenceClassification
15
+ )
16
+
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+
21
+
22
+ class Bertsicle(BertForSequenceClassification):
23
+ def forward(
24
+ self,
25
+ input_ids=None,
26
+ attention_mask=None,
27
+ token_type_ids=None,
28
+ position_ids=None,
29
+ head_mask=None,
30
+ inputs_embeds=None,
31
+ labels=None,
32
+ ):
33
+ with torch.no_grad():
34
+ outputs = self.bert(
35
+ input_ids,
36
+ attention_mask=attention_mask,
37
+ token_type_ids=token_type_ids,
38
+ position_ids=position_ids,
39
+ head_mask=head_mask,
40
+ inputs_embeds=inputs_embeds,
41
+ )
42
+
43
+ pooled_output = outputs[1] #by ROB
44
+ pooled_output = outputs[0]
45
+ pooled_output = pooled_output[:,1:,:] #eliminating CLS token
46
+ pooled_output = torch.mean(pooled_output, dim=1)
47
+
48
+ pooled_output = self.dropout(pooled_output)
49
+ logits = self.classifier(pooled_output)
50
+
51
+ outputs = (logits,) + outputs[2:] # add hidden states and attention if they are here
52
+
53
+ if labels is not None:
54
+ if self.num_labels == 1:
55
+ # We are doing regression
56
+ loss_fct = MSELoss()
57
+ loss = loss_fct(logits.view(-1), labels.view(-1))
58
+ else:
59
+ loss_fct = CrossEntropyLoss()
60
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
61
+ outputs = (loss,) + outputs
62
+
63
+ return outputs # (loss), logits, (hidden_states), (attentions)
64
+
65
+
66
+ class Robertasicle(RobertaForSequenceClassification):
67
+ def forward(
68
+ self,
69
+ input_ids=None,
70
+ attention_mask=None,
71
+ token_type_ids=None,
72
+ position_ids=None,
73
+ head_mask=None,
74
+ inputs_embeds=None,
75
+ labels=None,
76
+ ):
77
+ with torch.no_grad():
78
+ outputs = self.roberta(
79
+ input_ids,
80
+ attention_mask=attention_mask,
81
+ token_type_ids=token_type_ids,
82
+ position_ids=position_ids,
83
+ head_mask=head_mask,
84
+ inputs_embeds=inputs_embeds,
85
+ )
86
+ sequence_output = outputs[0]
87
+ sequence_output = sequence_output[:, 1:, :] # eliminating <s> token
88
+ pooled_sequence_output = torch.mean(sequence_output, dim=1, keepdim=True)
89
+ logits = self.classifier(pooled_sequence_output)
90
+ outputs = (logits,) + outputs[2:]
91
+ if labels is not None:
92
+ if self.num_labels == 1:
93
+ # We are doing regression
94
+ loss_fct = MSELoss()
95
+ loss = loss_fct(logits.view(-1), labels.view(-1))
96
+ else:
97
+ loss_fct = CrossEntropyLoss()
98
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
99
+ outputs = (loss,) + outputs
100
+
101
+ return outputs # (loss), logits, (hidden_states), (attentions)
102
+
103
+
104
+ MODEL_MAPPING = {
105
+ RobertaConfig: Robertasicle,
106
+ BertConfig: Bertsicle
107
+ }
108
+
109
+
110
+ class AutoPopsicle:
111
+ def __init__(self):
112
+ raise EnvironmentError('You done goofed. Use `.from_pretrained()` or something.')
113
+
114
+ @classmethod
115
+ def from_config(cls, config):
116
+ for config_class, model_class in MODEL_MAPPING.items():
117
+ if isinstance(config, config_class):
118
+ return model_class(config)
119
+
120
+ raise ValueError('We do not support this config.')
121
+
122
+ @classmethod
123
+ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
124
+ config = kwargs.pop("config", None)
125
+ if not isinstance(config, PretrainedConfig):
126
+ config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
127
+
128
+ for config_class, model_class in MODEL_MAPPING.items():
129
+ if isinstance(config, config_class):
130
+ logger.info(f'Config class: {config_class}')
131
+ logger.info(f'Model class: {model_class}')
132
+ return model_class.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
133
+
134
+ raise ValueError('We do not support "{pretrained_model_name_or_path}".')
autoprompt/run_linear_probe.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Script for running a linear probe on glue tasks.
3
+
4
+ Largely copied from:
5
+ https://github.com/huggingface/transformers/blob/master/examples/text-classification/run_glue.py
6
+ """
7
+ import argparse
8
+ import logging
9
+ from pathlib import Path
10
+
11
+ import torch
12
+ import torch.nn.functional as F
13
+ from torch.utils.data import DataLoader
14
+ from transformers import AutoConfig, AutoTokenizer, WEIGHTS_NAME, CONFIG_NAME
15
+ from tqdm import tqdm
16
+
17
+ from autoprompt.popsicle import AutoPopsicle
18
+ import autoprompt.utils as utils
19
+
20
+ logger = logging.getLogger(__name__)
21
+
22
+
23
+ def main(args):
24
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
25
+
26
+ config = AutoConfig.from_pretrained(args.model_name, num_labels=args.num_labels)
27
+ tokenizer = AutoTokenizer.from_pretrained(args.model_name)
28
+ model = AutoPopsicle.from_pretrained(args.model_name, config=config)
29
+ model.to(device)
30
+
31
+ collator = utils.Collator(pad_token_id=tokenizer.pad_token_id)
32
+ train_dataset, label_map = utils.load_classification_dataset(
33
+ args.train,
34
+ tokenizer,
35
+ args.field_a,
36
+ args.field_b,
37
+ args.label_field
38
+ )
39
+ train_loader = DataLoader(train_dataset, batch_size=args.bsz, shuffle=True, collate_fn=collator)
40
+ dev_dataset, _ = utils.load_classification_dataset(
41
+ args.dev,
42
+ tokenizer,
43
+ args.field_a,
44
+ args.field_b,
45
+ args.label_field,
46
+ label_map
47
+ )
48
+ dev_loader = DataLoader(dev_dataset, batch_size=args.bsz, shuffle=True, collate_fn=collator)
49
+ test_dataset, _ = utils.load_classification_dataset(
50
+ args.test,
51
+ tokenizer,
52
+ args.field_a,
53
+ args.field_b,
54
+ args.label_field,
55
+ label_map
56
+ )
57
+ test_loader = DataLoader(test_dataset, batch_size=args.bsz, shuffle=True, collate_fn=collator)
58
+ optimizer = torch.optim.Adam(model.classifier.parameters(), lr=args.lr, weight_decay=1e-6)
59
+
60
+ if not args.ckpt_dir.exists():
61
+ # logger.info(f'Making checkpoint directory: {args.ckpt_dir}')
62
+ args.ckpt_dir.mkdir(parents=True)
63
+ elif not args.force_overwrite:
64
+ raise RuntimeError('Checkpoint directory already exists.')
65
+
66
+ best_accuracy = 0
67
+ try:
68
+ for epoch in range(args.epochs):
69
+ logger.info('Training...')
70
+ model.eval() # Just linear regression - don't want model outputs changing during training.
71
+ avg_loss = utils.ExponentialMovingAverage()
72
+ pbar = tqdm(train_loader)
73
+ for model_inputs, labels in pbar:
74
+ model_inputs = {k: v.to(device) for k, v in model_inputs.items()}
75
+ labels = labels.to(device)
76
+ optimizer.zero_grad()
77
+ logits, *_ = model(**model_inputs)
78
+ loss = F.cross_entropy(logits, labels.squeeze(-1))
79
+ loss.backward()
80
+ optimizer.step()
81
+ avg_loss.update(loss.item())
82
+ pbar.set_description(f'loss: {avg_loss.get_metric(): 0.4f}')
83
+
84
+ logger.info('Evaluating...')
85
+ model.eval()
86
+ correct = 0
87
+ total = 0
88
+ for model_inputs, labels in dev_loader:
89
+ model_inputs = {k: v.to(device) for k, v in model_inputs.items()}
90
+ labels = labels.to(device)
91
+ logits, *_ = model(**model_inputs)
92
+ _, preds = logits.max(dim=-1)
93
+ correct += (preds == labels.squeeze(-1)).sum().item()
94
+ total += labels.size(0)
95
+ accuracy = correct / (total + 1e-13)
96
+ logger.info(f'Accuracy: {accuracy : 0.4f}')
97
+
98
+ if accuracy > best_accuracy:
99
+ logger.info('Best performance so far. Saving...')
100
+ # torch.save(model.state_dict(), args.ckpt_dir / WEIGHTS_NAME)
101
+ # model.config.to_json_file(args.ckpt_dir / CONFIG_NAME)
102
+ model.save_pretrained(args.ckpt_dir)
103
+ tokenizer.save_pretrained(args.ckpt_dir)
104
+ best_accuracy = accuracy
105
+
106
+ except KeyboardInterrupt:
107
+ logger.info('Training manually terminated.')
108
+
109
+ logger.info('Testing...')
110
+ checkpoint = torch.load(args.ckpt_dir / WEIGHTS_NAME)
111
+ model.load_state_dict(checkpoint)
112
+ model.eval()
113
+ correct = 0
114
+ total = 0
115
+ for model_inputs, labels in test_loader:
116
+ model_inputs = {k: v.to(device) for k, v in model_inputs.items()}
117
+ labels = labels.to(device)
118
+ logits, *_ = model(**model_inputs)
119
+ _, preds = logits.max(dim=-1)
120
+ correct += (preds == labels.squeeze(-1)).sum().item()
121
+ total += labels.size(0)
122
+ accuracy = correct / (total + 1e-13)
123
+ logger.info(f'Accuracy: {accuracy : 0.4f}')
124
+
125
+
126
+ if __name__ == '__main__':
127
+ parser = argparse.ArgumentParser()
128
+ parser.add_argument('--model-name', type=str)
129
+ parser.add_argument('--train', type=Path)
130
+ parser.add_argument('--dev', type=Path)
131
+ parser.add_argument('--test', type=Path)
132
+ parser.add_argument('--field-a', type=str)
133
+ parser.add_argument('--field-b', type=str, default=None)
134
+ parser.add_argument('--label-field', type=str, default='label')
135
+ parser.add_argument('--ckpt-dir', type=Path, default=Path('ckpt/'))
136
+ parser.add_argument('--num-labels', type=int, default=2)
137
+ parser.add_argument('--bsz', type=int, default=32)
138
+ parser.add_argument('--epochs', type=int, default=10)
139
+ parser.add_argument('--lr', type=float, default=1e-3)
140
+ parser.add_argument('-f', '--force-overwrite', action='store_true', default=True)
141
+ parser.add_argument('--debug', action='store_true')
142
+ parser.add_argument('--log_file', type=str, default='log.txt')
143
+ args = parser.parse_args()
144
+
145
+ if args.debug:
146
+ level = logging.DEBUG
147
+ else:
148
+ level = logging.INFO
149
+ logging.basicConfig(level=level, filename=args.log_file)
150
+
151
+ main(args)
autoprompt/utils.py ADDED
@@ -0,0 +1,376 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import csv
2
+ import copy
3
+ import json
4
+ import logging
5
+ import random
6
+ from collections import defaultdict
7
+
8
+ import torch
9
+ from torch.nn.utils.rnn import pad_sequence
10
+
11
+
12
+ MAX_CONTEXT_LEN = 50
13
+
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+
18
+ def pad_squeeze_sequence(sequence, *args, **kwargs):
19
+ """Squeezes fake batch dimension added by tokenizer before padding sequence."""
20
+ return pad_sequence([x.squeeze(0) for x in sequence], *args, **kwargs)
21
+
22
+
23
+ class OutputStorage:
24
+ """
25
+ This object stores the intermediate gradients of the output a the given PyTorch module, which
26
+ otherwise might not be retained.
27
+ """
28
+ def __init__(self, module):
29
+ self._stored_output = None
30
+ module.register_forward_hook(self.hook)
31
+
32
+ def hook(self, module, input, output):
33
+ self._stored_output = output
34
+
35
+ def get(self):
36
+ return self._stored_output
37
+
38
+
39
+ class ExponentialMovingAverage:
40
+ def __init__(self, weight=0.3):
41
+ self._weight = weight
42
+ self.reset()
43
+
44
+ def update(self, x):
45
+ self._x += x
46
+ self._i += 1
47
+
48
+ def reset(self):
49
+ self._x = 0
50
+ self._i = 0
51
+
52
+ def get_metric(self):
53
+ return self._x / (self._i + 1e-13)
54
+
55
+
56
+ class Collator:
57
+ """
58
+ Collates transformer outputs.
59
+ """
60
+ def __init__(self, pad_token_id=0):
61
+ self._pad_token_id = pad_token_id
62
+
63
+ def __call__(self, features):
64
+ # Separate the list of inputs and labels
65
+ model_inputs, labels = list(zip(*features))
66
+ # Assume that all inputs have the same keys as the first
67
+ proto_input = model_inputs[0]
68
+ keys = list(proto_input.keys())
69
+ padded_inputs = {}
70
+ for key in keys:
71
+ if key == 'input_ids':
72
+ padding_value = self._pad_token_id
73
+ else:
74
+ padding_value = 0
75
+ # NOTE: We need to squeeze to get rid of fake batch dim.
76
+ sequence = [x[key] for x in model_inputs]
77
+ padded = pad_squeeze_sequence(sequence, batch_first=True, padding_value=padding_value)
78
+ padded_inputs[key] = padded
79
+ labels = pad_squeeze_sequence(labels, batch_first=True, padding_value=0)
80
+ return padded_inputs, labels
81
+
82
+
83
+ def encode_label(tokenizer, label, tokenize=False):
84
+ """
85
+ Helper function for encoding labels. Deals with the subtleties of handling multiple tokens.
86
+ """
87
+ if isinstance(label, str):
88
+ if tokenize:
89
+ # Ensure label is properly tokenized, and only retain first token
90
+ # if it gets split into multiple tokens. TODO: Make sure this is
91
+ # desired behavior.
92
+ tokens = tokenizer.tokenize(label)
93
+ if len(tokens) > 1:
94
+ raise ValueError(f'Label "{label}" gets mapped to multiple tokens.')
95
+ if tokens[0] == tokenizer.unk_token:
96
+ raise ValueError(f'Label "{label}" gets mapped to unk.')
97
+ label = tokens[0]
98
+ encoded = torch.tensor(tokenizer.convert_tokens_to_ids([label])).unsqueeze(0)
99
+ elif isinstance(label, list):
100
+ encoded = torch.tensor(tokenizer.convert_tokens_to_ids(label)).unsqueeze(0)
101
+ elif isinstance(label, int):
102
+ encoded = torch.tensor([[label]])
103
+ return encoded
104
+
105
+
106
+ class TriggerTemplatizer:
107
+ """
108
+ An object to facilitate creating transformers-friendly triggers inputs from a template.
109
+
110
+ Parameters
111
+ ==========
112
+ template : str
113
+ The template string, comprised of the following tokens:
114
+ [T] to mark a trigger placeholder.
115
+ [P] to mark a prediction placeholder.
116
+ {fields} arbitrary fields instantiated from the dataset instances.
117
+ For example a NLI template might look like:
118
+ "[T] [T] [T] {premise} [P] {hypothesis}"
119
+ tokenizer : PretrainedTokenizer
120
+ A HuggingFace tokenizer. Must have special trigger and predict tokens.
121
+ add_special_tokens : bool
122
+ Whether or not to add special tokens when encoding. Default: False.
123
+ """
124
+ def __init__(self,
125
+ template,
126
+ config,
127
+ tokenizer,
128
+ label_field='label',
129
+ label_map=None,
130
+ tokenize_labels=False,
131
+ add_special_tokens=False,
132
+ use_ctx=False):
133
+ if not hasattr(tokenizer, 'predict_token') or \
134
+ not hasattr(tokenizer, 'trigger_token'):
135
+ raise ValueError(
136
+ 'Tokenizer missing special trigger and predict tokens in vocab.'
137
+ 'Use `utils.add_special_tokens` to add them.'
138
+ )
139
+ self._template = template
140
+ self._config = config
141
+ self._tokenizer = tokenizer
142
+ self._label_field = label_field
143
+ self._label_map = label_map
144
+ self._tokenize_labels = tokenize_labels
145
+ self._add_special_tokens = add_special_tokens
146
+ self._use_ctx = use_ctx
147
+
148
+ @property
149
+ def num_trigger_tokens(self):
150
+ return sum(token == '[T]' for token in self._template.split())
151
+
152
+ def __call__(self, format_kwargs):
153
+ # Format the template string
154
+ format_kwargs = format_kwargs.copy()
155
+ label = format_kwargs.pop(self._label_field)
156
+ text = self._template.format(**format_kwargs)
157
+ if label is None:
158
+ raise Exception(f'Bad data: {text}')
159
+
160
+ # Have the tokenizer encode the text and process the output to:
161
+ # - Create a trigger and predict mask
162
+ # - Replace the predict token with a mask token
163
+ model_inputs = self._tokenizer.encode_plus(
164
+ text,
165
+ add_special_tokens=self._add_special_tokens,
166
+ return_tensors='pt'
167
+ )
168
+ input_ids = model_inputs['input_ids']
169
+ trigger_mask = input_ids.eq(self._tokenizer.trigger_token_id)
170
+ predict_mask = input_ids.eq(self._tokenizer.predict_token_id)
171
+ input_ids[predict_mask] = self._tokenizer.mask_token_id
172
+
173
+ model_inputs['trigger_mask'] = trigger_mask
174
+ model_inputs['predict_mask'] = predict_mask
175
+
176
+ # For relation extraction with BERT, update token_type_ids to reflect the two different sequences
177
+ if self._use_ctx and self._config.model_type == 'bert':
178
+ sep_token_indices = (input_ids.squeeze(0) == self._tokenizer.convert_tokens_to_ids(self._tokenizer.sep_token)).nonzero().flatten()
179
+ sequence_b_indices = torch.arange(sep_token_indices[0], sep_token_indices[1] + 1).long().unsqueeze(0)
180
+ model_inputs['token_type_ids'].scatter_(1, sequence_b_indices, 1)
181
+
182
+ # Encode the label(s)
183
+ if self._label_map is not None:
184
+ label = self._label_map[label]
185
+ label_id = encode_label(
186
+ tokenizer=self._tokenizer,
187
+ label=label,
188
+ tokenize=self._tokenize_labels
189
+ )
190
+
191
+ return model_inputs, label_id
192
+
193
+
194
+ def add_task_specific_tokens(tokenizer):
195
+ tokenizer.add_special_tokens({
196
+ 'additional_special_tokens': ['[T]', '[P]', '[Y]']
197
+ })
198
+ tokenizer.trigger_token = '[T]'
199
+ tokenizer.trigger_token_id = tokenizer.convert_tokens_to_ids('[T]')
200
+ tokenizer.predict_token = '[P]'
201
+ tokenizer.predict_token_id = tokenizer.convert_tokens_to_ids('[P]')
202
+ # NOTE: BERT and RoBERTa tokenizers work properly if [X] is not a special token...
203
+ # tokenizer.lama_x = '[X]'
204
+ # tokenizer.lama_x_id = tokenizer.convert_tokens_to_ids('[X]')
205
+ tokenizer.lama_y = '[Y]'
206
+ tokenizer.lama_x_id = tokenizer.convert_tokens_to_ids('[Y]')
207
+
208
+
209
+
210
+ def load_tsv(fname):
211
+ with open(fname, 'r') as f:
212
+ reader = csv.DictReader(f, delimiter='\t')
213
+ for row in reader:
214
+ yield row
215
+
216
+
217
+ def load_jsonl(fname):
218
+ with open(fname, 'r') as f:
219
+ for line in f:
220
+ yield json.loads(line)
221
+
222
+
223
+ LOADERS = {
224
+ '.tsv': load_tsv,
225
+ '.jsonl': load_jsonl
226
+ }
227
+
228
+
229
+ def load_trigger_dataset(fname, templatizer, use_ctx, limit=None):
230
+ loader = LOADERS[fname.suffix]
231
+ instances = []
232
+
233
+ for x in loader(fname):
234
+ try:
235
+ if use_ctx:
236
+ # For relation extraction, skip facts that don't have context sentence
237
+ if 'evidences' not in x:
238
+ logger.warning('Skipping RE sample because it lacks context sentences: {}'.format(x))
239
+ continue
240
+
241
+ evidences = x['evidences']
242
+
243
+ # Randomly pick a context sentence
244
+ obj_surface, masked_sent = random.choice([(evidence['obj_surface'], evidence['masked_sentence']) for evidence in evidences])
245
+ words = masked_sent.split()
246
+ if len(words) > MAX_CONTEXT_LEN:
247
+ # If the masked sentence is too long, use the first X tokens. For training we want to keep as many samples as we can.
248
+ masked_sent = ' '.join(words[:MAX_CONTEXT_LEN])
249
+
250
+ # If truncated context sentence still has MASK, we need to replace it with object surface
251
+ # We explicitly use [MASK] because all TREx fact's context sentences use it
252
+ context = masked_sent.replace('[MASK]', obj_surface)
253
+ x['context'] = context
254
+ model_inputs, label_id = templatizer(x)
255
+ else:
256
+ model_inputs, label_id = templatizer(x)
257
+ except ValueError as e:
258
+ logger.warning('Encountered error "%s" when processing "%s". Skipping.', e, x)
259
+ continue
260
+ else:
261
+ instances.append((model_inputs, label_id))
262
+ if limit:
263
+ return random.sample(instances, limit)
264
+ else:
265
+ return instances
266
+
267
+
268
+ def load_augmented_trigger_dataset(fname, templatizer, limit=None):
269
+ loader = LOADERS[fname.suffix]
270
+ instances = []
271
+
272
+ # For augmented relation extraction, we need to replace obj_label with another obj_label, and replace obj_surface with a surface form of the new obj_label
273
+ unique_objs_dict = defaultdict(list)
274
+ # Also for augmented relation extraction, we need to accumulate all facts and process them afterwards
275
+ facts = []
276
+
277
+ for x in loader(fname):
278
+ try:
279
+ sub_label = x['sub_label']
280
+ obj_label = x['obj_label']
281
+
282
+ # For relation extraction, skip facts that don't have context sentence
283
+ if 'evidences' not in x:
284
+ logger.warning('Skipping RE sample because it lacks context sentences: {}'.format(x))
285
+ continue
286
+
287
+ evidences = x['evidences']
288
+
289
+ # Gather all UNIQUE objects and their surface forms if its augmented relation extraction
290
+ for evidence in evidences:
291
+ obj_surface = evidence['obj_surface']
292
+ masked_sent = evidence['masked_sentence']
293
+ unique_objs_dict[obj_label].append(obj_surface)
294
+
295
+ # Randomly pick a context sentence
296
+ obj_surface, masked_sent = random.choice([(evidence['obj_surface'], evidence['masked_sentence']) for evidence in evidences])
297
+ words = masked_sent.split()
298
+ if len(words) > MAX_CONTEXT_LEN:
299
+ # If the masked sentence is too long, use the first X tokens. For training we want to keep as many samples as we can.
300
+ masked_sent = ' '.join(words[:MAX_CONTEXT_LEN])
301
+
302
+ x['context'] = masked_sent
303
+ facts.append(x)
304
+ except ValueError as e:
305
+ logger.warning('Encountered error "%s" when processing "%s". Skipping.', e, x)
306
+
307
+ # Go through all facts and replace each object with a new one. Also insert the new object (surface form) into the masked sentence
308
+ synth_facts = []
309
+ for fact in facts:
310
+ sub_label = fact['sub_label']
311
+ obj_label = fact['obj_label']
312
+ masked_sent = fact['context']
313
+ # print('Original fact: ({}, {}, {})'.format(sub_label, obj_label, masked_sent))
314
+ synth_obj_label = random.choice([x for x in unique_objs_dict.keys() if x != obj_label])
315
+ synth_obj_surface = random.choice(unique_objs_dict[synth_obj_label])
316
+ synth_ctx = masked_sent.replace('[MASK]', synth_obj_surface)
317
+ # print('Synthetic fact: ({}, {}, {})\n'.format(sub_label, synth_obj_label, synth_ctx))
318
+ # Reassign the labels and context sentence
319
+ synth_fact = copy.deepcopy(fact)
320
+ synth_fact['sub_label'] = sub_label
321
+ synth_fact['obj_label'] = synth_obj_label
322
+ synth_fact['context'] = synth_ctx
323
+ synth_facts.append(synth_fact)
324
+
325
+ # Go through facts, templatize each one, then append them to instances
326
+ for fact in synth_facts:
327
+ model_inputs, label_id = templatizer(fact)
328
+ instances.append((model_inputs, label_id))
329
+
330
+ if limit:
331
+ return random.sample(instances, limit)
332
+ else:
333
+ return instances
334
+
335
+
336
+ def load_classification_dataset(
337
+ fname,
338
+ tokenizer,
339
+ input_field_a,
340
+ input_field_b=None,
341
+ label_field='label',
342
+ label_map=None,
343
+ limit=None
344
+ ):
345
+ """
346
+ Loads a dataset for classification
347
+
348
+ Parameters
349
+ ==========
350
+ tokenizer : transformers.PretrainedTokenizer
351
+ Maps text to id tensors.
352
+ sentence1 :
353
+ """
354
+ instances = []
355
+ label_map = label_map or {}
356
+ loader = LOADERS[fname.suffix]
357
+ for instance in loader(fname):
358
+ logger.debug(instance)
359
+ model_inputs = tokenizer.encode_plus(
360
+ instance[input_field_a],
361
+ instance[input_field_b] if input_field_b else None,
362
+ add_special_tokens=True,
363
+ # add_prefix_space=True,
364
+ return_tensors='pt'
365
+ )
366
+ logger.debug(model_inputs)
367
+ label = instance[label_field]
368
+ if label not in label_map:
369
+ label_map[label] = len(label_map)
370
+ label_id = label_map[label]
371
+ label_id = torch.tensor([[label_id]]) # To make collator expectation
372
+ logger.debug(f'Label id: {label_id}')
373
+ instances.append((model_inputs, label_id))
374
+ if limit:
375
+ instances = random.sample(instances, limit)
376
+ return instances, label_map
prompts/fact_retrieval_bert_prompts.jsonl ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {"relation": "P1001", "template": "[X]vik nationwide disabilities policing within [Y]."}
2
+ {"relation": "P101", "template": "[X] probability earliest fame totaled studying [Y]."}
3
+ {"relation": "P103", "template": "[X]PA communerug speaks proper [Y]."}
4
+ {"relation": "P106", "template": "[X] supporters studied politicians musician turned [Y]."}
5
+ {"relation": "P108", "template": "[X] 1987adeNBC computing succeeded [Y]."}
6
+ {"relation": "P127", "template": "[X] is hindwings mainline architecture within [Y]."}
7
+ {"relation": "P1303", "template": "[X] playingdrum concertoative electric [Y]."}
8
+ {"relation": "P131", "template": "[X]ediatric close suburb throughout northwest [Y]."}
9
+ {"relation": "P136", "template": "[X] freaking genre orchestra fiction acid [Y]."}
10
+ {"relation": "P1376", "template": "[X] boasts native territory traditionally called [Y]."}
11
+ {"relation": "P138", "template": "[X] consistslanche classical name of [Y]."}
12
+ {"relation": "P140", "template": "[X]urn openly explicitly mosques practicing [Y]."}
13
+ {"relation": "P1412", "template": "[X] receivedorganisation 1904 speaking only [Y]."}
14
+ {"relation": "P159", "template": "[X] isnky galleries headquartered in [Y]."}
15
+ {"relation": "P17", "template": "[X] is association footballled southeastern [Y]."}
16
+ {"relation": "P176", "template": "[X] was flight series manufactured by [Y]."}
17
+ {"relation": "P178", "template": "[X] is memory arcade branding by [Y]."}
18
+ {"relation": "P19", "template": "[X] clocks literary economist relocated to [Y]."}
19
+ {"relation": "P190", "template": "[X] proceeded worldwidedick offices near [Y]."}
20
+ {"relation": "P20", "template": "[X] reorganizationotype photographic studio in [Y]."}
21
+ {"relation": "P264", "template": "[X] cameo explanation\u00f6table sued [Y]."}
22
+ {"relation": "P27", "template": "[X] m\u00b3 badminton pieces internationally representing [Y]."}
23
+ {"relation": "P276", "template": "[X] consists kilograms centred neighborhoods in [Y]."}
24
+ {"relation": "P279", "template": "[X] is \u00ee adequately termed coated [Y]."}
25
+ {"relation": "P30", "template": "[X] is commune polar continent in [Y]."}
26
+ {"relation": "P31", "template": "[X] isious 1970s southwardlier [Y]."}
27
+ {"relation": "P36", "template": "[X] includesiidae geologic countryside near [Y]."}
28
+ {"relation": "P361", "template": "[X] isaul archaic section of [Y]."}
29
+ {"relation": "P364", "template": "[X]dak \u20ac dancers speak standard [Y]."}
30
+ {"relation": "P37", "template": "[X]inen dialects resembled officially exclusively [Y]."}
31
+ {"relation": "P39", "template": "[X] explorers voting municipal \u2192 consecrated [Y]."}
32
+ {"relation": "P407", "template": "[X] playedi\u0107 every dialect but [Y]."}
33
+ {"relation": "P413", "template": "[X] played colors skier \u2194 defensive [Y]."}
34
+ {"relation": "P449", "template": "[X] uncredited recording remake aired on [Y]."}
35
+ {"relation": "P463", "template": "[X] splits artisticlogy prior joining [Y]."}
36
+ {"relation": "P47", "template": "[X] shares undrafted border northeast neighbours [Y]."}
37
+ {"relation": "P495", "template": "[X] album spanninggie chart in [Y]."}
38
+ {"relation": "P527", "template": "[X] nickname involves \u032f\u00bddized [Y]."}
39
+ {"relation": "P530", "template": "[X] nightclubrah preceding relations with [Y]."}
40
+ {"relation": "P740", "template": "[X] refers drum blog centred downtown [Y]."}
41
+ {"relation": "P937", "template": "[X] vol \u300elson gallery in [Y]."}
prompts/fact_retrieval_roberta_prompts.jsonl ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {"relation": "P1001", "template": " [X]\u00a2List unsu rivers spanning [Y] ."}
2
+ {"relation": "P101", "template": " [X] 1830 dissertation applying mathsucci [Y] ."}
3
+ {"relation": "P103", "template": " [X]neau optionally fluent!?\" traditional [Y] ."}
4
+ {"relation": "P106", "template": " [X] (), astronomers businessman\u00b7former [Y] ."}
5
+ {"relation": "P108", "template": " [X] heads opio computer divisionersen [Y] ."}
6
+ {"relation": "P127", "template": " [X] picThom unwillingness officially governs [Y] ."}
7
+ {"relation": "P1303", "template": " [X]Trump learned soloKeefe classical [Y] ."}
8
+ {"relation": "P131", "template": " [X] scenic neighbourhood occurred enqu northeastern [Y] ."}
9
+ {"relation": "P136", "template": " [X] blends postwar hostage drama sax [Y] ."}
10
+ {"relation": "P1376", "template": " [X] limestone depositedati boroughDepending [Y] ."}
11
+ {"relation": "P138", "template": " [X] =alysis northern spellingSaint [Y] ."}
12
+ {"relation": "P140", "template": " [X] traced pagan fascism individuality extremist [Y] ."}
13
+ {"relation": "P1412", "template": " [X] translatedANCauld writings binaries [Y] ."}
14
+ {"relation": "P159", "template": " [X] spinsCompany organisedLocation near [Y] ."}
15
+ {"relation": "P17", "template": " [X]exec scenic provinces iodine northeastern [Y] ."}
16
+ {"relation": "P176", "template": " [X] 125definition enormously stunned manufacturer [Y] ."}
17
+ {"relation": "P178", "template": " [X] 1987 floppy simulator users sued [Y] ."}
18
+ {"relation": "P19", "template": " [X] 2002 protesting disco constructionamine [Y] ."}
19
+ {"relation": "P190", "template": " [X] flight facultiesyna arrivesfolios [Y] ."}
20
+ {"relation": "P20", "template": " [X].. enigmatic twentieth nowadays near [Y] ."}
21
+ {"relation": "P264", "template": " [X] touring 1958 defunct videog label [Y] ."}
22
+ {"relation": "P27", "template": " [X] offic organise forests statutes northwestern [Y] ."}
23
+ {"relation": "P276", "template": " [X] manoeuv constructs whistleblowers hills near [Y] ."}
24
+ {"relation": "P279", "template": " [X],formerly prayers unstaceous [Y] ."}
25
+ {"relation": "P30", "template": " [X] coral caves symb polar zone [Y] ."}
26
+ {"relation": "P31", "template": " [X] (), therapists nationallyrecorded enchanted [Y] ."}
27
+ {"relation": "P36", "template": " [X] 1954 misinterpretburg narrowly battered [Y] ."}
28
+ {"relation": "P361", "template": " [X], supplementaryfoot structuresNorthern [Y] ."}
29
+ {"relation": "P364", "template": " [X]vanathering preferred languagesEnglish [Y] ."}
30
+ {"relation": "P37", "template": " [X]onen tribes descending speak mainly [Y] ."}
31
+ {"relation": "P39", "template": " [X] billionaire elected unp\u200b\u200bCatholic [Y] ."}
32
+ {"relation": "P407", "template": " [X] scaven pronunciation.*Wikipedia speaks [Y] ."}
33
+ {"relation": "P413", "template": " [X],'' (), ex-,Liverpool [Y] ."}
34
+ {"relation": "P449", "template": " [X] premiered 1989 simulatively instinctively [Y] ."}
35
+ {"relation": "P463", "template": " [X] joins reformedolitical endangered grouping [Y] ."}
36
+ {"relation": "P47", "template": " [X] combinesfill marry territory surrounding [Y] ."}
37
+ {"relation": "P495", "template": " [X] condom announces manufacturer residence exported [Y] ."}
38
+ {"relation": "P527", "template": " [X] minus asylum cooked = compressed [Y] ."}
39
+ {"relation": "P530", "template": " [X]varOriginally kidnappedstrate neighboring [Y] ."}
40
+ {"relation": "P740", "template": " [X] prefersLondon whilst 182 favors [Y] ."}
41
+ {"relation": "P937", "template": " [X] bicycles investments railway neighborhoodAlternatively [Y] ."}
prompts/relation_extraction_bert_prompts.jsonl ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {"relation": "P1001", "template": "[X] dispatched state consul federally to [Y]."}
2
+ {"relation": "P101", "template": "[X]icidalology fascinated textbook on [Y]."}
3
+ {"relation": "P103", "template": "[X] sent literary visa speaking predominantly [Y]."}
4
+ {"relation": "P106", "template": "[X] as invented firstractical aspiring [Y]."}
5
+ {"relation": "P108", "template": "[X] funded transmissions business involvement at [Y]."}
6
+ {"relation": "P127", "template": "[X] sentuti limo sponsorship to [Y]."}
7
+ {"relation": "P1303", "template": "[X] ] podcast 1935 practices unison [Y]."}
8
+ {"relation": "P131", "template": "[X] fewer congressional consul corporation bordering [Y]."}
9
+ {"relation": "P136", "template": "[X] drama bacteriatitled 80s cosmic [Y]."}
10
+ {"relation": "P138", "template": "[X] positively cited the town nicknamed [Y]."}
11
+ {"relation": "P140", "template": "[X] 2006 revelation convertedtsky practiced [Y]."}
12
+ {"relation": "P1412", "template": "[X] imported colleges translations exports speak [Y]."}
13
+ {"relation": "P159", "template": "[X]rica headquartered town across from [Y]."}
14
+ {"relation": "P17", "template": "[X] constituteronological country embassy to [Y]."}
15
+ {"relation": "P176", "template": "[X] became plays sponsor co with [Y]."}
16
+ {"relation": "P178", "template": "[X] game handed showcased separately by [Y]."}
17
+ {"relation": "P19", "template": "[X]lancheheim grew house in [Y]."}
18
+ {"relation": "P190", "template": "[X] attended waived both cities including [Y]."}
19
+ {"relation": "P20", "template": "[X]rseyjee maintained apartment in [Y]."}
20
+ {"relation": "P264", "template": "[X] became commemorated label label succeeding [Y]."}
21
+ {"relation": "P27", "template": "[X] country goals diaspora diplomat visited [Y]."}
22
+ {"relation": "P276", "template": "[X] visited crore sister town to [Y]."}
23
+ {"relation": "P279", "template": "[X]districtutical\u00e8ne word resembling [Y]."}
24
+ {"relation": "P30", "template": "[X] subfamily pardon globallyinae throughout [Y]."}
25
+ {"relation": "P31", "template": "[X] nm charitiespository nicknamed underwater [Y]."}
26
+ {"relation": "P36", "template": "[X] sued wraps owner city of [Y]."}
27
+ {"relation": "P361", "template": "[X] passwordU emblem inspired by [Y]."}
28
+ {"relation": "P364", "template": "[X] translated mistress culturally language notably [Y]."}
29
+ {"relation": "P37", "template": "[X] called countries speaking originually [Y]."}
30
+ {"relation": "P39", "template": "[X]lina \u2500 fifteenthously supreme [Y]."}
31
+ {"relation": "P407", "template": "[X] sent vocalist languages foreign especially [Y]."}
32
+ {"relation": "P413", "template": "[X] acts sentiment rookie minimum prone [Y]."}
33
+ {"relation": "P449", "template": "[X] novels channel similarly also joined [Y]."}
34
+ {"relation": "P463", "template": "[X] member testified frontman founded also [Y]."}
35
+ {"relation": "P47", "template": "[X] became consulate will include [Y]."}
36
+ {"relation": "P495", "template": "[X] shows website country abroad includes [Y]."}
37
+ {"relation": "P530", "template": "[X] send globally dedicated embassy to [Y]."}
38
+ {"relation": "P740", "template": "[X] music compliment residents resident in [Y]."}
39
+ {"relation": "P937", "template": "[X] described courthouse residency career near [Y]."}
prompts/relation_extraction_roberta_prompts.jsonl ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {"relation": "P1001", "template": "[X] congratulated killers counterparts residing outage [Y] ."}
2
+ {"relation": "P101", "template": "[X]itations illustratingModern\u2010 risked [Y] ."}
3
+ {"relation": "P103", "template": "[X] website canceled learn languageposition [Y] ."}
4
+ {"relation": "P106", "template": "[X]Officersoglu internationally renown trained [Y] ."}
5
+ {"relation": "P108", "template": "[X] culinary \u00a9 fixtures file courtesy [Y] ."}
6
+ {"relation": "P127", "template": "[X] proudly celebrating playthrough ties cultured [Y] ."}
7
+ {"relation": "P1303", "template": "[X] gubernatorial 410 antique vibritone [Y] ."}
8
+ {"relation": "P131", "template": "[X] \"{\">phys town bordering [Y] ."}
9
+ {"relation": "P136", "template": "[X] poking maneuvers genre synonymous baseline [Y] ."}
10
+ {"relation": "P138", "template": "[X] slideshow painting spelling homage ()); [Y] ."}
11
+ {"relation": "P140", "template": "[X] modified kosher spiritualitycert imitate [Y] ."}
12
+ {"relation": "P1412", "template": "[X] translating pled spoken callback fluent [Y] ."}
13
+ {"relation": "P159", "template": "[X] hometown bonding hahaVisit downtown [Y] ."}
14
+ {"relation": "P17", "template": "[X] embassy factual diplomatic ambassadorooked [Y] ."}
15
+ {"relation": "P176", "template": "[X] sponsorship respectfully complimentary courtesy fuckin [Y] ."}
16
+ {"relation": "P178", "template": "[X] wikiPlanetSOURCE sponsored reckon [Y] ."}
17
+ {"relation": "P19", "template": "[X] slideshow referencing correctness hometown continent [Y] ."}
18
+ {"relation": "P190", "template": "[X] planetaking luggage transfer reaching [Y] ."}
19
+ {"relation": "P20", "template": "[X] ironically resided located recalling downtown [Y] ."}
20
+ {"relation": "P264", "template": "[X] claims primary label membershipdisc [Y] ."}
21
+ {"relation": "P27", "template": "[X] smugglers smuggled davidjl forcibly affordability [Y] ."}
22
+ {"relation": "P276", "template": "[X] photographed>:Folder cliffs overlooking [Y] ."}
23
+ {"relation": "P279", "template": "[X]enez sculpture disguised mailboxSensor [Y] ."}
24
+ {"relation": "P30", "template": "[X] tropical continent tropicalmessageoutheastern [Y] ."}
25
+ {"relation": "P31", "template": "[X] bullies campuses hypothetical substitutiononic [Y] ."}
26
+ {"relation": "P36", "template": "[X] border*.NOWVisit downtown [Y] ."}
27
+ {"relation": "P361", "template": "[X] ~/FlickrFORE blessing representing [Y] ."}
28
+ {"relation": "P364", "template": "[X] population language predomin smoker installer [Y] ."}
29
+ {"relation": "P37", "template": "[X] screamed visibly fluent descendants nutrients [Y] ."}
30
+ {"relation": "P39", "template": "[X] slideshow photo\u30aa workforce appointed [Y] ."}
31
+ {"relation": "P407", "template": "[X] screened pioneering documentaries translated curry [Y] ."}
32
+ {"relation": "P413", "template": "[X] learnedailed springTherefore veteran [Y] ."}
33
+ {"relation": "P449", "template": "[X] slideshow courtesy recommendation television broadcaster [Y] ."}
34
+ {"relation": "P463", "template": "[X] facebook referencing summarizes monikerTeam [Y] ."}
35
+ {"relation": "P47", "template": "[X] aquatic contacted diplomatic consulate imperialist [Y] ."}
36
+ {"relation": "P495", "template": "[X] webpage highlighting cultural exportsrero [Y] ."}
37
+ {"relation": "P530", "template": "[X]ECD establishes diplomatic ties fut [Y] ."}
38
+ {"relation": "P740", "template": "[X]empt adjoining merchants utilized downtown [Y] ."}
39
+ {"relation": "P937", "template": "[X]\u00df died\"\" 1931 barbecue [Y] ."}
pytest.ini ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ [pytest]
2
+ testpaths = tests/
3
+ pythonpath = ./
4
+ log_format = %(asctime)s - %(levelname)s - %(name)s - %(message)s
5
+ log_level = DEBUG
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ streamlit==0.79.0
2
+ tqdm==4.49.0
3
+ pandas==1.2.1
4
+ numpy==1.17.2
5
+ torch==1.4.0
6
+ transformers==2.9.1
7
+ spacy==2.2.0
8
+ termcolor==1.1.0
9
+ colorama==0.4.1
10
+ matplotlib==3.1.1
11
+ pytest
scripts/run_fact_retrieval_example.sh ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ # Experiment 8
3
+ # Task: fact retrieval
4
+ # Model: RoBERTa
5
+ # Batch sizes: 56
6
+ # Iters: 1000
7
+ # Filtering: True
8
+
9
+ datadir=$1
10
+ logfile=$2
11
+
12
+ # Clear files
13
+ cat /dev/null > $logfile
14
+ cat /dev/null > ${logfile}.log
15
+
16
+ for path in $datadir/*; do
17
+ filename=$(basename "$path")
18
+ time CUDA_VISIBLE_DEVICES=3 python -m autoprompt.create_trigger \
19
+ --train $path/train.jsonl \
20
+ --dev $path/dev.jsonl \
21
+ --template '<s> {sub_label} [T] [T] [T] [T] [T] [P] . </s>' \
22
+ --num-cand 10 \
23
+ --accumulation-steps 1 \
24
+ --model-name roberta-large \
25
+ --bsz 56 \
26
+ --eval-size 56 \
27
+ --iters 1000 \
28
+ --label-field 'obj_label' \
29
+ --tokenize-labels \
30
+ --filter \
31
+ --print-lama >> $logfile 2>> ${logfile}.log
32
+ done
scripts/run_relation_extraction_example.sh ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ # Experiment 9
3
+ # Task: relation extraction
4
+ # Model: BERT
5
+ # Batch size: 32
6
+ # Iters: 500
7
+ # Filtering: True
8
+
9
+ datadir=$1
10
+ logfile=$2
11
+
12
+ # Clear files
13
+ cat /dev/null > $logfile
14
+ cat /dev/null > ${logfile}.log
15
+
16
+ for path in $datadir/*; do
17
+ filename=$(basename "$path")
18
+ time CUDA_VISIBLE_DEVICES=4 python -m autoprompt.create_trigger \
19
+ --train $path/train.jsonl \
20
+ --dev $path/dev.jsonl \
21
+ --template '[CLS] {context} [SEP] {sub_label} [T] [T] [T] [T] [T] [P] . [SEP]' \
22
+ --num-cand 10 \
23
+ --accumulation-steps 1 \
24
+ --model-name bert-base-cased \
25
+ --bsz 32 \
26
+ --eval-size 32 \
27
+ --iters 500 \
28
+ --label-field 'obj_label' \
29
+ --tokenize-labels \
30
+ --filter \
31
+ --print-lama \
32
+ --use-ctx >> $logfile 2>> ${logfile}.log
33
+ done
setup.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import setuptools
3
+ import sys
4
+
5
+
6
+ # Load README to get long description.
7
+ with open('README.md') as f:
8
+ _LONG_DESCRIPTION = f.read()
9
+
10
+
11
+ setuptools.setup(
12
+ name='autoprompt',
13
+ version='0.0.1',
14
+ description='AutoPrompt',
15
+ long_description=_LONG_DESCRIPTION,
16
+ long_description_content_type='text/markdown',
17
+ author='UCI NLP',
18
+ url='https://github.com/ucinlp/autoprompt',
19
+ packages=setuptools.find_packages(),
20
+ install_requires=[ ],
21
+ extras_require={
22
+ 'test': ['pytest']
23
+ },
24
+ classifiers=[
25
+ 'Intended Audience :: Science/Research',
26
+ 'Topic :: Scientific/Engineering :: Artificial Intelligence',
27
+ ],
28
+ keywords='text nlp machinelearning',
29
+ )
tests/test_create_trigger.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from unittest import TestCase
2
+
3
+ import torch
4
+ from transformers import AutoConfig, AutoModelWithLMHead, AutoTokenizer
5
+
6
+ import autoprompt.create_trigger as ct
7
+
8
+
9
+ def _load(model_name):
10
+ config = AutoConfig.from_pretrained('bert-base-cased')
11
+ model = AutoModelWithLMHead.from_pretrained('bert-base-cased')
12
+ tokenizer = AutoTokenizer.from_pretrained('bert-base-cased')
13
+ return config, model, tokenizer
14
+
15
+
16
+ class TestGetEmbeddings(TestCase):
17
+ def test_bert(self):
18
+ model_name = 'bert-base-cased'
19
+ config, model, tokenizer = _load(model_name)
20
+ embeddings = ct.get_embeddings(model, config)
21
+ self.assertEqual(embeddings.weight.shape[0], config.vocab_size)
22
+
23
+ def test_roberta(self):
24
+ model_name = 'roberta-base'
25
+ config, model, tokenizer = _load(model_name)
26
+ embeddings = ct.get_embeddings(model, config)
27
+ self.assertEqual(embeddings.weight.shape[0], config.vocab_size)
28
+
29
+
30
+ class TestGradientStorage(TestCase):
31
+ def test_gradient_storage(self):
32
+ num_embeddings = 3
33
+ embedding_dim = 4
34
+ embeddings = torch.nn.Embedding(num_embeddings, embedding_dim)
35
+ embedding_storage = ct.GradientStorage(embeddings)
36
+
37
+ inputs = torch.tensor([0, 1, 2, 1])
38
+ outputs = embeddings(inputs)
39
+ outputs.retain_grad()
40
+ loss = outputs.sum()
41
+ loss.backward()
42
+
43
+ assert torch.equal(outputs.grad, embedding_storage.get())
44
+
45
+
46
+ def test_replace_trigger_tokens():
47
+ model_inputs = {
48
+ 'input_ids': torch.tensor([
49
+ [1, 2, 3, 4],
50
+ [1, 1, 1, 0]
51
+ ])
52
+ }
53
+ trigger_ids = torch.tensor([[5, 6]])
54
+ trigger_mask = torch.tensor([
55
+ [True, True, False, False],
56
+ [False, True, False, True]
57
+ ])
58
+ replaced = ct.replace_trigger_tokens(model_inputs, trigger_ids, trigger_mask)
59
+ expected = torch.tensor([
60
+ [5, 6, 3, 4],
61
+ [1, 5, 1, 6]
62
+ ])
63
+ assert torch.equal(expected, replaced['input_ids'])
tests/test_utils.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from unittest import TestCase
2
+
3
+ import torch
4
+ from torch.utils.data import DataLoader
5
+ from transformers import AutoConfig, AutoTokenizer
6
+
7
+ import autoprompt.utils as utils
8
+
9
+
10
+ class TestEncodeLabel(TestCase):
11
+ def setUp(self):
12
+ self._tokenizer = AutoTokenizer.from_pretrained('bert-base-cased')
13
+
14
+ def test_single_token(self):
15
+ output = utils.encode_label(self._tokenizer, 'the')
16
+ expected_output = torch.tensor([self._tokenizer.convert_tokens_to_ids(['the'])])
17
+ assert torch.equal(output, expected_output)
18
+
19
+ def test_multiple_tokens(self):
20
+ output = utils.encode_label(self._tokenizer, ['a', 'the'])
21
+ expected_output = torch.tensor([
22
+ self._tokenizer.convert_tokens_to_ids(['a', 'the'])
23
+ ])
24
+ assert torch.equal(output, expected_output)
25
+
26
+
27
+ class TestTriggerTemplatizer(TestCase):
28
+ def setUp(self):
29
+ self.default_template = '[T] [T] {arbitrary} [T] {fields} [P]'
30
+ self.default_config = AutoConfig.from_pretrained('bert-base-cased')
31
+ self.default_tokenizer = AutoTokenizer.from_pretrained('bert-base-cased')
32
+ utils.add_task_specific_tokens(self.default_tokenizer)
33
+ self.default_instance = {
34
+ 'arbitrary': 'does this',
35
+ 'fields': 'work',
36
+ 'label': 'and'
37
+ }
38
+
39
+ def test_bert(self):
40
+ templatizer = utils.TriggerTemplatizer(
41
+ self.default_template,
42
+ self.default_config,
43
+ self.default_tokenizer,
44
+ add_special_tokens=False
45
+ )
46
+ model_inputs, label = templatizer(self.default_instance)
47
+
48
+ # Label should be mapped to its token id
49
+ expected_label = torch.tensor([self.default_tokenizer.convert_tokens_to_ids([self.default_instance['label']])])
50
+ assert torch.equal(expected_label, label)
51
+
52
+ # For BERT ouput is expected to have the following keys
53
+ assert 'input_ids' in model_inputs
54
+ assert 'token_type_ids' in model_inputs
55
+ assert 'attention_mask' in model_inputs
56
+
57
+ # Test that the custom masks match our expectations
58
+ expected_trigger_mask = torch.tensor(
59
+ [[True, True, False, False, True, False, False]]
60
+ )
61
+ assert torch.equal(expected_trigger_mask, model_inputs['trigger_mask'])
62
+
63
+ expected_predict_mask = torch.tensor(
64
+ [[False, False, False, False, False, False, True]]
65
+ )
66
+ assert torch.equal(expected_predict_mask, model_inputs['predict_mask'])
67
+
68
+ # Lastly, ensure [P] is replaced by a [MASK] token
69
+ input_ids = model_inputs['input_ids']
70
+ predict_mask = model_inputs['predict_mask']
71
+ predict_token_id = input_ids[predict_mask].squeeze().item()
72
+ assert predict_token_id == self.default_tokenizer.mask_token_id
73
+
74
+ def test_roberta(self):
75
+ config = AutoConfig.from_pretrained('roberta-base')
76
+ tokenizer = AutoTokenizer.from_pretrained('roberta-base')
77
+ utils.add_task_specific_tokens(tokenizer)
78
+ templatizer = utils.TriggerTemplatizer(
79
+ self.default_template,
80
+ config,
81
+ tokenizer,
82
+ add_special_tokens=False
83
+ )
84
+
85
+ model_inputs, label = templatizer(self.default_instance)
86
+
87
+ # Label should be mapped to its token id
88
+ expected_label = torch.tensor([tokenizer.convert_tokens_to_ids([self.default_instance['label']])])
89
+ assert torch.equal(expected_label, label)
90
+
91
+ # For BERT ouput is expected to have the following keys
92
+ print(model_inputs)
93
+ assert 'input_ids' in model_inputs
94
+ assert 'attention_mask' in model_inputs
95
+
96
+ # Test that the custom masks match our expectations
97
+ expected_trigger_mask = torch.tensor(
98
+ [[True, True, False, False, True, False, False]]
99
+ )
100
+ assert torch.equal(expected_trigger_mask, model_inputs['trigger_mask'])
101
+
102
+ expected_predict_mask = torch.tensor(
103
+ [[False, False, False, False, False, False, True]]
104
+ )
105
+ assert torch.equal(expected_predict_mask, model_inputs['predict_mask'])
106
+
107
+ # Lastly, ensure [P] is replaced by a [MASK] token
108
+ input_ids = model_inputs['input_ids']
109
+ predict_mask = model_inputs['predict_mask']
110
+ predict_token_id = input_ids[predict_mask].squeeze().item()
111
+ assert predict_token_id == tokenizer.mask_token_id
112
+
113
+
114
+ class TestCollator(TestCase):
115
+
116
+ def test_collator(self):
117
+ template = '[T] [T] {arbitrary} [T] {fields} [P]'
118
+ tokenizer = AutoTokenizer.from_pretrained('bert-base-cased')
119
+ config = AutoConfig.from_pretrained('bert-base-cased')
120
+ utils.add_task_specific_tokens(tokenizer)
121
+ templatizer = utils.TriggerTemplatizer(
122
+ template,
123
+ config,
124
+ tokenizer,
125
+ add_special_tokens=False
126
+ )
127
+ collator = utils.Collator(pad_token_id=tokenizer.pad_token_id)
128
+
129
+ instances = [
130
+ {'arbitrary': 'a', 'fields': 'the', 'label': 'hot'},
131
+ {'arbitrary': 'a a', 'fields': 'the the', 'label': 'cold'}
132
+ ]
133
+ templatized_instances = [templatizer(x) for x in instances]
134
+ loader = DataLoader(
135
+ templatized_instances,
136
+ batch_size=2,
137
+ shuffle=False,
138
+ collate_fn=collator
139
+ )
140
+ model_inputs, labels = next(iter(loader))
141
+
142
+ # Check results match our expectations
143
+ expected_labels = torch.tensor([
144
+ tokenizer.encode('hot', add_special_tokens=False, add_prefix_space=True),
145
+ tokenizer.encode('cold', add_special_tokens=False, add_prefix_space=True),
146
+ ])
147
+ assert torch.equal(expected_labels, labels)
148
+
149
+ expected_trigger_mask = torch.tensor([
150
+ [True, True, False, True, False, False, False, False],
151
+ [True, True, False, False, True, False, False, False],
152
+ ])
153
+ assert torch.equal(expected_trigger_mask, model_inputs['trigger_mask'])
154
+
155
+ expected_predict_mask = torch.tensor([
156
+ [False, False, False, False, False, True, False, False],
157
+ [False, False, False, False, False, False, False, True],
158
+ ])
159
+ assert torch.equal(expected_predict_mask, model_inputs['predict_mask'])