Spaces:
Build error
Build error
Commit
·
861c889
unverified
·
0
Parent(s):
Initial commit
Browse files- .circleci/config.yml +18 -0
- .gitignore +66 -0
- README.md +168 -0
- app.py +580 -0
- app/.streamlit/config.toml +10 -0
- assets/icon.png +0 -0
- assets/sst2_train.jsonl +32 -0
- autoprompt/__init__.py +0 -0
- autoprompt/create_trigger.py +523 -0
- autoprompt/finetune.py +203 -0
- autoprompt/label_search.py +162 -0
- autoprompt/popsicle.py +134 -0
- autoprompt/run_linear_probe.py +151 -0
- autoprompt/utils.py +376 -0
- prompts/fact_retrieval_bert_prompts.jsonl +41 -0
- prompts/fact_retrieval_roberta_prompts.jsonl +41 -0
- prompts/relation_extraction_bert_prompts.jsonl +39 -0
- prompts/relation_extraction_roberta_prompts.jsonl +39 -0
- pytest.ini +5 -0
- requirements.txt +11 -0
- scripts/run_fact_retrieval_example.sh +32 -0
- scripts/run_relation_extraction_example.sh +33 -0
- setup.py +29 -0
- tests/test_create_trigger.py +63 -0
- tests/test_utils.py +159 -0
.circleci/config.yml
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
version: 2
|
2 |
+
jobs:
|
3 |
+
build:
|
4 |
+
working_directory: ~/autoprompt
|
5 |
+
docker:
|
6 |
+
- image: circleci/python:3.7
|
7 |
+
environment:
|
8 |
+
OMP_NUM_THREADS: 1
|
9 |
+
resource_class: medium
|
10 |
+
parallelism: 1
|
11 |
+
steps:
|
12 |
+
- checkout
|
13 |
+
- run: pip install --upgrade pip
|
14 |
+
- run: pip install -r requirements.txt
|
15 |
+
- run: python -m pytest --disable-warnings
|
16 |
+
- store_test_results:
|
17 |
+
path: test-results
|
18 |
+
|
.gitignore
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Byte-compiled / optimized / DLL files
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
|
5 |
+
# C extensions
|
6 |
+
*.so
|
7 |
+
|
8 |
+
# Distribution / packaging
|
9 |
+
.Python
|
10 |
+
env/
|
11 |
+
build/
|
12 |
+
develop-eggs/
|
13 |
+
dist/
|
14 |
+
downloads/
|
15 |
+
eggs/
|
16 |
+
.eggs/
|
17 |
+
lib/
|
18 |
+
lib64/
|
19 |
+
parts/
|
20 |
+
sdist/
|
21 |
+
var/
|
22 |
+
*.egg-info/
|
23 |
+
.installed.cfg
|
24 |
+
*.egg
|
25 |
+
|
26 |
+
# PyInstaller
|
27 |
+
# Usually these files are written by a python script from a template
|
28 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
29 |
+
*.manifest
|
30 |
+
*.spec
|
31 |
+
|
32 |
+
# Installer logs
|
33 |
+
pip-log.txt
|
34 |
+
pip-delete-this-directory.txt
|
35 |
+
|
36 |
+
# Unit test / coverage reports
|
37 |
+
htmlcov/
|
38 |
+
.tox/
|
39 |
+
.coverage
|
40 |
+
.coverage.*
|
41 |
+
.cache
|
42 |
+
nosetests.xml
|
43 |
+
coverage.xml
|
44 |
+
*,cover
|
45 |
+
|
46 |
+
# Translations
|
47 |
+
*.mo
|
48 |
+
*.pot
|
49 |
+
|
50 |
+
# Django stuff:
|
51 |
+
*.log
|
52 |
+
|
53 |
+
# Sphinx documentation
|
54 |
+
docs/_build/
|
55 |
+
|
56 |
+
# PyBuilder
|
57 |
+
target/
|
58 |
+
|
59 |
+
# IPython checkpoints
|
60 |
+
.ipynb_checkpoints
|
61 |
+
|
62 |
+
# Miscellaneous
|
63 |
+
.DS_Store
|
64 |
+
.vscode/
|
65 |
+
out/
|
66 |
+
#data/
|
README.md
ADDED
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# AutoPrompt
|
2 |
+
An automated method based on gradient-guided search to create prompts for a diverse set of NLP tasks. AutoPrompt demonstrates that masked language models (MLMs) have an innate ability to perform sentiment analysis, natural language inference, fact retrieval, and relation extraction. Check out our [website](https://ucinlp.github.io/autoprompt/) for the paper and more information.
|
3 |
+
|
4 |
+
## Table of Contents
|
5 |
+
* [Setup](#setup)
|
6 |
+
* [Generating Prompts](#generating-prompts)
|
7 |
+
* [Label Token Selection](#label-token-selection)
|
8 |
+
* [Evaluation for Fact Retrieval and Relation Extraction](#evaluation-for-fact-retrieval-and-relation-extraction)
|
9 |
+
* [Citation](#citation)
|
10 |
+
|
11 |
+
## Setup
|
12 |
+
|
13 |
+
### 1. Create conda environment
|
14 |
+
```
|
15 |
+
conda create -n autoprompt -y python=3.7 && conda activate autoprompt
|
16 |
+
```
|
17 |
+
|
18 |
+
### 2. Install dependecies
|
19 |
+
Install the required packages
|
20 |
+
```
|
21 |
+
pip install -r requirements.txt
|
22 |
+
```
|
23 |
+
Also download the spacy model
|
24 |
+
```
|
25 |
+
python -m spacy download en
|
26 |
+
```
|
27 |
+
|
28 |
+
### 3. Download the data
|
29 |
+
The datasets for sentiment analysis, NLI, fact retrieval, and relation extraction are available to download [here](https://drive.google.com/drive/folders/1vVhgnSXmbuJb6GLPn_FErY1xDTh1xyv-?usp=sharing)
|
30 |
+
|
31 |
+
There are a couple different datasets for fact retrieval and relation extraction so here are brief overviews of each:
|
32 |
+
- Fact Retrieval
|
33 |
+
- `original`: We used the T-REx subset provided by LAMA as our test set and gathered more facts from the [original T-REx dataset](https://hadyelsahar.github.io/t-rex/) that we partitioned into train and dev sets
|
34 |
+
- `original_rob`: We filtered facts in `original` so that each object is a single token for both BERT and RoBERTa
|
35 |
+
- `trex`: We split the extra T-REx data collected (for train/val sets of `original`) into train, dev, test sets
|
36 |
+
- Relation Extraction
|
37 |
+
- Trimmed the `original` dataset to compensate for both the [RE baseline](https://github.com/UKPLab/emnlp2017-relation-extraction) and RoBERTa. We also excluded relations `P527` and `P1376` because the RE baseline doesn’t consider them.
|
38 |
+
|
39 |
+
## Generating Prompts
|
40 |
+
|
41 |
+
### Quick Overview of Templates
|
42 |
+
A prompt is constructed by mapping things like the original input and trigger tokens to a template that looks something like
|
43 |
+
|
44 |
+
`[CLS] {sub_label} [T] [T] [T] [P]. [SEP]`
|
45 |
+
|
46 |
+
The example above is a template for generating fact retrieval prompts with 3 trigger tokens where `{sub_label}` is a placeholder for the subject in any (subject, relation, object) triplet in fact retrieval. `[P]` denotes the placement of a special `[MASK]` token that will be used to "fill-in-the-blank" by the language model. Each trigger token in the set of trigger tokens that are shared across all prompts is denoted by `[T]`.
|
47 |
+
|
48 |
+
Depending on the language model (i.e. BERT or RoBERTa) you choose to generate prompts, the special tokens will be different. For BERT, stick `[CLS]` and `[SEP]` to each end of the template. For RoBERTa, use `<s>` and `</s>` instead.
|
49 |
+
|
50 |
+
### Sentiment Analysis
|
51 |
+
```
|
52 |
+
python -m autoprompt.create_trigger \
|
53 |
+
--train glue_data/SST-2/train.tsv \
|
54 |
+
--dev glue_data/SST-2/dev.tsv \
|
55 |
+
--template '<s> {sentence} [T] [T] [T] [P] . </s>' \
|
56 |
+
--label-map '{"0": ["Ġworse", "Ġincompetence", "ĠWorse", "Ġblamed", "Ġsucked"], "1": ["ĠCris", "Ġmarvelous", "Ġphilanthrop", "Ġvisionary", "Ġwonderful"]}' \
|
57 |
+
--num-cand 100 \
|
58 |
+
--accumulation-steps 30 \
|
59 |
+
--bsz 24 \
|
60 |
+
--eval-size 48 \
|
61 |
+
--iters 180 \
|
62 |
+
--model-name roberta-large
|
63 |
+
```
|
64 |
+
|
65 |
+
### Natural Language Inference
|
66 |
+
```
|
67 |
+
python -m autoprompt.create_trigger --train SICK_TRAIN_ALL_S.tsv --dev SICK_DEV_ALL_S.tsv --template '<s> {sentence_A} [P] [T] [T] [T] [T] {sentence_B} </s>' --label-map '{"ENTAILMENT": ["\u0120Taiwan", "\u0120Ara", "abet"], "CONTRADICTION": ["\u0120Only", "\u0120Didn", "\u0120BUT"], "NEUTRAL": ["icy", "oder", "agna"]}' --bsz 120 --model-name roberta-large
|
68 |
+
```
|
69 |
+
|
70 |
+
### Fact Retrieval
|
71 |
+
```
|
72 |
+
python -m autoprompt.create_trigger \
|
73 |
+
--train $path/train.jsonl \
|
74 |
+
--dev $path/dev.jsonl \
|
75 |
+
--template '<s> {sub_label} [T] [T] [T] [P] . </s>' \
|
76 |
+
--num-cand 10 \
|
77 |
+
--accumulation-steps 1 \
|
78 |
+
--model-name roberta-large \
|
79 |
+
--bsz 56 \
|
80 |
+
--eval-size 56 \
|
81 |
+
--iters 1000 \
|
82 |
+
--label-field 'obj_label' \
|
83 |
+
--tokenize-labels \
|
84 |
+
--filter \
|
85 |
+
--print-lama
|
86 |
+
```
|
87 |
+
|
88 |
+
### Relation Extraction
|
89 |
+
```
|
90 |
+
python -m autoprompt.create_trigger \
|
91 |
+
--train $path/train.jsonl \
|
92 |
+
--dev $path/dev.jsonl \
|
93 |
+
--template '[CLS] {context} [SEP] {sub_label} [T] [T] [T] [P] . [SEP]' \
|
94 |
+
--num-cand 10 \
|
95 |
+
--accumulation-steps 1 \
|
96 |
+
--model-name bert-base-cased \
|
97 |
+
--bsz 32 \
|
98 |
+
--eval-size 32 \
|
99 |
+
--iters 500 \
|
100 |
+
--label-field 'obj_label' \
|
101 |
+
--tokenize-labels \
|
102 |
+
--filter \
|
103 |
+
--print-lama \
|
104 |
+
--use-ctx
|
105 |
+
```
|
106 |
+
|
107 |
+
## Label Token Selection
|
108 |
+
|
109 |
+
For sentiment analysis
|
110 |
+
```
|
111 |
+
python -m autoprompt.label_search --train ../data/SST-2/train.tsv --template '[CLS] {sentence} [T] [T] [T] [P]. [SEP]' --label-map '{"0": 0, "1": 1}' --iters 50 --model-name 'bert-base-cased'
|
112 |
+
```
|
113 |
+
|
114 |
+
For NLI
|
115 |
+
```
|
116 |
+
python -m autoprompt.label_search --train ../data/SICK-E-balanced/3-balance/SICK_TRAIN_ALL_S.tsv --template '[CLS] {sentence} [T] [T] [T] [P]. [SEP]' --label-map '{"entailment": 0, "contradiction": 1, "neutral": 2}' --iters 50 --model-name 'bert-base-cased'
|
117 |
+
```
|
118 |
+
|
119 |
+
## Evaluation for Fact Retrieval and Relation Extraction
|
120 |
+
|
121 |
+
### 1. Setup LAMA
|
122 |
+
Clone [our fork](https://github.com/taylorshin/LAMA) of the LAMA repo and follow the directions to set it up outside of the AutoPrompt repo.
|
123 |
+
We recommended creating a separate conda environment for LAMA due to different dependencies and requirements.
|
124 |
+
|
125 |
+
Copy the AutoPrompt data folder into the `data` directory of LAMA or set `data_path_pre` in `scripts/run_experiments.py` to a custom data location.
|
126 |
+
|
127 |
+
In order to get LAMA to work with RoBERTa, run the following commands:
|
128 |
+
```
|
129 |
+
mkdir pre-trained_language_models/roberta
|
130 |
+
cd pre-trained_language_models/roberta
|
131 |
+
curl -O https://dl.fbaipublicfiles.com/fairseq/models/roberta.large.tar.gz
|
132 |
+
tar -xvzf roberta.large.tar.gz
|
133 |
+
```
|
134 |
+
|
135 |
+
### 2. Update prompts
|
136 |
+
Update the `data/relations.jsonl` file with your own automatically generated prompts
|
137 |
+
|
138 |
+
### 3. Configure settings
|
139 |
+
To change evaluation settings, go to `scripts/run_experiments.py` and update the configurable values accordingly.
|
140 |
+
Note: each of the configurable settings are marked with a `[CONFIGURABLE]` comment.
|
141 |
+
|
142 |
+
- Uncomment the settings of the LM you want to evaluate with (and comment out the other LM settings) in the `LMs` list at the top of the file
|
143 |
+
- Update the `common_vocab_filename` field to the appropriate file. Anything evaluating both BERT and RoBERTa requires this field to be `common_vocab_cased_rob.txt` instead of the usual `common_vocab_cased.txt`.
|
144 |
+
- Set `use_ctx` to `True` if running evaluation for Relation Extraction
|
145 |
+
- Set `synthetic` to `True` for perturbed sentence evaluation for Relation Extraction
|
146 |
+
- In `get_TREx_parameters` function, set `data_path_pre` to the corresponding data path (e.g. `"../data/relation_extraction"` for Relation Extraction)
|
147 |
+
|
148 |
+
### 4. Evaluate prompts
|
149 |
+
Run the evaluation code
|
150 |
+
```
|
151 |
+
python scripts/run_experiments.py
|
152 |
+
```
|
153 |
+
|
154 |
+
### 4. Miscellaneous
|
155 |
+
Set `PYTHONPATH` if the following error occurs: `ModuleNotFoundError: No module named 'lama'`
|
156 |
+
```
|
157 |
+
export PYTHONPATH="${PYTHONPATH}:/path/to/the/AutoPrompt/repo"
|
158 |
+
```
|
159 |
+
|
160 |
+
## Citation
|
161 |
+
```
|
162 |
+
@inproceedings{autoprompt:emnlp20,
|
163 |
+
author = {Taylor Shin and Yasaman Razeghi and Robert L. Logan IV and Eric Wallace and Sameer Singh},
|
164 |
+
title = { {AutoPrompt}: Eliciting Knowledge from Language Models with Automatically Generated Prompts },
|
165 |
+
booktitle = {Empirical Methods in Natural Language Processing (EMNLP)},
|
166 |
+
year = {2020}
|
167 |
+
}
|
168 |
+
```
|
app.py
ADDED
@@ -0,0 +1,580 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import csv
|
2 |
+
from dataclasses import dataclass
|
3 |
+
import io
|
4 |
+
import json
|
5 |
+
import logging
|
6 |
+
import random
|
7 |
+
import sys
|
8 |
+
from typing import Dict, List
|
9 |
+
|
10 |
+
import pandas as pd
|
11 |
+
import streamlit as st
|
12 |
+
import torch
|
13 |
+
import transformers
|
14 |
+
from tqdm import tqdm
|
15 |
+
|
16 |
+
from autoprompt import utils
|
17 |
+
import autoprompt.create_trigger as ct
|
18 |
+
|
19 |
+
|
20 |
+
# logging.getLogger("streamlit.caching").addHandler(logging.StreamHandler(sys.stdout))
|
21 |
+
# logging.getLogger("streamlit.caching").setLevel(logging.DEBUG)
|
22 |
+
|
23 |
+
logger = logging.getLogger(__name__)
|
24 |
+
|
25 |
+
|
26 |
+
with open('assets/sst2_train.jsonl', 'r') as f:
|
27 |
+
DEFAULT_TRAIN = [json.loads(line) for line in f]
|
28 |
+
|
29 |
+
|
30 |
+
@dataclass
|
31 |
+
class CacheTest:
|
32 |
+
"""
|
33 |
+
Stores whether the train button has been pressed for a given
|
34 |
+
set of inputs to run_autoprompt.
|
35 |
+
"""
|
36 |
+
is_test: bool
|
37 |
+
|
38 |
+
|
39 |
+
class CacheMiss(Exception):
|
40 |
+
pass
|
41 |
+
|
42 |
+
|
43 |
+
def css_hack():
|
44 |
+
"""
|
45 |
+
Inject some style into this app. ヽ(⌐■_■)ノ
|
46 |
+
"""
|
47 |
+
st.markdown(
|
48 |
+
"""
|
49 |
+
<style>
|
50 |
+
code {
|
51 |
+
color: #eec66d;
|
52 |
+
}
|
53 |
+
.css-gtmd9c a {
|
54 |
+
color: #6f98af;
|
55 |
+
}
|
56 |
+
</style>
|
57 |
+
""",
|
58 |
+
unsafe_allow_html=True
|
59 |
+
)
|
60 |
+
|
61 |
+
|
62 |
+
# Setting eq and frozen ensures that a __hash__ method is generated which is needed for caching to
|
63 |
+
# properly respond to changed args.
|
64 |
+
@dataclass(eq=True, frozen=True)
|
65 |
+
class Args:
|
66 |
+
# Configurable
|
67 |
+
template: str
|
68 |
+
model_name: str
|
69 |
+
iters: int
|
70 |
+
num_cand: int
|
71 |
+
accumulation_steps: int
|
72 |
+
|
73 |
+
# Non-Configurable
|
74 |
+
seed = 0
|
75 |
+
sentence_size = 64
|
76 |
+
tokenize_labels = True
|
77 |
+
filter = False
|
78 |
+
initial_trigger = None
|
79 |
+
label_field = "label"
|
80 |
+
bsz = 32
|
81 |
+
eval_size = 1
|
82 |
+
|
83 |
+
@classmethod
|
84 |
+
def from_streamlit(cls):
|
85 |
+
st.sidebar.image('assets/icon.png', width=150)
|
86 |
+
st.sidebar.markdown('### Training Parameters')
|
87 |
+
model_name = st.sidebar.selectbox(
|
88 |
+
"Model",
|
89 |
+
options=['roberta-large', 'bert-base-cased'],
|
90 |
+
help="Language model used for training and evaluation."
|
91 |
+
)
|
92 |
+
iters = int(st.sidebar.number_input(
|
93 |
+
"Iterations",
|
94 |
+
value=10,
|
95 |
+
min_value=1,
|
96 |
+
max_value=100,
|
97 |
+
help="Number of trigger search iterations. Larger values may yield better results."
|
98 |
+
))
|
99 |
+
num_cand = int(st.sidebar.number_input(
|
100 |
+
"Number of Candidates",
|
101 |
+
value=25,
|
102 |
+
min_value=1,
|
103 |
+
max_value=100,
|
104 |
+
help="Number of candidate trigger token replacements to evaluate during each search "
|
105 |
+
"iteration. Larger values may yield better results."
|
106 |
+
))
|
107 |
+
accumulation_steps = int(st.sidebar.number_input(
|
108 |
+
"Gradient Accumulation Steps",
|
109 |
+
value=1,
|
110 |
+
min_value=1,
|
111 |
+
max_value=10,
|
112 |
+
help="Number of gradient accumulation steps used during training. Larger values may yield "
|
113 |
+
"better results. Cannot be larger than half the dataset size."
|
114 |
+
))
|
115 |
+
st.sidebar.markdown(
|
116 |
+
"""
|
117 |
+
### Template
|
118 |
+
|
119 |
+
Templates define how task-specific inputs are combined with trigger tokens to create
|
120 |
+
the prompt. They should contain the following placeholders:
|
121 |
+
- `{sentence}`: Placeholders for the task-specific input fields contain the field name
|
122 |
+
between curly brackets. For manually entered data the field name is `{sentence}`. For
|
123 |
+
uploaded csv's, field names should correspond to columns in the csv.
|
124 |
+
- `[T]`: Placeholder for a trigger token. These are learned from the training data.
|
125 |
+
- `[P]`: Placeholder for where to insert the [MASK] token that the model will predict
|
126 |
+
on.
|
127 |
+
|
128 |
+
Templates can also include manually written text (such as the
|
129 |
+
period in the default example below).
|
130 |
+
"""
|
131 |
+
)
|
132 |
+
template = st.sidebar.text_input("Template", "{sentence} [T] [T] [T] [P].")
|
133 |
+
return cls(
|
134 |
+
template=template,
|
135 |
+
model_name=model_name,
|
136 |
+
iters=iters,
|
137 |
+
num_cand=num_cand,
|
138 |
+
accumulation_steps=accumulation_steps,
|
139 |
+
)
|
140 |
+
|
141 |
+
|
142 |
+
# TODO(rloganiv): This probably could use a better name...
|
143 |
+
@dataclass
|
144 |
+
class GlobalData:
|
145 |
+
device: torch.device
|
146 |
+
config: transformers.PretrainedConfig
|
147 |
+
model: transformers.PreTrainedModel
|
148 |
+
tokenizer: transformers.PreTrainedTokenizer
|
149 |
+
embeddings: torch.nn.Module
|
150 |
+
embedding_gradient: ct.GradientStorage
|
151 |
+
predictor: ct.PredictWrapper
|
152 |
+
|
153 |
+
@classmethod
|
154 |
+
@st.cache(allow_output_mutation=True)
|
155 |
+
def from_pretrained(cls, model_name):
|
156 |
+
logger.info(f'Loading pretrained model: {model_name}')
|
157 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
158 |
+
if torch.cuda.is_available():
|
159 |
+
st.write('CUDA is available')
|
160 |
+
else:
|
161 |
+
st.write('CUDA not available')
|
162 |
+
config, model, tokenizer = ct.load_pretrained(model_name)
|
163 |
+
model.to(device)
|
164 |
+
embeddings = ct.get_embeddings(model, config)
|
165 |
+
embedding_gradient = ct.GradientStorage(embeddings)
|
166 |
+
predictor = ct.PredictWrapper(model)
|
167 |
+
return cls(
|
168 |
+
device,
|
169 |
+
config,
|
170 |
+
model,
|
171 |
+
tokenizer,
|
172 |
+
embeddings,
|
173 |
+
embedding_gradient,
|
174 |
+
predictor
|
175 |
+
)
|
176 |
+
|
177 |
+
|
178 |
+
@dataclass
|
179 |
+
class Dataset:
|
180 |
+
train: List[int]
|
181 |
+
label_map: Dict[str, str]
|
182 |
+
|
183 |
+
|
184 |
+
def load_trigger_dataset(dataset, templatizer):
|
185 |
+
instances = []
|
186 |
+
for x in dataset:
|
187 |
+
instances.append(templatizer(x))
|
188 |
+
return instances
|
189 |
+
|
190 |
+
|
191 |
+
@st.cache(suppress_st_warning=True, allow_output_mutation=True, hash_funcs={CacheTest: lambda o: 0})
|
192 |
+
def run_autoprompt(args, dataset, cache_test):
|
193 |
+
if cache_test.is_test:
|
194 |
+
raise CacheMiss()
|
195 |
+
|
196 |
+
ct.set_seed(args.seed)
|
197 |
+
global_data = GlobalData.from_pretrained(args.model_name)
|
198 |
+
|
199 |
+
templatizer = utils.TriggerTemplatizer(
|
200 |
+
args.template,
|
201 |
+
global_data.config,
|
202 |
+
global_data.tokenizer,
|
203 |
+
label_field=args.label_field,
|
204 |
+
label_map=dataset.label_map,
|
205 |
+
tokenize_labels=args.tokenize_labels,
|
206 |
+
add_special_tokens=True,
|
207 |
+
)
|
208 |
+
evaluation_fn = ct.AccuracyFn(global_data.tokenizer, dataset.label_map, global_data.device,
|
209 |
+
tokenize_labels=args.tokenize_labels)
|
210 |
+
|
211 |
+
# Do not allow for initial trigger specification.
|
212 |
+
trigger_ids = [global_data.tokenizer.mask_token_id] * templatizer.num_trigger_tokens
|
213 |
+
trigger_ids = torch.tensor(trigger_ids, device=global_data.device).unsqueeze(0)
|
214 |
+
best_trigger_ids = trigger_ids.clone()
|
215 |
+
|
216 |
+
# Load datasets
|
217 |
+
logger.info('Loading datasets')
|
218 |
+
collator = utils.Collator(pad_token_id=global_data.tokenizer.pad_token_id)
|
219 |
+
try:
|
220 |
+
train_dataset = load_trigger_dataset(dataset.train, templatizer)
|
221 |
+
except KeyError as e:
|
222 |
+
raise RuntimeError(
|
223 |
+
'A field in your template is not present in the uploaded dataset. '
|
224 |
+
f'Check that there is a column with the name: {e}'
|
225 |
+
)
|
226 |
+
|
227 |
+
train_loader = torch.utils.data.DataLoader(
|
228 |
+
train_dataset, batch_size=args.bsz, shuffle=True, collate_fn=collator)
|
229 |
+
|
230 |
+
progress = st.progress(0.0)
|
231 |
+
trigger_placeholder = st.empty()
|
232 |
+
best_dev_metric = -float('inf')
|
233 |
+
for i in range(args.iters):
|
234 |
+
logger.info(f'Iteration: {i}')
|
235 |
+
progress.progress(float(i)/args.iters)
|
236 |
+
|
237 |
+
current_trigger = ','.join(global_data.tokenizer.convert_ids_to_tokens(best_trigger_ids.squeeze(0)))
|
238 |
+
trigger_placeholder.markdown(f'**Current trigger**: {current_trigger}')
|
239 |
+
|
240 |
+
global_data.model.zero_grad()
|
241 |
+
train_iter = iter(train_loader)
|
242 |
+
averaged_grad = None
|
243 |
+
|
244 |
+
# Compute gradient of loss
|
245 |
+
for step in range(args.accumulation_steps):
|
246 |
+
try:
|
247 |
+
model_inputs, labels = next(train_iter)
|
248 |
+
except:
|
249 |
+
logger.warning(
|
250 |
+
'Insufficient data for number of accumulation steps. '
|
251 |
+
'Effective batch size will be smaller than specified.'
|
252 |
+
)
|
253 |
+
break
|
254 |
+
model_inputs = {k: v.to(global_data.device) for k, v in model_inputs.items()}
|
255 |
+
labels = labels.to(global_data.device)
|
256 |
+
predict_logits = global_data.predictor(model_inputs, trigger_ids)
|
257 |
+
loss = ct.get_loss(predict_logits, labels).mean()
|
258 |
+
loss.backward()
|
259 |
+
|
260 |
+
grad = global_data.embedding_gradient.get()
|
261 |
+
bsz, _, emb_dim = grad.size()
|
262 |
+
selection_mask = model_inputs['trigger_mask'].unsqueeze(-1)
|
263 |
+
grad = torch.masked_select(grad, selection_mask)
|
264 |
+
grad = grad.view(bsz, templatizer.num_trigger_tokens, emb_dim)
|
265 |
+
|
266 |
+
if averaged_grad is None:
|
267 |
+
averaged_grad = grad.sum(dim=0) / args.accumulation_steps
|
268 |
+
else:
|
269 |
+
averaged_grad += grad.sum(dim=0) / args.accumulation_steps
|
270 |
+
|
271 |
+
logger.info('Evaluating Candidates')
|
272 |
+
pbar = tqdm(range(args.accumulation_steps))
|
273 |
+
train_iter = iter(train_loader)
|
274 |
+
|
275 |
+
token_to_flip = i % templatizer.num_trigger_tokens
|
276 |
+
candidates = ct.hotflip_attack(averaged_grad[token_to_flip],
|
277 |
+
global_data.embeddings.weight,
|
278 |
+
increase_loss=False,
|
279 |
+
num_candidates=args.num_cand)
|
280 |
+
current_score = 0
|
281 |
+
candidate_scores = torch.zeros(args.num_cand, device=global_data.device)
|
282 |
+
denom = 0
|
283 |
+
for step in pbar:
|
284 |
+
try:
|
285 |
+
model_inputs, labels = next(train_iter)
|
286 |
+
except:
|
287 |
+
logger.warning(
|
288 |
+
'Insufficient data for number of accumulation steps. '
|
289 |
+
'Effective batch size will be smaller than specified.'
|
290 |
+
)
|
291 |
+
break
|
292 |
+
model_inputs = {k: v.to(global_data.device) for k, v in model_inputs.items()}
|
293 |
+
labels = labels.to(global_data.device)
|
294 |
+
with torch.no_grad():
|
295 |
+
predict_logits = global_data.predictor(model_inputs, trigger_ids)
|
296 |
+
eval_metric = evaluation_fn(predict_logits, labels)
|
297 |
+
|
298 |
+
# Update current score
|
299 |
+
current_score += eval_metric.sum()
|
300 |
+
denom += labels.size(0)
|
301 |
+
|
302 |
+
# NOTE: Instead of iterating over tokens to flip we randomly change just one each
|
303 |
+
# time so the gradients don't get stale.
|
304 |
+
for i, candidate in enumerate(candidates):
|
305 |
+
|
306 |
+
# if candidate.item() in filter_candidates:
|
307 |
+
# candidate_scores[i] = -1e32
|
308 |
+
# continue
|
309 |
+
|
310 |
+
temp_trigger = trigger_ids.clone()
|
311 |
+
temp_trigger[:, token_to_flip] = candidate
|
312 |
+
with torch.no_grad():
|
313 |
+
predict_logits = global_data.predictor(model_inputs, temp_trigger)
|
314 |
+
eval_metric = evaluation_fn(predict_logits, labels)
|
315 |
+
|
316 |
+
candidate_scores[i] += eval_metric.sum()
|
317 |
+
|
318 |
+
if (candidate_scores >= current_score).any():
|
319 |
+
logger.info('Better trigger detected.')
|
320 |
+
best_candidate_score = candidate_scores.max()
|
321 |
+
best_candidate_idx = candidate_scores.argmax()
|
322 |
+
trigger_ids[:, token_to_flip] = candidates[best_candidate_idx]
|
323 |
+
logger.info(f'Train metric: {best_candidate_score / (denom + 1e-13): 0.4f}')
|
324 |
+
|
325 |
+
# Skip eval
|
326 |
+
best_trigger_ids = trigger_ids.clone()
|
327 |
+
|
328 |
+
progress.progress(1.0)
|
329 |
+
current_trigger = ','.join(global_data.tokenizer.convert_ids_to_tokens(best_trigger_ids.squeeze(0)))
|
330 |
+
trigger_placeholder.markdown(f'**Current trigger**: {current_trigger}')
|
331 |
+
|
332 |
+
best_trigger_tokens = global_data.tokenizer.convert_ids_to_tokens(best_trigger_ids.squeeze(0))
|
333 |
+
|
334 |
+
train_output = predict_test(map(lambda x: x['sentence'], dataset.train), dataset.label_map,
|
335 |
+
templatizer, best_trigger_ids, global_data.tokenizer, global_data.predictor, args)
|
336 |
+
|
337 |
+
# Streamlit does not like accessing widgets across functions, which is
|
338 |
+
# problematic for this "live updating" widget which we want to still
|
339 |
+
# display even if the train output is cached. To get around this, we're
|
340 |
+
# going to delete the widget and replace it with a very similar looking
|
341 |
+
# widget outside the function...no one will ever notice ;)
|
342 |
+
trigger_placeholder.empty()
|
343 |
+
|
344 |
+
return (
|
345 |
+
best_trigger_tokens,
|
346 |
+
current_score/denom,
|
347 |
+
dataset.label_map,
|
348 |
+
templatizer,
|
349 |
+
best_trigger_ids,
|
350 |
+
global_data.tokenizer,
|
351 |
+
global_data.predictor,
|
352 |
+
args,
|
353 |
+
train_output
|
354 |
+
)
|
355 |
+
|
356 |
+
|
357 |
+
def predict_test(sentences, label_map, templatizer, best_trigger_ids, tokenizer, predictor, args):
|
358 |
+
# Evaluate clean
|
359 |
+
output = { 'sentences': [] }
|
360 |
+
any_label = None
|
361 |
+
for label in label_map.values():
|
362 |
+
output[label] = []
|
363 |
+
any_label = label
|
364 |
+
output['prompt'] = []
|
365 |
+
for sentence in sentences:
|
366 |
+
model_inputs, _ = templatizer({'sentence': sentence, 'label': any_label})
|
367 |
+
model_inputs = {k: v.to(best_trigger_ids.device) for k, v in model_inputs.items()}
|
368 |
+
|
369 |
+
prompt_ids = ct.replace_trigger_tokens(
|
370 |
+
model_inputs, best_trigger_ids, model_inputs['trigger_mask'])
|
371 |
+
|
372 |
+
prompt = ' '.join(tokenizer.convert_ids_to_tokens(prompt_ids['input_ids'][0]))
|
373 |
+
output['prompt'].append(prompt)
|
374 |
+
|
375 |
+
predict_logits = predictor(model_inputs, best_trigger_ids)
|
376 |
+
output['sentences'].append(sentence)
|
377 |
+
for label in label_map.values():
|
378 |
+
label_id = utils.encode_label(tokenizer=tokenizer, label=label, tokenize=args.tokenize_labels)
|
379 |
+
label_id = label_id.to(best_trigger_ids.device)
|
380 |
+
label_loss = ct.get_loss(predict_logits, label_id)
|
381 |
+
# st.write(sentence, label, label_loss)
|
382 |
+
output[label].append(label_loss.item())
|
383 |
+
return output
|
384 |
+
|
385 |
+
|
386 |
+
def manual_dataset(use_defaults):
|
387 |
+
|
388 |
+
num_train_instances = st.slider("Number of Train Instances", 4, 32, 8)
|
389 |
+
any_empty = False
|
390 |
+
dataset = []
|
391 |
+
data_col, label_col = st.beta_columns([3,1])
|
392 |
+
for i in range(num_train_instances):
|
393 |
+
default_data = DEFAULT_TRAIN[i]['sentence'] if use_defaults else ''
|
394 |
+
default_label = DEFAULT_TRAIN[i]['label'] if use_defaults else ''
|
395 |
+
with data_col:
|
396 |
+
data = st.text_input("Train Instance " + str(i+1), default_data)
|
397 |
+
with label_col:
|
398 |
+
label = st.text_input("Train Label " + str(i+1), default_label, max_chars=20)
|
399 |
+
if data == "" or label == "":
|
400 |
+
any_empty = True
|
401 |
+
dataset.append({'sentence': data, 'label': label})
|
402 |
+
|
403 |
+
label_set = list(set(map(lambda x: x['label'], dataset)))
|
404 |
+
label_idx = {x: i for i, x in enumerate(label_set)}
|
405 |
+
label_map = dict(map(lambda x: (x, x), label_set))
|
406 |
+
|
407 |
+
if any_empty:
|
408 |
+
st.warning('Waiting for data to be added')
|
409 |
+
st.stop()
|
410 |
+
|
411 |
+
if len(label_set) < 2:
|
412 |
+
st.warning('Not enough labels')
|
413 |
+
st.stop()
|
414 |
+
|
415 |
+
return Dataset(
|
416 |
+
train=dataset,
|
417 |
+
label_map=label_map
|
418 |
+
)
|
419 |
+
|
420 |
+
|
421 |
+
def csv_dataset():
|
422 |
+
st.markdown("""
|
423 |
+
Please upload your training and evaluation csv files.
|
424 |
+
|
425 |
+
Format restrictions:
|
426 |
+
- The file is required to have a header
|
427 |
+
- The column name of the output field should be `label`.
|
428 |
+
- Each file should contain no more than 64 rows.
|
429 |
+
""")
|
430 |
+
train_csv = st.file_uploader('Train', accept_multiple_files=False)
|
431 |
+
|
432 |
+
if train_csv is None:
|
433 |
+
st.stop()
|
434 |
+
|
435 |
+
with io.StringIO(train_csv.getvalue().decode('utf-8')) as f:
|
436 |
+
reader = csv.DictReader(f)
|
437 |
+
train_dataset = list(reader)
|
438 |
+
if len(train_dataset) > 64:
|
439 |
+
raise ValueError('Train dataset is too large. Please limit the number '
|
440 |
+
'of examples to 64 or less.')
|
441 |
+
|
442 |
+
labels = set(x['label'] for x in train_dataset)
|
443 |
+
label_map = {x: x for x in labels}
|
444 |
+
|
445 |
+
return Dataset(
|
446 |
+
train=train_dataset,
|
447 |
+
label_map=label_map
|
448 |
+
)
|
449 |
+
|
450 |
+
|
451 |
+
def run():
|
452 |
+
css_hack()
|
453 |
+
st.title('AutoPrompt Demo')
|
454 |
+
st.markdown('''
|
455 |
+
For many years, the predominant approach for training machine learning
|
456 |
+
models to solve NLP tasks has been to use supervised training data to
|
457 |
+
estimate model parameters using maximum likelihood estimation or some
|
458 |
+
similar paradigm. Whether fitting a logistic regression model over a
|
459 |
+
bag-of-words, an LSTM over a sequence of GloVe embeddings, or finetuning a
|
460 |
+
language model such as ELMo or BERT, the approach is essentially the same.
|
461 |
+
However, as language models have become more and more capable of accurately
|
462 |
+
generating plausible text a new possibility for solving classification
|
463 |
+
tasks has emerged...
|
464 |
+
|
465 |
+
## Prompting
|
466 |
+
|
467 |
+
Prompting is the method of converting classification tasks into
|
468 |
+
*fill-in-the-blanks* problems that can be solved by a language model **without
|
469 |
+
modifying the model's internals**. For example, to perform sentiment analysis,
|
470 |
+
we may take the sentence we wish to classify and append the text "Overall, this
|
471 |
+
movie was ____." and feed it into a language model like so:
|
472 |
+
''')
|
473 |
+
# st.image('assets/bert-mouth.png', use_column_width=True)
|
474 |
+
st.markdown('''
|
475 |
+
By measuring whether the language model assigns a higher probability to
|
476 |
+
words that are associated with a **positive** sentiment ("good", "great",
|
477 |
+
and "fantastic") vs. words that are associated with a **negative**
|
478 |
+
sentiment ("bad", "terrible", or "awful") we can infer the
|
479 |
+
predicted label for the given input. So in this example, because the word "good"
|
480 |
+
has a higher probability than "bad", the predicted label is **positive**.
|
481 |
+
|
482 |
+
## AutoPrompt
|
483 |
+
|
484 |
+
One issue that arises when using prompts is that it is not usually clear
|
485 |
+
how to best pose a task as a fill-in-the-blanks problem in a way that gets
|
486 |
+
the most performance from the language model. Even for a simple problem
|
487 |
+
like sentiment analysis, we don't know whether it is better to ask whether
|
488 |
+
a movie is good/bad, or whether you feel great/terrible about it, and for
|
489 |
+
more abstract problems like natural language inference it is difficult to
|
490 |
+
even know where to start.
|
491 |
+
|
492 |
+
To cure this writer's block we introduce **AutoPrompt**, a data-driven
|
493 |
+
approach for automatic prompt construction. The basic idea is
|
494 |
+
straightfoward: instead of writing a prompt, a user need only write a
|
495 |
+
**template** that specfies where the *task inputs* go along with placeholders for
|
496 |
+
a number of *trigger tokens* that will automatically be learned by the
|
497 |
+
model and the *predict token* that the model will fill in:
|
498 |
+
''')
|
499 |
+
# st.image('assets/template.png', use_column_width=True)
|
500 |
+
st.markdown(
|
501 |
+
'''
|
502 |
+
In each iteration of the search process:
|
503 |
+
1. The template is instantiated using a batch of training inputs.
|
504 |
+
2. The loss of the model on each input is measured and used to identify a
|
505 |
+
number of candidate replacements for the current trigger tokens.
|
506 |
+
3. The performance of each candidate is measured on another batch of
|
507 |
+
training data, and the best performing candidate is used in the next
|
508 |
+
iteration.
|
509 |
+
|
510 |
+
### Demo
|
511 |
+
|
512 |
+
To give a better sense of how AutoPrompt works, we have provided a simple
|
513 |
+
interactive demo. You can generate a prompt using the training data we have
|
514 |
+
pre-populated for you, or alternatively write your own training/evaluation
|
515 |
+
instances or upload them using a csv below. In addition, you can vary
|
516 |
+
some of the training parameters, as well as the template using the sidebar
|
517 |
+
on the left.
|
518 |
+
'''
|
519 |
+
)
|
520 |
+
args = Args.from_streamlit()
|
521 |
+
dataset_mode = st.radio('How would you like to input your training data?',
|
522 |
+
options=['Example Data', 'Manual Input', 'From CSV'])
|
523 |
+
|
524 |
+
if dataset_mode == 'Example Data':
|
525 |
+
dataset = manual_dataset(use_defaults=True)
|
526 |
+
elif dataset_mode == 'Manual Input':
|
527 |
+
dataset = manual_dataset(use_defaults=False)
|
528 |
+
else:
|
529 |
+
dataset = csv_dataset()
|
530 |
+
|
531 |
+
button = st.empty()
|
532 |
+
clicked = button.button('Train')
|
533 |
+
|
534 |
+
if clicked:
|
535 |
+
trigger_tokens, eval_metric, label_map, templatizer, best_trigger_ids, tokenizer, predictor, args, train_output = run_autoprompt(args, dataset, cache_test=CacheTest(False))
|
536 |
+
else:
|
537 |
+
try:
|
538 |
+
trigger_tokens, eval_metric, label_map, templatizer, best_trigger_ids, tokenizer, predictor, args, train_output = run_autoprompt(args, dataset, cache_test=CacheTest(True))
|
539 |
+
except CacheMiss:
|
540 |
+
st.stop()
|
541 |
+
else:
|
542 |
+
button.empty()
|
543 |
+
|
544 |
+
|
545 |
+
st.markdown(f'**Final trigger**: {", ".join(trigger_tokens)}')
|
546 |
+
st.dataframe(pd.DataFrame(train_output).style.highlight_min(axis=1, color='#94666b'))
|
547 |
+
logger.debug('Dev metric')
|
548 |
+
st.write('Accuracy: ' + str(round(eval_metric.item()*100, 1)))
|
549 |
+
st.write("""
|
550 |
+
Et voila, you've now effectively finetuned a classifier using just a few
|
551 |
+
kilobytes of parameters (the tokens in the prompt). If you like you can
|
552 |
+
write down your "model" on the back of a napkin and take it with you.
|
553 |
+
|
554 |
+
### Try it out yourself!
|
555 |
+
|
556 |
+
""")
|
557 |
+
sentence = st.text_input("Sentence", 'Enter a test input here')
|
558 |
+
pred_output = predict_test([sentence], label_map ,templatizer, best_trigger_ids, tokenizer, predictor, args)
|
559 |
+
st.dataframe(pd.DataFrame(pred_output).style.highlight_min(axis=1, color='#94666b'))
|
560 |
+
|
561 |
+
st.markdown('''
|
562 |
+
## Where can I learn more?
|
563 |
+
|
564 |
+
If you are interested in learning more about AutoPrompt we recommend
|
565 |
+
[reading our paper](https://arxiv.org/abs/2010.15980) and [checking out our
|
566 |
+
code](https://github.com/ucinlp/autoprompt), or if you'd like you can also
|
567 |
+
watch our presentation at EMNLP 2020:
|
568 |
+
''')
|
569 |
+
st.components.v1.iframe(
|
570 |
+
src="https://www.youtube.com/embed/IBMT_oOCBbc",
|
571 |
+
height=400,
|
572 |
+
)
|
573 |
+
st.markdown('Thanks!')
|
574 |
+
|
575 |
+
|
576 |
+
if __name__ == '__main__':
|
577 |
+
logging.basicConfig(level=logging.INFO,
|
578 |
+
stream=sys.stdout)
|
579 |
+
run()
|
580 |
+
|
app/.streamlit/config.toml
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[server]
|
2 |
+
enableCORS = false
|
3 |
+
enableXsrfProtection = false
|
4 |
+
|
5 |
+
[theme]
|
6 |
+
primaryColor="#96666b"
|
7 |
+
backgroundColor="#28282d"
|
8 |
+
secondaryBackgroundColor="#333333"
|
9 |
+
textColor="#f3f3f3"
|
10 |
+
font="monospace"
|
assets/icon.png
ADDED
![]() |
assets/sst2_train.jsonl
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{"label": "terrible", "idx": "61123", "sentence": "in its yearning for the days "}
|
2 |
+
{"label": "great", "idx": "23159", "sentence": "of riveting set pieces "}
|
3 |
+
{"label": "great", "idx": "21277", "sentence": "all-star reunions "}
|
4 |
+
{"label": "terrible", "idx": "27987", "sentence": "( swimfan ) falls victim to sloppy plotting , an insultingly unbelievable final act and a villainess who is too crazy to be interesting . "}
|
5 |
+
{"label": "great", "idx": "23240", "sentence": "the leads ) are such a companionable couple "}
|
6 |
+
{"label": "great", "idx": "25838", "sentence": "astonishingly skillful and moving "}
|
7 |
+
{"label": "terrible", "idx": "35653", "sentence": "has been sacrificed for the sake of spectacle "}
|
8 |
+
{"label": "terrible", "idx": "49778", "sentence": "hard to imagine acting that could be any flatter "}
|
9 |
+
{"label": "great", "idx": "2403", "sentence": "the better video-game-based flicks , "}
|
10 |
+
{"label": "great", "idx": "135", "sentence": "so many of the challenges it poses for itself that one can forgive the film its flaws "}
|
11 |
+
{"label": "great", "idx": "42426", "sentence": "while somewhat less than it might have been , the film is a good one "}
|
12 |
+
{"label": "great", "idx": "6863", "sentence": "has a dashing and resourceful hero ; a lisping , reptilian villain ; big fights ; big hair ; lavish period scenery ; and a story "}
|
13 |
+
{"label": "terrible", "idx": "42330", "sentence": "in the end , the weight of water comes to resemble the kind of soft-core twaddle you 'd expect to see on showtime 's ` red shoe diaries . ' "}
|
14 |
+
{"label": "terrible", "idx": "57545", "sentence": "stuck in heaven because he 's afraid of his best-known creation ? "}
|
15 |
+
{"label": "terrible", "idx": "23530", "sentence": "can be as tiresome as 9 seconds of jesse helms ' anti- castro "}
|
16 |
+
{"label": "great", "idx": "54745", "sentence": "tackles the difficult subject of grief and loss with such life-embracing spirit that the theme does n't drag an audience down "}
|
17 |
+
{"label": "terrible", "idx": "30797", "sentence": "violence "}
|
18 |
+
{"label": "great", "idx": "30169", "sentence": "feminine energy , a tribute to the power of women to heal "}
|
19 |
+
{"label": "terrible", "idx": "42869", "sentence": "flat as a spoof "}
|
20 |
+
{"label": "terrible", "idx": "33313", "sentence": "( somebody suggested the stills might make a nice coffee table book ) "}
|
21 |
+
{"label": "great", "idx": "10766", "sentence": "brosnan 's finest non-bondish performance "}
|
22 |
+
{"label": "terrible", "idx": "15631", "sentence": "all seemed wasted like deniro 's once promising career and the once grand long beach boardwalk . "}
|
23 |
+
{"label": "great", "idx": "3401", "sentence": "a lot smarter and more unnerving than the sequels "}
|
24 |
+
{"label": "terrible", "idx": "29775", "sentence": "it may not be particularly innovative "}
|
25 |
+
{"label": "great", "idx": "39776", "sentence": "it 's packed with adventure and a worthwhile environmental message , so it 's great for the kids . "}
|
26 |
+
{"label": "terrible", "idx": "64246", "sentence": "mind ugly "}
|
27 |
+
{"label": "terrible", "idx": "57404", "sentence": "meandering , norton has to recite bland police procedural details , fiennes wanders around in an attempt to seem weird and distanced , hopkins looks like a drag queen "}
|
28 |
+
{"label": "terrible", "idx": "3766", "sentence": "pathetic idea "}
|
29 |
+
{"label": "great", "idx": "63925", "sentence": "has done his homework and "}
|
30 |
+
{"label": "great", "idx": "50687", "sentence": "a very compelling , sensitive , intelligent and almost cohesive piece "}
|
31 |
+
{"label": "terrible", "idx": "54945", "sentence": "turn and devolves "}
|
32 |
+
{"label": "great", "idx": "35874", "sentence": "about this silly , outrageous , ingenious thriller "}
|
autoprompt/__init__.py
ADDED
File without changes
|
autoprompt/create_trigger.py
ADDED
@@ -0,0 +1,523 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import time
|
2 |
+
import argparse
|
3 |
+
import json
|
4 |
+
import logging
|
5 |
+
from pathlib import Path
|
6 |
+
import random
|
7 |
+
|
8 |
+
import numpy as np
|
9 |
+
import torch
|
10 |
+
import torch.nn.functional as F
|
11 |
+
from torch.utils.data import DataLoader
|
12 |
+
import transformers
|
13 |
+
from transformers import AutoConfig, AutoModelWithLMHead, AutoTokenizer
|
14 |
+
from tqdm import tqdm
|
15 |
+
|
16 |
+
import autoprompt.utils as utils
|
17 |
+
|
18 |
+
|
19 |
+
logger = logging.getLogger(__name__)
|
20 |
+
|
21 |
+
|
22 |
+
class GradientStorage:
|
23 |
+
"""
|
24 |
+
This object stores the intermediate gradients of the output a the given PyTorch module, which
|
25 |
+
otherwise might not be retained.
|
26 |
+
"""
|
27 |
+
def __init__(self, module):
|
28 |
+
self._stored_gradient = None
|
29 |
+
module.register_backward_hook(self.hook)
|
30 |
+
|
31 |
+
def hook(self, module, grad_in, grad_out):
|
32 |
+
self._stored_gradient = grad_out[0]
|
33 |
+
|
34 |
+
def get(self):
|
35 |
+
return self._stored_gradient
|
36 |
+
|
37 |
+
|
38 |
+
class PredictWrapper:
|
39 |
+
"""
|
40 |
+
PyTorch transformers model wrapper. Handles necc. preprocessing of inputs for triggers
|
41 |
+
experiments.
|
42 |
+
"""
|
43 |
+
def __init__(self, model):
|
44 |
+
self._model = model
|
45 |
+
|
46 |
+
def __call__(self, model_inputs, trigger_ids):
|
47 |
+
# Copy dict so pop operations don't have unwanted side-effects
|
48 |
+
model_inputs = model_inputs.copy()
|
49 |
+
trigger_mask = model_inputs.pop('trigger_mask')
|
50 |
+
predict_mask = model_inputs.pop('predict_mask')
|
51 |
+
model_inputs = replace_trigger_tokens(model_inputs, trigger_ids, trigger_mask)
|
52 |
+
logits, *_ = self._model(**model_inputs)
|
53 |
+
predict_logits = logits.masked_select(predict_mask.unsqueeze(-1)).view(logits.size(0), -1)
|
54 |
+
return predict_logits
|
55 |
+
|
56 |
+
|
57 |
+
class AccuracyFn:
|
58 |
+
"""
|
59 |
+
Computing the accuracy when a label is mapped to multiple tokens is difficult in the current
|
60 |
+
framework, since the data generator only gives us the token ids. To get around this we
|
61 |
+
compare the target logp to the logp of all labels. If target logp is greater than all (but)
|
62 |
+
one of the label logps we know we are accurate.
|
63 |
+
"""
|
64 |
+
def __init__(self, tokenizer, label_map, device, tokenize_labels=False):
|
65 |
+
self._all_label_ids = []
|
66 |
+
self._pred_to_label = []
|
67 |
+
logger.info(label_map)
|
68 |
+
for label, label_tokens in label_map.items():
|
69 |
+
self._all_label_ids.append(utils.encode_label(tokenizer, label_tokens, tokenize_labels).to(device))
|
70 |
+
self._pred_to_label.append(label)
|
71 |
+
logger.info(self._all_label_ids)
|
72 |
+
|
73 |
+
def __call__(self, predict_logits, gold_label_ids):
|
74 |
+
# Get total log-probability for the true label
|
75 |
+
gold_logp = get_loss(predict_logits, gold_label_ids)
|
76 |
+
|
77 |
+
# Get total log-probability for all labels
|
78 |
+
bsz = predict_logits.size(0)
|
79 |
+
all_label_logp = []
|
80 |
+
for label_ids in self._all_label_ids:
|
81 |
+
label_logp = get_loss(predict_logits, label_ids.repeat(bsz, 1))
|
82 |
+
all_label_logp.append(label_logp)
|
83 |
+
all_label_logp = torch.stack(all_label_logp, dim=-1)
|
84 |
+
_, predictions = all_label_logp.max(dim=-1)
|
85 |
+
predictions = [self._pred_to_label[x] for x in predictions.tolist()]
|
86 |
+
|
87 |
+
# Add up the number of entries where loss is greater than or equal to gold_logp.
|
88 |
+
ge_count = all_label_logp.le(gold_logp.unsqueeze(-1)).sum(-1)
|
89 |
+
correct = ge_count.le(1) # less than in case of num. prec. issues
|
90 |
+
|
91 |
+
return correct.float()
|
92 |
+
|
93 |
+
# TODO: @rloganiv - This is hacky. Replace with something sensible.
|
94 |
+
def predict(self, predict_logits):
|
95 |
+
bsz = predict_logits.size(0)
|
96 |
+
all_label_logp = []
|
97 |
+
for label_ids in self._all_label_ids:
|
98 |
+
label_logp = get_loss(predict_logits, label_ids.repeat(bsz, 1))
|
99 |
+
all_label_logp.append(label_logp)
|
100 |
+
all_label_logp = torch.stack(all_label_logp, dim=-1)
|
101 |
+
_, predictions = all_label_logp.max(dim=-1)
|
102 |
+
predictions = [self._pred_to_label[x] for x in predictions.tolist()]
|
103 |
+
return predictions
|
104 |
+
|
105 |
+
|
106 |
+
def load_pretrained(model_name):
|
107 |
+
"""
|
108 |
+
Loads pretrained HuggingFace config/model/tokenizer, as well as performs required
|
109 |
+
initialization steps to facilitate working with triggers.
|
110 |
+
"""
|
111 |
+
config = AutoConfig.from_pretrained(model_name)
|
112 |
+
model = AutoModelWithLMHead.from_pretrained(model_name)
|
113 |
+
model.eval()
|
114 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name, add_prefix_space=True)
|
115 |
+
utils.add_task_specific_tokens(tokenizer)
|
116 |
+
return config, model, tokenizer
|
117 |
+
|
118 |
+
|
119 |
+
def set_seed(seed: int):
|
120 |
+
"""Sets the relevant random seeds."""
|
121 |
+
random.seed(seed)
|
122 |
+
np.random.seed(seed)
|
123 |
+
torch.random.manual_seed(seed)
|
124 |
+
torch.cuda.manual_seed(seed)
|
125 |
+
|
126 |
+
|
127 |
+
def get_embeddings(model, config):
|
128 |
+
"""Returns the wordpiece embedding module."""
|
129 |
+
base_model = getattr(model, config.model_type)
|
130 |
+
embeddings = base_model.embeddings.word_embeddings
|
131 |
+
return embeddings
|
132 |
+
|
133 |
+
|
134 |
+
def hotflip_attack(averaged_grad,
|
135 |
+
embedding_matrix,
|
136 |
+
increase_loss=False,
|
137 |
+
num_candidates=1,
|
138 |
+
filter=None):
|
139 |
+
"""Returns the top candidate replacements."""
|
140 |
+
with torch.no_grad():
|
141 |
+
gradient_dot_embedding_matrix = torch.matmul(
|
142 |
+
embedding_matrix,
|
143 |
+
averaged_grad
|
144 |
+
)
|
145 |
+
if filter is not None:
|
146 |
+
gradient_dot_embedding_matrix -= filter
|
147 |
+
if not increase_loss:
|
148 |
+
gradient_dot_embedding_matrix *= -1
|
149 |
+
_, top_k_ids = gradient_dot_embedding_matrix.topk(num_candidates)
|
150 |
+
|
151 |
+
return top_k_ids
|
152 |
+
|
153 |
+
|
154 |
+
def replace_trigger_tokens(model_inputs, trigger_ids, trigger_mask):
|
155 |
+
"""Replaces the trigger tokens in input_ids."""
|
156 |
+
out = model_inputs.copy()
|
157 |
+
input_ids = model_inputs['input_ids']
|
158 |
+
trigger_ids = trigger_ids.repeat(trigger_mask.size(0), 1)
|
159 |
+
try:
|
160 |
+
filled = input_ids.masked_scatter(trigger_mask, trigger_ids)
|
161 |
+
except RuntimeError:
|
162 |
+
filled = input_ids
|
163 |
+
out['input_ids'] = filled
|
164 |
+
return out
|
165 |
+
|
166 |
+
|
167 |
+
def get_loss(predict_logits, label_ids):
|
168 |
+
predict_logp = F.log_softmax(predict_logits, dim=-1)
|
169 |
+
target_logp = predict_logp.gather(-1, label_ids)
|
170 |
+
target_logp = target_logp - 1e32 * label_ids.eq(0) # Apply mask
|
171 |
+
target_logp = torch.logsumexp(target_logp, dim=-1)
|
172 |
+
return -target_logp
|
173 |
+
|
174 |
+
|
175 |
+
def isupper(idx, tokenizer):
|
176 |
+
"""
|
177 |
+
Determines whether a token (e.g., word piece) begins with a capital letter.
|
178 |
+
"""
|
179 |
+
_isupper = False
|
180 |
+
# We only want to check tokens that begin words. Since byte-pair encoding
|
181 |
+
# captures a prefix space, we need to check that the decoded token begins
|
182 |
+
# with a space, and has a capitalized second character.
|
183 |
+
if isinstance(tokenizer, transformers.GPT2Tokenizer):
|
184 |
+
decoded = tokenizer.decode([idx])
|
185 |
+
if decoded[0] == ' ' and decoded[1].isupper():
|
186 |
+
_isupper = True
|
187 |
+
# For all other tokenization schemes, we can just check the first character
|
188 |
+
# is capitalized.
|
189 |
+
elif tokenizer.decode([idx])[0].isupper():
|
190 |
+
_isupper = True
|
191 |
+
return _isupper
|
192 |
+
|
193 |
+
|
194 |
+
def run_model(args):
|
195 |
+
|
196 |
+
set_seed(args.seed)
|
197 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
198 |
+
|
199 |
+
logger.info('Loading model, tokenizer, etc.')
|
200 |
+
config, model, tokenizer = load_pretrained(args.model_name)
|
201 |
+
model.to(device)
|
202 |
+
embeddings = get_embeddings(model, config)
|
203 |
+
embedding_gradient = GradientStorage(embeddings)
|
204 |
+
predictor = PredictWrapper(model)
|
205 |
+
|
206 |
+
if args.label_map is not None:
|
207 |
+
label_map = json.loads(args.label_map)
|
208 |
+
logger.info(f"Label map: {label_map}")
|
209 |
+
else:
|
210 |
+
label_map = None
|
211 |
+
logger.info('No label map')
|
212 |
+
|
213 |
+
templatizer = utils.TriggerTemplatizer(
|
214 |
+
args.template,
|
215 |
+
config,
|
216 |
+
tokenizer,
|
217 |
+
label_map=label_map,
|
218 |
+
label_field=args.label_field,
|
219 |
+
tokenize_labels=args.tokenize_labels,
|
220 |
+
add_special_tokens=False,
|
221 |
+
use_ctx=args.use_ctx
|
222 |
+
)
|
223 |
+
|
224 |
+
# Obtain the initial trigger tokens and label mapping
|
225 |
+
if args.initial_trigger:
|
226 |
+
trigger_ids = tokenizer.convert_tokens_to_ids(args.initial_trigger)
|
227 |
+
logger.debug(f'Initial trigger: {args.initial_trigger}')
|
228 |
+
logger.debug(f'Trigger ids: {trigger_ids}')
|
229 |
+
assert len(trigger_ids) == templatizer.num_trigger_tokens
|
230 |
+
else:
|
231 |
+
trigger_ids = [tokenizer.mask_token_id] * templatizer.num_trigger_tokens
|
232 |
+
trigger_ids = torch.tensor(trigger_ids, device=device).unsqueeze(0)
|
233 |
+
best_trigger_ids = trigger_ids.clone()
|
234 |
+
|
235 |
+
# NOTE: Accuracy can only be computed if a fixed pool of labels is given, which currently
|
236 |
+
# requires the label map to be specified. Since producing a label map may be cumbersome (e.g.,
|
237 |
+
# for link prediction tasks), we just use (negative) loss as the evaluation metric in these cases.
|
238 |
+
if label_map:
|
239 |
+
evaluation_fn = AccuracyFn(tokenizer, label_map, device)
|
240 |
+
else:
|
241 |
+
evaluation_fn = lambda x, y: -get_loss(x, y)
|
242 |
+
|
243 |
+
logger.info('Loading datasets')
|
244 |
+
collator = utils.Collator(pad_token_id=tokenizer.pad_token_id)
|
245 |
+
|
246 |
+
if args.perturbed:
|
247 |
+
train_dataset = utils.load_augmented_trigger_dataset(args.train, templatizer, limit=args.limit)
|
248 |
+
else:
|
249 |
+
train_dataset = utils.load_trigger_dataset(args.train, templatizer, use_ctx=args.use_ctx, limit=args.limit)
|
250 |
+
train_loader = DataLoader(train_dataset, batch_size=args.bsz, shuffle=True, collate_fn=collator)
|
251 |
+
|
252 |
+
if args.perturbed:
|
253 |
+
dev_dataset = utils.load_augmented_trigger_dataset(args.dev, templatizer)
|
254 |
+
else:
|
255 |
+
dev_dataset = utils.load_trigger_dataset(args.dev, templatizer, use_ctx=args.use_ctx)
|
256 |
+
dev_loader = DataLoader(dev_dataset, batch_size=args.eval_size, shuffle=False, collate_fn=collator)
|
257 |
+
|
258 |
+
# To "filter" unwanted trigger tokens, we subtract a huge number from their logits.
|
259 |
+
filter = torch.zeros(tokenizer.vocab_size, dtype=torch.float32, device=device)
|
260 |
+
if args.filter:
|
261 |
+
logger.info('Filtering label tokens.')
|
262 |
+
if label_map:
|
263 |
+
for label_tokens in label_map.values():
|
264 |
+
label_ids = utils.encode_label(tokenizer, label_tokens).unsqueeze(0)
|
265 |
+
filter[label_ids] = -1e32
|
266 |
+
else:
|
267 |
+
for _, label_ids in train_dataset:
|
268 |
+
filter[label_ids] = -1e32
|
269 |
+
logger.info('Filtering special tokens and capitalized words.')
|
270 |
+
for word, idx in tokenizer.get_vocab().items():
|
271 |
+
if len(word) == 1 or idx >= tokenizer.vocab_size:
|
272 |
+
continue
|
273 |
+
# Filter special tokens.
|
274 |
+
if idx in tokenizer.all_special_ids:
|
275 |
+
logger.debug('Filtered: %s', word)
|
276 |
+
filter[idx] = -1e32
|
277 |
+
# Filter capitalized words (lazy way to remove proper nouns).
|
278 |
+
if isupper(idx, tokenizer):
|
279 |
+
logger.debug('Filtered: %s', word)
|
280 |
+
filter[idx] = -1e32
|
281 |
+
|
282 |
+
logger.info('Evaluating')
|
283 |
+
numerator = 0
|
284 |
+
denominator = 0
|
285 |
+
for model_inputs, labels in tqdm(dev_loader):
|
286 |
+
model_inputs = {k: v.to(device) for k, v in model_inputs.items()}
|
287 |
+
labels = labels.to(device)
|
288 |
+
with torch.no_grad():
|
289 |
+
predict_logits = predictor(model_inputs, trigger_ids)
|
290 |
+
numerator += evaluation_fn(predict_logits, labels).sum().item()
|
291 |
+
denominator += labels.size(0)
|
292 |
+
dev_metric = numerator / (denominator + 1e-13)
|
293 |
+
logger.info(f'Dev metric: {dev_metric}')
|
294 |
+
|
295 |
+
best_dev_metric = -float('inf')
|
296 |
+
# Measure elapsed time of trigger search
|
297 |
+
start = time.time()
|
298 |
+
|
299 |
+
for i in range(args.iters):
|
300 |
+
|
301 |
+
logger.info(f'Iteration: {i}')
|
302 |
+
|
303 |
+
logger.info('Accumulating Gradient')
|
304 |
+
model.zero_grad()
|
305 |
+
|
306 |
+
pbar = tqdm(range(args.accumulation_steps))
|
307 |
+
train_iter = iter(train_loader)
|
308 |
+
averaged_grad = None
|
309 |
+
|
310 |
+
# Accumulate
|
311 |
+
for step in pbar:
|
312 |
+
|
313 |
+
# Shuttle inputs to GPU
|
314 |
+
try:
|
315 |
+
model_inputs, labels = next(train_iter)
|
316 |
+
except:
|
317 |
+
logger.warning(
|
318 |
+
'Insufficient data for number of accumulation steps. '
|
319 |
+
'Effective batch size will be smaller than specified.'
|
320 |
+
)
|
321 |
+
break
|
322 |
+
model_inputs = {k: v.to(device) for k, v in model_inputs.items()}
|
323 |
+
labels = labels.to(device)
|
324 |
+
predict_logits = predictor(model_inputs, trigger_ids)
|
325 |
+
loss = get_loss(predict_logits, labels).mean()
|
326 |
+
loss.backward()
|
327 |
+
|
328 |
+
grad = embedding_gradient.get()
|
329 |
+
bsz, _, emb_dim = grad.size()
|
330 |
+
selection_mask = model_inputs['trigger_mask'].unsqueeze(-1)
|
331 |
+
grad = torch.masked_select(grad, selection_mask)
|
332 |
+
grad = grad.view(bsz, templatizer.num_trigger_tokens, emb_dim)
|
333 |
+
|
334 |
+
if averaged_grad is None:
|
335 |
+
averaged_grad = grad.sum(dim=0) / args.accumulation_steps
|
336 |
+
else:
|
337 |
+
averaged_grad += grad.sum(dim=0) / args.accumulation_steps
|
338 |
+
|
339 |
+
logger.info('Evaluating Candidates')
|
340 |
+
pbar = tqdm(range(args.accumulation_steps))
|
341 |
+
train_iter = iter(train_loader)
|
342 |
+
|
343 |
+
token_to_flip = random.randrange(templatizer.num_trigger_tokens)
|
344 |
+
candidates = hotflip_attack(averaged_grad[token_to_flip],
|
345 |
+
embeddings.weight,
|
346 |
+
increase_loss=False,
|
347 |
+
num_candidates=args.num_cand,
|
348 |
+
filter=filter)
|
349 |
+
|
350 |
+
current_score = 0
|
351 |
+
candidate_scores = torch.zeros(args.num_cand, device=device)
|
352 |
+
denom = 0
|
353 |
+
for step in pbar:
|
354 |
+
|
355 |
+
try:
|
356 |
+
model_inputs, labels = next(train_iter)
|
357 |
+
except:
|
358 |
+
logger.warning(
|
359 |
+
'Insufficient data for number of accumulation steps. '
|
360 |
+
'Effective batch size will be smaller than specified.'
|
361 |
+
)
|
362 |
+
break
|
363 |
+
model_inputs = {k: v.to(device) for k, v in model_inputs.items()}
|
364 |
+
labels = labels.to(device)
|
365 |
+
with torch.no_grad():
|
366 |
+
predict_logits = predictor(model_inputs, trigger_ids)
|
367 |
+
eval_metric = evaluation_fn(predict_logits, labels)
|
368 |
+
|
369 |
+
# Update current score
|
370 |
+
current_score += eval_metric.sum()
|
371 |
+
denom += labels.size(0)
|
372 |
+
|
373 |
+
# NOTE: Instead of iterating over tokens to flip we randomly change just one each
|
374 |
+
# time so the gradients don't get stale.
|
375 |
+
for i, candidate in enumerate(candidates):
|
376 |
+
|
377 |
+
# if candidate.item() in filter_candidates:
|
378 |
+
# candidate_scores[i] = -1e32
|
379 |
+
# continue
|
380 |
+
|
381 |
+
temp_trigger = trigger_ids.clone()
|
382 |
+
temp_trigger[:, token_to_flip] = candidate
|
383 |
+
with torch.no_grad():
|
384 |
+
predict_logits = predictor(model_inputs, temp_trigger)
|
385 |
+
eval_metric = evaluation_fn(predict_logits, labels)
|
386 |
+
|
387 |
+
candidate_scores[i] += eval_metric.sum()
|
388 |
+
|
389 |
+
# TODO: Something cleaner. LAMA templates can't have mask tokens, so if
|
390 |
+
# there are still mask tokens in the trigger then set the current score
|
391 |
+
# to -inf.
|
392 |
+
if args.print_lama:
|
393 |
+
if trigger_ids.eq(tokenizer.mask_token_id).any():
|
394 |
+
current_score = float('-inf')
|
395 |
+
|
396 |
+
if (candidate_scores > current_score).any():
|
397 |
+
logger.info('Better trigger detected.')
|
398 |
+
best_candidate_score = candidate_scores.max()
|
399 |
+
best_candidate_idx = candidate_scores.argmax()
|
400 |
+
trigger_ids[:, token_to_flip] = candidates[best_candidate_idx]
|
401 |
+
logger.info(f'Train metric: {best_candidate_score / (denom + 1e-13): 0.4f}')
|
402 |
+
else:
|
403 |
+
logger.info('No improvement detected. Skipping evaluation.')
|
404 |
+
continue
|
405 |
+
|
406 |
+
logger.info('Evaluating')
|
407 |
+
numerator = 0
|
408 |
+
denominator = 0
|
409 |
+
for model_inputs, labels in tqdm(dev_loader):
|
410 |
+
model_inputs = {k: v.to(device) for k, v in model_inputs.items()}
|
411 |
+
labels = labels.to(device)
|
412 |
+
with torch.no_grad():
|
413 |
+
predict_logits = predictor(model_inputs, trigger_ids)
|
414 |
+
numerator += evaluation_fn(predict_logits, labels).sum().item()
|
415 |
+
denominator += labels.size(0)
|
416 |
+
dev_metric = numerator / (denominator + 1e-13)
|
417 |
+
|
418 |
+
logger.info(f'Trigger tokens: {tokenizer.convert_ids_to_tokens(trigger_ids.squeeze(0))}')
|
419 |
+
logger.info(f'Dev metric: {dev_metric}')
|
420 |
+
|
421 |
+
# TODO: Something cleaner. LAMA templates can't have mask tokens, so if
|
422 |
+
# there are still mask tokens in the trigger then set the current score
|
423 |
+
# to -inf.
|
424 |
+
if args.print_lama:
|
425 |
+
if best_trigger_ids.eq(tokenizer.mask_token_id).any():
|
426 |
+
best_dev_metric = float('-inf')
|
427 |
+
|
428 |
+
if dev_metric > best_dev_metric:
|
429 |
+
logger.info('Best performance so far')
|
430 |
+
best_trigger_ids = trigger_ids.clone()
|
431 |
+
best_dev_metric = dev_metric
|
432 |
+
|
433 |
+
best_trigger_tokens = tokenizer.convert_ids_to_tokens(best_trigger_ids.squeeze(0))
|
434 |
+
logger.info(f'Best tokens: {best_trigger_tokens}')
|
435 |
+
logger.info(f'Best dev metric: {best_dev_metric}')
|
436 |
+
if args.print_lama:
|
437 |
+
# Templatize with [X] and [Y]
|
438 |
+
if args.use_ctx:
|
439 |
+
model_inputs, label_ids = templatizer({
|
440 |
+
'sub_label': '[X]',
|
441 |
+
'obj_label': tokenizer.lama_y,
|
442 |
+
'context': ''
|
443 |
+
})
|
444 |
+
else:
|
445 |
+
model_inputs, label_ids = templatizer({
|
446 |
+
'sub_label': '[X]',
|
447 |
+
'obj_label': tokenizer.lama_y,
|
448 |
+
})
|
449 |
+
lama_template = model_inputs['input_ids']
|
450 |
+
# Instantiate trigger tokens
|
451 |
+
lama_template.masked_scatter_(
|
452 |
+
mask=model_inputs['trigger_mask'],
|
453 |
+
source=best_trigger_ids.cpu())
|
454 |
+
# Instantiate label token
|
455 |
+
lama_template.masked_scatter_(
|
456 |
+
mask=model_inputs['predict_mask'],
|
457 |
+
source=label_ids)
|
458 |
+
# Print LAMA JSON template
|
459 |
+
relation = args.train.parent.stem
|
460 |
+
|
461 |
+
# The following block of code is a bit hacky but whatever, it gets the job done
|
462 |
+
if args.use_ctx:
|
463 |
+
template = tokenizer.decode(lama_template.squeeze(0)[1:-1]).replace('[SEP] ', '').replace('</s> ', '').replace('[ X ]', '[X]')
|
464 |
+
else:
|
465 |
+
template = tokenizer.decode(lama_template.squeeze(0)[1:-1]).replace('[ X ]', '[X]')
|
466 |
+
|
467 |
+
out = {
|
468 |
+
'relation': args.train.parent.stem,
|
469 |
+
'template': template
|
470 |
+
}
|
471 |
+
print(json.dumps(out))
|
472 |
+
|
473 |
+
|
474 |
+
if __name__ == '__main__':
|
475 |
+
parser = argparse.ArgumentParser()
|
476 |
+
parser.add_argument('--train', type=Path, required=True, help='Train data path')
|
477 |
+
parser.add_argument('--dev', type=Path, required=True, help='Dev data path')
|
478 |
+
parser.add_argument('--template', type=str, help='Template string')
|
479 |
+
parser.add_argument('--label-map', type=str, default=None, help='JSON object defining label map')
|
480 |
+
|
481 |
+
# LAMA-specific
|
482 |
+
parser.add_argument('--tokenize-labels', action='store_true',
|
483 |
+
help='If specified labels are split into word pieces.'
|
484 |
+
'Needed for LAMA probe experiments.')
|
485 |
+
parser.add_argument('--filter', action='store_true',
|
486 |
+
help='If specified, filter out special tokens and gold objects.'
|
487 |
+
'Furthermore, tokens starting with capital '
|
488 |
+
'letters will not appear in triggers. Lazy '
|
489 |
+
'approach for removing proper nouns.')
|
490 |
+
parser.add_argument('--print-lama', action='store_true',
|
491 |
+
help='Prints best trigger in LAMA format.')
|
492 |
+
|
493 |
+
parser.add_argument('--initial-trigger', nargs='+', type=str, default=None, help='Manual prompt')
|
494 |
+
parser.add_argument('--label-field', type=str, default='label',
|
495 |
+
help='Name of the label field')
|
496 |
+
|
497 |
+
parser.add_argument('--bsz', type=int, default=32, help='Batch size')
|
498 |
+
parser.add_argument('--eval-size', type=int, default=256, help='Eval size')
|
499 |
+
parser.add_argument('--iters', type=int, default=100,
|
500 |
+
help='Number of iterations to run trigger search algorithm')
|
501 |
+
parser.add_argument('--accumulation-steps', type=int, default=10)
|
502 |
+
parser.add_argument('--model-name', type=str, default='bert-base-cased',
|
503 |
+
help='Model name passed to HuggingFace AutoX classes.')
|
504 |
+
parser.add_argument('--seed', type=int, default=0)
|
505 |
+
parser.add_argument('--limit', type=int, default=None)
|
506 |
+
parser.add_argument('--use-ctx', action='store_true',
|
507 |
+
help='Use context sentences for relation extraction only')
|
508 |
+
parser.add_argument('--perturbed', action='store_true',
|
509 |
+
help='Perturbed sentence evaluation of relation extraction: replace each object in dataset with a random other object')
|
510 |
+
parser.add_argument('--patience', type=int, default=5)
|
511 |
+
parser.add_argument('--num-cand', type=int, default=10)
|
512 |
+
parser.add_argument('--sentence-size', type=int, default=50)
|
513 |
+
|
514 |
+
parser.add_argument('--debug', action='store_true')
|
515 |
+
args = parser.parse_args()
|
516 |
+
|
517 |
+
if args.debug:
|
518 |
+
level = logging.DEBUG
|
519 |
+
else:
|
520 |
+
level = logging.INFO
|
521 |
+
logging.basicConfig(level=level)
|
522 |
+
|
523 |
+
run_model(args)
|
autoprompt/finetune.py
ADDED
@@ -0,0 +1,203 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Script for running finetuning on glue tasks.
|
3 |
+
|
4 |
+
Largely copied from:
|
5 |
+
https://github.com/huggingface/transformers/blob/master/examples/text-classification/run_glue.py
|
6 |
+
"""
|
7 |
+
import argparse
|
8 |
+
import logging
|
9 |
+
from pathlib import Path
|
10 |
+
import random
|
11 |
+
|
12 |
+
import numpy as np
|
13 |
+
import torch
|
14 |
+
import torch.nn.functional as F
|
15 |
+
from torch.utils.data import DataLoader
|
16 |
+
from torch.optim.lr_scheduler import LambdaLR
|
17 |
+
import transformers
|
18 |
+
from transformers import (
|
19 |
+
AdamW, AutoConfig, AutoModelForSequenceClassification, AutoTokenizer
|
20 |
+
)
|
21 |
+
from tqdm import tqdm
|
22 |
+
|
23 |
+
import autoprompt.utils as utils
|
24 |
+
|
25 |
+
|
26 |
+
logger = logging.getLogger(__name__)
|
27 |
+
|
28 |
+
|
29 |
+
def set_seed(seed: int):
|
30 |
+
"""Sets the relevant random seeds."""
|
31 |
+
random.seed(seed)
|
32 |
+
np.random.seed(seed)
|
33 |
+
torch.random.manual_seed(seed)
|
34 |
+
torch.cuda.manual_seed(seed)
|
35 |
+
|
36 |
+
|
37 |
+
def get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, last_epoch=-1):
|
38 |
+
""" Create a schedule with a learning rate that decreases linearly after
|
39 |
+
linearly increasing during a warmup period.
|
40 |
+
|
41 |
+
From:
|
42 |
+
https://github.com/uds-lsv/bert-stable-fine-tuning/blob/master/src/transformers/optimization.py
|
43 |
+
"""
|
44 |
+
|
45 |
+
def lr_lambda(current_step):
|
46 |
+
if current_step < num_warmup_steps:
|
47 |
+
return float(current_step) / float(max(1, num_warmup_steps))
|
48 |
+
return max(
|
49 |
+
0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps))
|
50 |
+
)
|
51 |
+
|
52 |
+
return LambdaLR(optimizer, lr_lambda, last_epoch)
|
53 |
+
|
54 |
+
|
55 |
+
def main(args):
|
56 |
+
set_seed(args.seed)
|
57 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
58 |
+
|
59 |
+
config = AutoConfig.from_pretrained(args.model_name, num_labels=args.num_labels)
|
60 |
+
tokenizer = AutoTokenizer.from_pretrained(args.model_name)
|
61 |
+
model = AutoModelForSequenceClassification.from_pretrained(args.model_name, config=config)
|
62 |
+
model.to(device)
|
63 |
+
|
64 |
+
collator = utils.Collator(pad_token_id=tokenizer.pad_token_id)
|
65 |
+
train_dataset, label_map = utils.load_classification_dataset(
|
66 |
+
args.train,
|
67 |
+
tokenizer,
|
68 |
+
args.field_a,
|
69 |
+
args.field_b,
|
70 |
+
args.label_field,
|
71 |
+
limit=args.limit
|
72 |
+
)
|
73 |
+
train_loader = DataLoader(train_dataset, batch_size=args.bsz, shuffle=True, collate_fn=collator)
|
74 |
+
dev_dataset, _ = utils.load_classification_dataset(
|
75 |
+
args.dev,
|
76 |
+
tokenizer,
|
77 |
+
args.field_a,
|
78 |
+
args.field_b,
|
79 |
+
args.label_field,
|
80 |
+
label_map
|
81 |
+
)
|
82 |
+
dev_loader = DataLoader(dev_dataset, batch_size=args.bsz, shuffle=False, collate_fn=collator)
|
83 |
+
test_dataset, _ = utils.load_classification_dataset(
|
84 |
+
args.test,
|
85 |
+
tokenizer,
|
86 |
+
args.field_a,
|
87 |
+
args.field_b,
|
88 |
+
args.label_field,
|
89 |
+
label_map
|
90 |
+
)
|
91 |
+
test_loader = DataLoader(test_dataset, batch_size=args.bsz, shuffle=False, collate_fn=collator)
|
92 |
+
|
93 |
+
if args.bias_correction:
|
94 |
+
betas = (0.9, 0.999)
|
95 |
+
else:
|
96 |
+
betas = (0.0, 0.000)
|
97 |
+
|
98 |
+
optimizer = AdamW(
|
99 |
+
model.parameters(),
|
100 |
+
lr=args.lr,
|
101 |
+
weight_decay=1e-2,
|
102 |
+
betas=betas
|
103 |
+
)
|
104 |
+
|
105 |
+
# Use suggested learning rate scheduler
|
106 |
+
num_training_steps = len(train_dataset) * args.epochs // args.bsz
|
107 |
+
num_warmup_steps = num_training_steps // 10
|
108 |
+
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps,
|
109 |
+
num_training_steps)
|
110 |
+
|
111 |
+
if not args.ckpt_dir.exists():
|
112 |
+
logger.info(f'Making checkpoint directory: {args.ckpt_dir}')
|
113 |
+
args.ckpt_dir.mkdir(parents=True)
|
114 |
+
elif not args.force_overwrite:
|
115 |
+
raise RuntimeError('Checkpoint directory already exists.')
|
116 |
+
|
117 |
+
try:
|
118 |
+
best_accuracy = 0
|
119 |
+
for epoch in range(args.epochs):
|
120 |
+
logger.info('Training...')
|
121 |
+
model.train()
|
122 |
+
avg_loss = utils.ExponentialMovingAverage()
|
123 |
+
pbar = tqdm(train_loader)
|
124 |
+
for model_inputs, labels in pbar:
|
125 |
+
model_inputs = {k: v.to(device) for k, v in model_inputs.items()}
|
126 |
+
labels = labels.to(device)
|
127 |
+
optimizer.zero_grad()
|
128 |
+
logits, *_ = model(**model_inputs)
|
129 |
+
loss = F.cross_entropy(logits, labels.squeeze(-1))
|
130 |
+
loss.backward()
|
131 |
+
optimizer.step()
|
132 |
+
scheduler.step()
|
133 |
+
avg_loss.update(loss.item())
|
134 |
+
pbar.set_description(f'loss: {avg_loss.get_metric(): 0.4f}, '
|
135 |
+
f'lr: {optimizer.param_groups[0]["lr"]: .3e}')
|
136 |
+
|
137 |
+
logger.info('Evaluating...')
|
138 |
+
model.eval()
|
139 |
+
correct = 0
|
140 |
+
total = 0
|
141 |
+
with torch.no_grad():
|
142 |
+
for model_inputs, labels in dev_loader:
|
143 |
+
model_inputs = {k: v.to(device) for k, v in model_inputs.items()}
|
144 |
+
labels = labels.to(device)
|
145 |
+
logits, *_ = model(**model_inputs)
|
146 |
+
_, preds = logits.max(dim=-1)
|
147 |
+
correct += (preds == labels.squeeze(-1)).sum().item()
|
148 |
+
total += labels.size(0)
|
149 |
+
accuracy = correct / (total + 1e-13)
|
150 |
+
logger.info(f'Accuracy: {accuracy : 0.4f}')
|
151 |
+
|
152 |
+
if accuracy > best_accuracy:
|
153 |
+
logger.info('Best performance so far.')
|
154 |
+
model.save_pretrained(args.ckpt_dir)
|
155 |
+
tokenizer.save_pretrained(args.ckpt_dir)
|
156 |
+
best_accuracy = accuracy
|
157 |
+
except KeyboardInterrupt:
|
158 |
+
logger.info('Interrupted...')
|
159 |
+
|
160 |
+
logger.info('Testing...')
|
161 |
+
model.eval()
|
162 |
+
correct = 0
|
163 |
+
total = 0
|
164 |
+
with torch.no_grad():
|
165 |
+
for model_inputs, labels in test_loader:
|
166 |
+
model_inputs = {k: v.to(device) for k, v in model_inputs.items()}
|
167 |
+
labels = labels.to(device)
|
168 |
+
logits, *_ = model(**model_inputs)
|
169 |
+
_, preds = logits.max(dim=-1)
|
170 |
+
correct += (preds == labels.squeeze(-1)).sum().item()
|
171 |
+
total += labels.size(0)
|
172 |
+
accuracy = correct / (total + 1e-13)
|
173 |
+
logger.info(f'Accuracy: {accuracy : 0.4f}')
|
174 |
+
|
175 |
+
|
176 |
+
if __name__ == '__main__':
|
177 |
+
parser = argparse.ArgumentParser()
|
178 |
+
parser.add_argument('--model-name', type=str)
|
179 |
+
parser.add_argument('--train', type=Path)
|
180 |
+
parser.add_argument('--dev', type=Path)
|
181 |
+
parser.add_argument('--test', type=Path)
|
182 |
+
parser.add_argument('--field-a', type=str)
|
183 |
+
parser.add_argument('--field-b', type=str, default=None)
|
184 |
+
parser.add_argument('--label-field', type=str, default='label')
|
185 |
+
parser.add_argument('--ckpt-dir', type=Path, default=Path('ckpt/'))
|
186 |
+
parser.add_argument('--num-labels', type=int, default=2)
|
187 |
+
parser.add_argument('--bsz', type=int, default=32)
|
188 |
+
parser.add_argument('--epochs', type=int, default=3)
|
189 |
+
parser.add_argument('--lr', type=float, default=2e-5)
|
190 |
+
parser.add_argument('--limit', type=int, default=None)
|
191 |
+
parser.add_argument('--seed', type=int, default=1234)
|
192 |
+
parser.add_argument('--bias-correction', action='store_true')
|
193 |
+
parser.add_argument('-f', '--force-overwrite', action='store_true')
|
194 |
+
parser.add_argument('--debug', action='store_true')
|
195 |
+
args = parser.parse_args()
|
196 |
+
|
197 |
+
if args.debug:
|
198 |
+
level = logging.DEBUG
|
199 |
+
else:
|
200 |
+
level = logging.INFO
|
201 |
+
logging.basicConfig(level=level)
|
202 |
+
|
203 |
+
main(args)
|
autoprompt/label_search.py
ADDED
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
This is a hacky little attempt using the tools from the trigger creation script to identify a
|
3 |
+
good set of label strings. The idea is to train a linear classifier over the predict token and
|
4 |
+
then look at the most similar tokens.
|
5 |
+
"""
|
6 |
+
import argparse
|
7 |
+
import json
|
8 |
+
import logging
|
9 |
+
from pathlib import Path
|
10 |
+
|
11 |
+
import torch
|
12 |
+
import torch.nn.functional as F
|
13 |
+
from torch.utils.data import DataLoader
|
14 |
+
from transformers import (
|
15 |
+
AutoConfig, AutoModelWithLMHead, AutoTokenizer, BertForMaskedLM, RobertaForMaskedLM
|
16 |
+
)
|
17 |
+
from tqdm import tqdm
|
18 |
+
|
19 |
+
import autoprompt.utils as utils
|
20 |
+
import autoprompt.create_trigger as ct
|
21 |
+
|
22 |
+
|
23 |
+
logger = logging.getLogger(__name__)
|
24 |
+
|
25 |
+
|
26 |
+
def load_pretrained(model_name):
|
27 |
+
"""
|
28 |
+
Loads pretrained HuggingFace config/model/tokenizer, as well as performs required
|
29 |
+
initialization steps to facilitate working with triggers.
|
30 |
+
"""
|
31 |
+
config = AutoConfig.from_pretrained(args.model_name)
|
32 |
+
model = AutoModelWithLMHead.from_pretrained(args.model_name, config=config)
|
33 |
+
model.eval()
|
34 |
+
tokenizer = AutoTokenizer.from_pretrained(args.model_name)
|
35 |
+
utils.add_task_specific_tokens(tokenizer)
|
36 |
+
return config, model, tokenizer
|
37 |
+
|
38 |
+
|
39 |
+
def get_final_embeddings(model):
|
40 |
+
if isinstance(model, BertForMaskedLM):
|
41 |
+
return model.cls.predictions.transform
|
42 |
+
elif isinstance(model, RobertaForMaskedLM):
|
43 |
+
return model.lm_head.layer_norm
|
44 |
+
else:
|
45 |
+
raise NotImplementedError(f'{model} not currently supported')
|
46 |
+
|
47 |
+
|
48 |
+
def get_word_embeddings(model):
|
49 |
+
if isinstance(model, BertForMaskedLM):
|
50 |
+
return model.cls.predictions.decoder.weight
|
51 |
+
elif isinstance(model, RobertaForMaskedLM):
|
52 |
+
return model.lm_head.decoder.weight
|
53 |
+
else:
|
54 |
+
raise NotImplementedError(f'{model} not currently supported')
|
55 |
+
|
56 |
+
|
57 |
+
def main(args):
|
58 |
+
ct.set_seed(args.seed)
|
59 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
60 |
+
|
61 |
+
logger.info('Loading model, tokenizer, etc.')
|
62 |
+
config, model, tokenizer = load_pretrained(args.model_name)
|
63 |
+
model.to(device)
|
64 |
+
final_embeddings = get_final_embeddings(model)
|
65 |
+
embedding_storage = utils.OutputStorage(final_embeddings)
|
66 |
+
word_embeddings = get_word_embeddings(model)
|
67 |
+
|
68 |
+
label_map = json.loads(args.label_map)
|
69 |
+
reverse_label_map = {y: x for x, y in label_map.items()}
|
70 |
+
templatizer = utils.TriggerTemplatizer(
|
71 |
+
args.template,
|
72 |
+
tokenizer,
|
73 |
+
label_map=label_map,
|
74 |
+
label_field=args.label_field,
|
75 |
+
add_special_tokens=False
|
76 |
+
)
|
77 |
+
|
78 |
+
# The weights of this projection will help identify the best label words.
|
79 |
+
projection = torch.nn.Linear(config.hidden_size, len(label_map))
|
80 |
+
projection.to(device)
|
81 |
+
|
82 |
+
# Obtain the initial trigger tokens and label mapping
|
83 |
+
if args.initial_trigger:
|
84 |
+
trigger_ids = tokenizer.encode(
|
85 |
+
args.initial_trigger,
|
86 |
+
add_special_tokens=False,
|
87 |
+
add_prefix_space=True
|
88 |
+
)
|
89 |
+
assert len(trigger_ids) == templatizer.num_trigger_tokens
|
90 |
+
else:
|
91 |
+
trigger_ids = [tokenizer.mask_token_id] * templatizer.num_trigger_tokens
|
92 |
+
trigger_ids = torch.tensor(trigger_ids, device=device).unsqueeze(0)
|
93 |
+
|
94 |
+
logger.info('Loading datasets')
|
95 |
+
collator = utils.Collator(pad_token_id=tokenizer.pad_token_id)
|
96 |
+
train_dataset = utils.load_trigger_dataset(args.train, templatizer)
|
97 |
+
train_loader = DataLoader(train_dataset, batch_size=args.bsz, shuffle=True, collate_fn=collator)
|
98 |
+
|
99 |
+
optimizer = torch.optim.Adam(projection.parameters(), lr=args.lr)
|
100 |
+
|
101 |
+
scores = torch.matmul(projection.weight, word_embeddings.transpose(0, 1))
|
102 |
+
scores = F.softmax(scores, dim=0)
|
103 |
+
for i, row in enumerate(scores):
|
104 |
+
_, top = row.topk(args.k)
|
105 |
+
decoded = tokenizer.convert_ids_to_tokens(top)
|
106 |
+
logger.info(f"Top k for class {reverse_label_map[i]}: {', '.join(decoded)}")
|
107 |
+
|
108 |
+
logger.info('Training')
|
109 |
+
for i in range(args.iters):
|
110 |
+
pbar = tqdm(train_loader)
|
111 |
+
for model_inputs, labels in pbar:
|
112 |
+
optimizer.zero_grad()
|
113 |
+
model_inputs = {k: v.to(device) for k, v in model_inputs.items()}
|
114 |
+
labels = labels.to(device)
|
115 |
+
trigger_mask = model_inputs.pop('trigger_mask')
|
116 |
+
predict_mask = model_inputs.pop('predict_mask')
|
117 |
+
model_inputs = ct.replace_trigger_tokens(model_inputs, trigger_ids, trigger_mask)
|
118 |
+
with torch.no_grad():
|
119 |
+
model(**model_inputs)
|
120 |
+
embeddings = embedding_storage.get()
|
121 |
+
predict_embeddings = embeddings.masked_select(predict_mask.unsqueeze(-1)).view(embeddings.size(0), -1)
|
122 |
+
logits = projection(predict_embeddings)
|
123 |
+
loss = F.cross_entropy(logits, labels.squeeze(-1))
|
124 |
+
loss.backward()
|
125 |
+
optimizer.step()
|
126 |
+
pbar.set_description(f'loss: {loss : 0.4f}')
|
127 |
+
|
128 |
+
scores = torch.matmul(projection.weight, word_embeddings.transpose(0, 1))
|
129 |
+
scores = F.softmax(scores, dim=0)
|
130 |
+
for i, row in enumerate(scores):
|
131 |
+
_, top = row.topk(args.k)
|
132 |
+
decoded = tokenizer.convert_ids_to_tokens(top)
|
133 |
+
logger.info(f"Top k for class {reverse_label_map[i]}: {', '.join(decoded)}")
|
134 |
+
|
135 |
+
|
136 |
+
|
137 |
+
if __name__ == '__main__':
|
138 |
+
parser = argparse.ArgumentParser()
|
139 |
+
parser.add_argument('--train', type=Path, required=True, help='Train data path')
|
140 |
+
parser.add_argument('--template', type=str, help='Template string')
|
141 |
+
parser.add_argument('--label-map', type=str, help='JSON object defining label map')
|
142 |
+
parser.add_argument('--initial-trigger', type=str, default=None, help='Manual prompt')
|
143 |
+
parser.add_argument('--label-field', type=str, default='label',
|
144 |
+
help='Name of the label field')
|
145 |
+
parser.add_argument('--lr', type=float, default=3e-4, help='Learning rate')
|
146 |
+
parser.add_argument('--k', type=int, default=50, help='Number of label tokens to print')
|
147 |
+
parser.add_argument('--bsz', type=int, default=32, help='Batch size')
|
148 |
+
parser.add_argument('--iters', type=int, default=10,
|
149 |
+
help='Number of iterations to run label search')
|
150 |
+
parser.add_argument('--model-name', type=str, default='bert-base-cased',
|
151 |
+
help='Model name passed to HuggingFace AutoX classes.')
|
152 |
+
parser.add_argument('--seed', type=int, default=0)
|
153 |
+
parser.add_argument('--debug', action='store_true')
|
154 |
+
args = parser.parse_args()
|
155 |
+
|
156 |
+
if args.debug:
|
157 |
+
level = logging.DEBUG
|
158 |
+
else:
|
159 |
+
level = logging.INFO
|
160 |
+
logging.basicConfig(level=level)
|
161 |
+
|
162 |
+
main(args)
|
autoprompt/popsicle.py
ADDED
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Frozen model with a linear topping...I'm really sleepy...
|
3 |
+
"""
|
4 |
+
import logging
|
5 |
+
|
6 |
+
import torch
|
7 |
+
from torch.nn import CrossEntropyLoss, MSELoss
|
8 |
+
from transformers import (
|
9 |
+
AutoConfig,
|
10 |
+
BertConfig,
|
11 |
+
BertForSequenceClassification,
|
12 |
+
PretrainedConfig,
|
13 |
+
RobertaConfig,
|
14 |
+
RobertaForSequenceClassification
|
15 |
+
)
|
16 |
+
|
17 |
+
|
18 |
+
logger = logging.getLogger(__name__)
|
19 |
+
|
20 |
+
|
21 |
+
|
22 |
+
class Bertsicle(BertForSequenceClassification):
|
23 |
+
def forward(
|
24 |
+
self,
|
25 |
+
input_ids=None,
|
26 |
+
attention_mask=None,
|
27 |
+
token_type_ids=None,
|
28 |
+
position_ids=None,
|
29 |
+
head_mask=None,
|
30 |
+
inputs_embeds=None,
|
31 |
+
labels=None,
|
32 |
+
):
|
33 |
+
with torch.no_grad():
|
34 |
+
outputs = self.bert(
|
35 |
+
input_ids,
|
36 |
+
attention_mask=attention_mask,
|
37 |
+
token_type_ids=token_type_ids,
|
38 |
+
position_ids=position_ids,
|
39 |
+
head_mask=head_mask,
|
40 |
+
inputs_embeds=inputs_embeds,
|
41 |
+
)
|
42 |
+
|
43 |
+
pooled_output = outputs[1] #by ROB
|
44 |
+
pooled_output = outputs[0]
|
45 |
+
pooled_output = pooled_output[:,1:,:] #eliminating CLS token
|
46 |
+
pooled_output = torch.mean(pooled_output, dim=1)
|
47 |
+
|
48 |
+
pooled_output = self.dropout(pooled_output)
|
49 |
+
logits = self.classifier(pooled_output)
|
50 |
+
|
51 |
+
outputs = (logits,) + outputs[2:] # add hidden states and attention if they are here
|
52 |
+
|
53 |
+
if labels is not None:
|
54 |
+
if self.num_labels == 1:
|
55 |
+
# We are doing regression
|
56 |
+
loss_fct = MSELoss()
|
57 |
+
loss = loss_fct(logits.view(-1), labels.view(-1))
|
58 |
+
else:
|
59 |
+
loss_fct = CrossEntropyLoss()
|
60 |
+
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
61 |
+
outputs = (loss,) + outputs
|
62 |
+
|
63 |
+
return outputs # (loss), logits, (hidden_states), (attentions)
|
64 |
+
|
65 |
+
|
66 |
+
class Robertasicle(RobertaForSequenceClassification):
|
67 |
+
def forward(
|
68 |
+
self,
|
69 |
+
input_ids=None,
|
70 |
+
attention_mask=None,
|
71 |
+
token_type_ids=None,
|
72 |
+
position_ids=None,
|
73 |
+
head_mask=None,
|
74 |
+
inputs_embeds=None,
|
75 |
+
labels=None,
|
76 |
+
):
|
77 |
+
with torch.no_grad():
|
78 |
+
outputs = self.roberta(
|
79 |
+
input_ids,
|
80 |
+
attention_mask=attention_mask,
|
81 |
+
token_type_ids=token_type_ids,
|
82 |
+
position_ids=position_ids,
|
83 |
+
head_mask=head_mask,
|
84 |
+
inputs_embeds=inputs_embeds,
|
85 |
+
)
|
86 |
+
sequence_output = outputs[0]
|
87 |
+
sequence_output = sequence_output[:, 1:, :] # eliminating <s> token
|
88 |
+
pooled_sequence_output = torch.mean(sequence_output, dim=1, keepdim=True)
|
89 |
+
logits = self.classifier(pooled_sequence_output)
|
90 |
+
outputs = (logits,) + outputs[2:]
|
91 |
+
if labels is not None:
|
92 |
+
if self.num_labels == 1:
|
93 |
+
# We are doing regression
|
94 |
+
loss_fct = MSELoss()
|
95 |
+
loss = loss_fct(logits.view(-1), labels.view(-1))
|
96 |
+
else:
|
97 |
+
loss_fct = CrossEntropyLoss()
|
98 |
+
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
99 |
+
outputs = (loss,) + outputs
|
100 |
+
|
101 |
+
return outputs # (loss), logits, (hidden_states), (attentions)
|
102 |
+
|
103 |
+
|
104 |
+
MODEL_MAPPING = {
|
105 |
+
RobertaConfig: Robertasicle,
|
106 |
+
BertConfig: Bertsicle
|
107 |
+
}
|
108 |
+
|
109 |
+
|
110 |
+
class AutoPopsicle:
|
111 |
+
def __init__(self):
|
112 |
+
raise EnvironmentError('You done goofed. Use `.from_pretrained()` or something.')
|
113 |
+
|
114 |
+
@classmethod
|
115 |
+
def from_config(cls, config):
|
116 |
+
for config_class, model_class in MODEL_MAPPING.items():
|
117 |
+
if isinstance(config, config_class):
|
118 |
+
return model_class(config)
|
119 |
+
|
120 |
+
raise ValueError('We do not support this config.')
|
121 |
+
|
122 |
+
@classmethod
|
123 |
+
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
124 |
+
config = kwargs.pop("config", None)
|
125 |
+
if not isinstance(config, PretrainedConfig):
|
126 |
+
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
127 |
+
|
128 |
+
for config_class, model_class in MODEL_MAPPING.items():
|
129 |
+
if isinstance(config, config_class):
|
130 |
+
logger.info(f'Config class: {config_class}')
|
131 |
+
logger.info(f'Model class: {model_class}')
|
132 |
+
return model_class.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
|
133 |
+
|
134 |
+
raise ValueError('We do not support "{pretrained_model_name_or_path}".')
|
autoprompt/run_linear_probe.py
ADDED
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Script for running a linear probe on glue tasks.
|
3 |
+
|
4 |
+
Largely copied from:
|
5 |
+
https://github.com/huggingface/transformers/blob/master/examples/text-classification/run_glue.py
|
6 |
+
"""
|
7 |
+
import argparse
|
8 |
+
import logging
|
9 |
+
from pathlib import Path
|
10 |
+
|
11 |
+
import torch
|
12 |
+
import torch.nn.functional as F
|
13 |
+
from torch.utils.data import DataLoader
|
14 |
+
from transformers import AutoConfig, AutoTokenizer, WEIGHTS_NAME, CONFIG_NAME
|
15 |
+
from tqdm import tqdm
|
16 |
+
|
17 |
+
from autoprompt.popsicle import AutoPopsicle
|
18 |
+
import autoprompt.utils as utils
|
19 |
+
|
20 |
+
logger = logging.getLogger(__name__)
|
21 |
+
|
22 |
+
|
23 |
+
def main(args):
|
24 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
25 |
+
|
26 |
+
config = AutoConfig.from_pretrained(args.model_name, num_labels=args.num_labels)
|
27 |
+
tokenizer = AutoTokenizer.from_pretrained(args.model_name)
|
28 |
+
model = AutoPopsicle.from_pretrained(args.model_name, config=config)
|
29 |
+
model.to(device)
|
30 |
+
|
31 |
+
collator = utils.Collator(pad_token_id=tokenizer.pad_token_id)
|
32 |
+
train_dataset, label_map = utils.load_classification_dataset(
|
33 |
+
args.train,
|
34 |
+
tokenizer,
|
35 |
+
args.field_a,
|
36 |
+
args.field_b,
|
37 |
+
args.label_field
|
38 |
+
)
|
39 |
+
train_loader = DataLoader(train_dataset, batch_size=args.bsz, shuffle=True, collate_fn=collator)
|
40 |
+
dev_dataset, _ = utils.load_classification_dataset(
|
41 |
+
args.dev,
|
42 |
+
tokenizer,
|
43 |
+
args.field_a,
|
44 |
+
args.field_b,
|
45 |
+
args.label_field,
|
46 |
+
label_map
|
47 |
+
)
|
48 |
+
dev_loader = DataLoader(dev_dataset, batch_size=args.bsz, shuffle=True, collate_fn=collator)
|
49 |
+
test_dataset, _ = utils.load_classification_dataset(
|
50 |
+
args.test,
|
51 |
+
tokenizer,
|
52 |
+
args.field_a,
|
53 |
+
args.field_b,
|
54 |
+
args.label_field,
|
55 |
+
label_map
|
56 |
+
)
|
57 |
+
test_loader = DataLoader(test_dataset, batch_size=args.bsz, shuffle=True, collate_fn=collator)
|
58 |
+
optimizer = torch.optim.Adam(model.classifier.parameters(), lr=args.lr, weight_decay=1e-6)
|
59 |
+
|
60 |
+
if not args.ckpt_dir.exists():
|
61 |
+
# logger.info(f'Making checkpoint directory: {args.ckpt_dir}')
|
62 |
+
args.ckpt_dir.mkdir(parents=True)
|
63 |
+
elif not args.force_overwrite:
|
64 |
+
raise RuntimeError('Checkpoint directory already exists.')
|
65 |
+
|
66 |
+
best_accuracy = 0
|
67 |
+
try:
|
68 |
+
for epoch in range(args.epochs):
|
69 |
+
logger.info('Training...')
|
70 |
+
model.eval() # Just linear regression - don't want model outputs changing during training.
|
71 |
+
avg_loss = utils.ExponentialMovingAverage()
|
72 |
+
pbar = tqdm(train_loader)
|
73 |
+
for model_inputs, labels in pbar:
|
74 |
+
model_inputs = {k: v.to(device) for k, v in model_inputs.items()}
|
75 |
+
labels = labels.to(device)
|
76 |
+
optimizer.zero_grad()
|
77 |
+
logits, *_ = model(**model_inputs)
|
78 |
+
loss = F.cross_entropy(logits, labels.squeeze(-1))
|
79 |
+
loss.backward()
|
80 |
+
optimizer.step()
|
81 |
+
avg_loss.update(loss.item())
|
82 |
+
pbar.set_description(f'loss: {avg_loss.get_metric(): 0.4f}')
|
83 |
+
|
84 |
+
logger.info('Evaluating...')
|
85 |
+
model.eval()
|
86 |
+
correct = 0
|
87 |
+
total = 0
|
88 |
+
for model_inputs, labels in dev_loader:
|
89 |
+
model_inputs = {k: v.to(device) for k, v in model_inputs.items()}
|
90 |
+
labels = labels.to(device)
|
91 |
+
logits, *_ = model(**model_inputs)
|
92 |
+
_, preds = logits.max(dim=-1)
|
93 |
+
correct += (preds == labels.squeeze(-1)).sum().item()
|
94 |
+
total += labels.size(0)
|
95 |
+
accuracy = correct / (total + 1e-13)
|
96 |
+
logger.info(f'Accuracy: {accuracy : 0.4f}')
|
97 |
+
|
98 |
+
if accuracy > best_accuracy:
|
99 |
+
logger.info('Best performance so far. Saving...')
|
100 |
+
# torch.save(model.state_dict(), args.ckpt_dir / WEIGHTS_NAME)
|
101 |
+
# model.config.to_json_file(args.ckpt_dir / CONFIG_NAME)
|
102 |
+
model.save_pretrained(args.ckpt_dir)
|
103 |
+
tokenizer.save_pretrained(args.ckpt_dir)
|
104 |
+
best_accuracy = accuracy
|
105 |
+
|
106 |
+
except KeyboardInterrupt:
|
107 |
+
logger.info('Training manually terminated.')
|
108 |
+
|
109 |
+
logger.info('Testing...')
|
110 |
+
checkpoint = torch.load(args.ckpt_dir / WEIGHTS_NAME)
|
111 |
+
model.load_state_dict(checkpoint)
|
112 |
+
model.eval()
|
113 |
+
correct = 0
|
114 |
+
total = 0
|
115 |
+
for model_inputs, labels in test_loader:
|
116 |
+
model_inputs = {k: v.to(device) for k, v in model_inputs.items()}
|
117 |
+
labels = labels.to(device)
|
118 |
+
logits, *_ = model(**model_inputs)
|
119 |
+
_, preds = logits.max(dim=-1)
|
120 |
+
correct += (preds == labels.squeeze(-1)).sum().item()
|
121 |
+
total += labels.size(0)
|
122 |
+
accuracy = correct / (total + 1e-13)
|
123 |
+
logger.info(f'Accuracy: {accuracy : 0.4f}')
|
124 |
+
|
125 |
+
|
126 |
+
if __name__ == '__main__':
|
127 |
+
parser = argparse.ArgumentParser()
|
128 |
+
parser.add_argument('--model-name', type=str)
|
129 |
+
parser.add_argument('--train', type=Path)
|
130 |
+
parser.add_argument('--dev', type=Path)
|
131 |
+
parser.add_argument('--test', type=Path)
|
132 |
+
parser.add_argument('--field-a', type=str)
|
133 |
+
parser.add_argument('--field-b', type=str, default=None)
|
134 |
+
parser.add_argument('--label-field', type=str, default='label')
|
135 |
+
parser.add_argument('--ckpt-dir', type=Path, default=Path('ckpt/'))
|
136 |
+
parser.add_argument('--num-labels', type=int, default=2)
|
137 |
+
parser.add_argument('--bsz', type=int, default=32)
|
138 |
+
parser.add_argument('--epochs', type=int, default=10)
|
139 |
+
parser.add_argument('--lr', type=float, default=1e-3)
|
140 |
+
parser.add_argument('-f', '--force-overwrite', action='store_true', default=True)
|
141 |
+
parser.add_argument('--debug', action='store_true')
|
142 |
+
parser.add_argument('--log_file', type=str, default='log.txt')
|
143 |
+
args = parser.parse_args()
|
144 |
+
|
145 |
+
if args.debug:
|
146 |
+
level = logging.DEBUG
|
147 |
+
else:
|
148 |
+
level = logging.INFO
|
149 |
+
logging.basicConfig(level=level, filename=args.log_file)
|
150 |
+
|
151 |
+
main(args)
|
autoprompt/utils.py
ADDED
@@ -0,0 +1,376 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import csv
|
2 |
+
import copy
|
3 |
+
import json
|
4 |
+
import logging
|
5 |
+
import random
|
6 |
+
from collections import defaultdict
|
7 |
+
|
8 |
+
import torch
|
9 |
+
from torch.nn.utils.rnn import pad_sequence
|
10 |
+
|
11 |
+
|
12 |
+
MAX_CONTEXT_LEN = 50
|
13 |
+
|
14 |
+
|
15 |
+
logger = logging.getLogger(__name__)
|
16 |
+
|
17 |
+
|
18 |
+
def pad_squeeze_sequence(sequence, *args, **kwargs):
|
19 |
+
"""Squeezes fake batch dimension added by tokenizer before padding sequence."""
|
20 |
+
return pad_sequence([x.squeeze(0) for x in sequence], *args, **kwargs)
|
21 |
+
|
22 |
+
|
23 |
+
class OutputStorage:
|
24 |
+
"""
|
25 |
+
This object stores the intermediate gradients of the output a the given PyTorch module, which
|
26 |
+
otherwise might not be retained.
|
27 |
+
"""
|
28 |
+
def __init__(self, module):
|
29 |
+
self._stored_output = None
|
30 |
+
module.register_forward_hook(self.hook)
|
31 |
+
|
32 |
+
def hook(self, module, input, output):
|
33 |
+
self._stored_output = output
|
34 |
+
|
35 |
+
def get(self):
|
36 |
+
return self._stored_output
|
37 |
+
|
38 |
+
|
39 |
+
class ExponentialMovingAverage:
|
40 |
+
def __init__(self, weight=0.3):
|
41 |
+
self._weight = weight
|
42 |
+
self.reset()
|
43 |
+
|
44 |
+
def update(self, x):
|
45 |
+
self._x += x
|
46 |
+
self._i += 1
|
47 |
+
|
48 |
+
def reset(self):
|
49 |
+
self._x = 0
|
50 |
+
self._i = 0
|
51 |
+
|
52 |
+
def get_metric(self):
|
53 |
+
return self._x / (self._i + 1e-13)
|
54 |
+
|
55 |
+
|
56 |
+
class Collator:
|
57 |
+
"""
|
58 |
+
Collates transformer outputs.
|
59 |
+
"""
|
60 |
+
def __init__(self, pad_token_id=0):
|
61 |
+
self._pad_token_id = pad_token_id
|
62 |
+
|
63 |
+
def __call__(self, features):
|
64 |
+
# Separate the list of inputs and labels
|
65 |
+
model_inputs, labels = list(zip(*features))
|
66 |
+
# Assume that all inputs have the same keys as the first
|
67 |
+
proto_input = model_inputs[0]
|
68 |
+
keys = list(proto_input.keys())
|
69 |
+
padded_inputs = {}
|
70 |
+
for key in keys:
|
71 |
+
if key == 'input_ids':
|
72 |
+
padding_value = self._pad_token_id
|
73 |
+
else:
|
74 |
+
padding_value = 0
|
75 |
+
# NOTE: We need to squeeze to get rid of fake batch dim.
|
76 |
+
sequence = [x[key] for x in model_inputs]
|
77 |
+
padded = pad_squeeze_sequence(sequence, batch_first=True, padding_value=padding_value)
|
78 |
+
padded_inputs[key] = padded
|
79 |
+
labels = pad_squeeze_sequence(labels, batch_first=True, padding_value=0)
|
80 |
+
return padded_inputs, labels
|
81 |
+
|
82 |
+
|
83 |
+
def encode_label(tokenizer, label, tokenize=False):
|
84 |
+
"""
|
85 |
+
Helper function for encoding labels. Deals with the subtleties of handling multiple tokens.
|
86 |
+
"""
|
87 |
+
if isinstance(label, str):
|
88 |
+
if tokenize:
|
89 |
+
# Ensure label is properly tokenized, and only retain first token
|
90 |
+
# if it gets split into multiple tokens. TODO: Make sure this is
|
91 |
+
# desired behavior.
|
92 |
+
tokens = tokenizer.tokenize(label)
|
93 |
+
if len(tokens) > 1:
|
94 |
+
raise ValueError(f'Label "{label}" gets mapped to multiple tokens.')
|
95 |
+
if tokens[0] == tokenizer.unk_token:
|
96 |
+
raise ValueError(f'Label "{label}" gets mapped to unk.')
|
97 |
+
label = tokens[0]
|
98 |
+
encoded = torch.tensor(tokenizer.convert_tokens_to_ids([label])).unsqueeze(0)
|
99 |
+
elif isinstance(label, list):
|
100 |
+
encoded = torch.tensor(tokenizer.convert_tokens_to_ids(label)).unsqueeze(0)
|
101 |
+
elif isinstance(label, int):
|
102 |
+
encoded = torch.tensor([[label]])
|
103 |
+
return encoded
|
104 |
+
|
105 |
+
|
106 |
+
class TriggerTemplatizer:
|
107 |
+
"""
|
108 |
+
An object to facilitate creating transformers-friendly triggers inputs from a template.
|
109 |
+
|
110 |
+
Parameters
|
111 |
+
==========
|
112 |
+
template : str
|
113 |
+
The template string, comprised of the following tokens:
|
114 |
+
[T] to mark a trigger placeholder.
|
115 |
+
[P] to mark a prediction placeholder.
|
116 |
+
{fields} arbitrary fields instantiated from the dataset instances.
|
117 |
+
For example a NLI template might look like:
|
118 |
+
"[T] [T] [T] {premise} [P] {hypothesis}"
|
119 |
+
tokenizer : PretrainedTokenizer
|
120 |
+
A HuggingFace tokenizer. Must have special trigger and predict tokens.
|
121 |
+
add_special_tokens : bool
|
122 |
+
Whether or not to add special tokens when encoding. Default: False.
|
123 |
+
"""
|
124 |
+
def __init__(self,
|
125 |
+
template,
|
126 |
+
config,
|
127 |
+
tokenizer,
|
128 |
+
label_field='label',
|
129 |
+
label_map=None,
|
130 |
+
tokenize_labels=False,
|
131 |
+
add_special_tokens=False,
|
132 |
+
use_ctx=False):
|
133 |
+
if not hasattr(tokenizer, 'predict_token') or \
|
134 |
+
not hasattr(tokenizer, 'trigger_token'):
|
135 |
+
raise ValueError(
|
136 |
+
'Tokenizer missing special trigger and predict tokens in vocab.'
|
137 |
+
'Use `utils.add_special_tokens` to add them.'
|
138 |
+
)
|
139 |
+
self._template = template
|
140 |
+
self._config = config
|
141 |
+
self._tokenizer = tokenizer
|
142 |
+
self._label_field = label_field
|
143 |
+
self._label_map = label_map
|
144 |
+
self._tokenize_labels = tokenize_labels
|
145 |
+
self._add_special_tokens = add_special_tokens
|
146 |
+
self._use_ctx = use_ctx
|
147 |
+
|
148 |
+
@property
|
149 |
+
def num_trigger_tokens(self):
|
150 |
+
return sum(token == '[T]' for token in self._template.split())
|
151 |
+
|
152 |
+
def __call__(self, format_kwargs):
|
153 |
+
# Format the template string
|
154 |
+
format_kwargs = format_kwargs.copy()
|
155 |
+
label = format_kwargs.pop(self._label_field)
|
156 |
+
text = self._template.format(**format_kwargs)
|
157 |
+
if label is None:
|
158 |
+
raise Exception(f'Bad data: {text}')
|
159 |
+
|
160 |
+
# Have the tokenizer encode the text and process the output to:
|
161 |
+
# - Create a trigger and predict mask
|
162 |
+
# - Replace the predict token with a mask token
|
163 |
+
model_inputs = self._tokenizer.encode_plus(
|
164 |
+
text,
|
165 |
+
add_special_tokens=self._add_special_tokens,
|
166 |
+
return_tensors='pt'
|
167 |
+
)
|
168 |
+
input_ids = model_inputs['input_ids']
|
169 |
+
trigger_mask = input_ids.eq(self._tokenizer.trigger_token_id)
|
170 |
+
predict_mask = input_ids.eq(self._tokenizer.predict_token_id)
|
171 |
+
input_ids[predict_mask] = self._tokenizer.mask_token_id
|
172 |
+
|
173 |
+
model_inputs['trigger_mask'] = trigger_mask
|
174 |
+
model_inputs['predict_mask'] = predict_mask
|
175 |
+
|
176 |
+
# For relation extraction with BERT, update token_type_ids to reflect the two different sequences
|
177 |
+
if self._use_ctx and self._config.model_type == 'bert':
|
178 |
+
sep_token_indices = (input_ids.squeeze(0) == self._tokenizer.convert_tokens_to_ids(self._tokenizer.sep_token)).nonzero().flatten()
|
179 |
+
sequence_b_indices = torch.arange(sep_token_indices[0], sep_token_indices[1] + 1).long().unsqueeze(0)
|
180 |
+
model_inputs['token_type_ids'].scatter_(1, sequence_b_indices, 1)
|
181 |
+
|
182 |
+
# Encode the label(s)
|
183 |
+
if self._label_map is not None:
|
184 |
+
label = self._label_map[label]
|
185 |
+
label_id = encode_label(
|
186 |
+
tokenizer=self._tokenizer,
|
187 |
+
label=label,
|
188 |
+
tokenize=self._tokenize_labels
|
189 |
+
)
|
190 |
+
|
191 |
+
return model_inputs, label_id
|
192 |
+
|
193 |
+
|
194 |
+
def add_task_specific_tokens(tokenizer):
|
195 |
+
tokenizer.add_special_tokens({
|
196 |
+
'additional_special_tokens': ['[T]', '[P]', '[Y]']
|
197 |
+
})
|
198 |
+
tokenizer.trigger_token = '[T]'
|
199 |
+
tokenizer.trigger_token_id = tokenizer.convert_tokens_to_ids('[T]')
|
200 |
+
tokenizer.predict_token = '[P]'
|
201 |
+
tokenizer.predict_token_id = tokenizer.convert_tokens_to_ids('[P]')
|
202 |
+
# NOTE: BERT and RoBERTa tokenizers work properly if [X] is not a special token...
|
203 |
+
# tokenizer.lama_x = '[X]'
|
204 |
+
# tokenizer.lama_x_id = tokenizer.convert_tokens_to_ids('[X]')
|
205 |
+
tokenizer.lama_y = '[Y]'
|
206 |
+
tokenizer.lama_x_id = tokenizer.convert_tokens_to_ids('[Y]')
|
207 |
+
|
208 |
+
|
209 |
+
|
210 |
+
def load_tsv(fname):
|
211 |
+
with open(fname, 'r') as f:
|
212 |
+
reader = csv.DictReader(f, delimiter='\t')
|
213 |
+
for row in reader:
|
214 |
+
yield row
|
215 |
+
|
216 |
+
|
217 |
+
def load_jsonl(fname):
|
218 |
+
with open(fname, 'r') as f:
|
219 |
+
for line in f:
|
220 |
+
yield json.loads(line)
|
221 |
+
|
222 |
+
|
223 |
+
LOADERS = {
|
224 |
+
'.tsv': load_tsv,
|
225 |
+
'.jsonl': load_jsonl
|
226 |
+
}
|
227 |
+
|
228 |
+
|
229 |
+
def load_trigger_dataset(fname, templatizer, use_ctx, limit=None):
|
230 |
+
loader = LOADERS[fname.suffix]
|
231 |
+
instances = []
|
232 |
+
|
233 |
+
for x in loader(fname):
|
234 |
+
try:
|
235 |
+
if use_ctx:
|
236 |
+
# For relation extraction, skip facts that don't have context sentence
|
237 |
+
if 'evidences' not in x:
|
238 |
+
logger.warning('Skipping RE sample because it lacks context sentences: {}'.format(x))
|
239 |
+
continue
|
240 |
+
|
241 |
+
evidences = x['evidences']
|
242 |
+
|
243 |
+
# Randomly pick a context sentence
|
244 |
+
obj_surface, masked_sent = random.choice([(evidence['obj_surface'], evidence['masked_sentence']) for evidence in evidences])
|
245 |
+
words = masked_sent.split()
|
246 |
+
if len(words) > MAX_CONTEXT_LEN:
|
247 |
+
# If the masked sentence is too long, use the first X tokens. For training we want to keep as many samples as we can.
|
248 |
+
masked_sent = ' '.join(words[:MAX_CONTEXT_LEN])
|
249 |
+
|
250 |
+
# If truncated context sentence still has MASK, we need to replace it with object surface
|
251 |
+
# We explicitly use [MASK] because all TREx fact's context sentences use it
|
252 |
+
context = masked_sent.replace('[MASK]', obj_surface)
|
253 |
+
x['context'] = context
|
254 |
+
model_inputs, label_id = templatizer(x)
|
255 |
+
else:
|
256 |
+
model_inputs, label_id = templatizer(x)
|
257 |
+
except ValueError as e:
|
258 |
+
logger.warning('Encountered error "%s" when processing "%s". Skipping.', e, x)
|
259 |
+
continue
|
260 |
+
else:
|
261 |
+
instances.append((model_inputs, label_id))
|
262 |
+
if limit:
|
263 |
+
return random.sample(instances, limit)
|
264 |
+
else:
|
265 |
+
return instances
|
266 |
+
|
267 |
+
|
268 |
+
def load_augmented_trigger_dataset(fname, templatizer, limit=None):
|
269 |
+
loader = LOADERS[fname.suffix]
|
270 |
+
instances = []
|
271 |
+
|
272 |
+
# For augmented relation extraction, we need to replace obj_label with another obj_label, and replace obj_surface with a surface form of the new obj_label
|
273 |
+
unique_objs_dict = defaultdict(list)
|
274 |
+
# Also for augmented relation extraction, we need to accumulate all facts and process them afterwards
|
275 |
+
facts = []
|
276 |
+
|
277 |
+
for x in loader(fname):
|
278 |
+
try:
|
279 |
+
sub_label = x['sub_label']
|
280 |
+
obj_label = x['obj_label']
|
281 |
+
|
282 |
+
# For relation extraction, skip facts that don't have context sentence
|
283 |
+
if 'evidences' not in x:
|
284 |
+
logger.warning('Skipping RE sample because it lacks context sentences: {}'.format(x))
|
285 |
+
continue
|
286 |
+
|
287 |
+
evidences = x['evidences']
|
288 |
+
|
289 |
+
# Gather all UNIQUE objects and their surface forms if its augmented relation extraction
|
290 |
+
for evidence in evidences:
|
291 |
+
obj_surface = evidence['obj_surface']
|
292 |
+
masked_sent = evidence['masked_sentence']
|
293 |
+
unique_objs_dict[obj_label].append(obj_surface)
|
294 |
+
|
295 |
+
# Randomly pick a context sentence
|
296 |
+
obj_surface, masked_sent = random.choice([(evidence['obj_surface'], evidence['masked_sentence']) for evidence in evidences])
|
297 |
+
words = masked_sent.split()
|
298 |
+
if len(words) > MAX_CONTEXT_LEN:
|
299 |
+
# If the masked sentence is too long, use the first X tokens. For training we want to keep as many samples as we can.
|
300 |
+
masked_sent = ' '.join(words[:MAX_CONTEXT_LEN])
|
301 |
+
|
302 |
+
x['context'] = masked_sent
|
303 |
+
facts.append(x)
|
304 |
+
except ValueError as e:
|
305 |
+
logger.warning('Encountered error "%s" when processing "%s". Skipping.', e, x)
|
306 |
+
|
307 |
+
# Go through all facts and replace each object with a new one. Also insert the new object (surface form) into the masked sentence
|
308 |
+
synth_facts = []
|
309 |
+
for fact in facts:
|
310 |
+
sub_label = fact['sub_label']
|
311 |
+
obj_label = fact['obj_label']
|
312 |
+
masked_sent = fact['context']
|
313 |
+
# print('Original fact: ({}, {}, {})'.format(sub_label, obj_label, masked_sent))
|
314 |
+
synth_obj_label = random.choice([x for x in unique_objs_dict.keys() if x != obj_label])
|
315 |
+
synth_obj_surface = random.choice(unique_objs_dict[synth_obj_label])
|
316 |
+
synth_ctx = masked_sent.replace('[MASK]', synth_obj_surface)
|
317 |
+
# print('Synthetic fact: ({}, {}, {})\n'.format(sub_label, synth_obj_label, synth_ctx))
|
318 |
+
# Reassign the labels and context sentence
|
319 |
+
synth_fact = copy.deepcopy(fact)
|
320 |
+
synth_fact['sub_label'] = sub_label
|
321 |
+
synth_fact['obj_label'] = synth_obj_label
|
322 |
+
synth_fact['context'] = synth_ctx
|
323 |
+
synth_facts.append(synth_fact)
|
324 |
+
|
325 |
+
# Go through facts, templatize each one, then append them to instances
|
326 |
+
for fact in synth_facts:
|
327 |
+
model_inputs, label_id = templatizer(fact)
|
328 |
+
instances.append((model_inputs, label_id))
|
329 |
+
|
330 |
+
if limit:
|
331 |
+
return random.sample(instances, limit)
|
332 |
+
else:
|
333 |
+
return instances
|
334 |
+
|
335 |
+
|
336 |
+
def load_classification_dataset(
|
337 |
+
fname,
|
338 |
+
tokenizer,
|
339 |
+
input_field_a,
|
340 |
+
input_field_b=None,
|
341 |
+
label_field='label',
|
342 |
+
label_map=None,
|
343 |
+
limit=None
|
344 |
+
):
|
345 |
+
"""
|
346 |
+
Loads a dataset for classification
|
347 |
+
|
348 |
+
Parameters
|
349 |
+
==========
|
350 |
+
tokenizer : transformers.PretrainedTokenizer
|
351 |
+
Maps text to id tensors.
|
352 |
+
sentence1 :
|
353 |
+
"""
|
354 |
+
instances = []
|
355 |
+
label_map = label_map or {}
|
356 |
+
loader = LOADERS[fname.suffix]
|
357 |
+
for instance in loader(fname):
|
358 |
+
logger.debug(instance)
|
359 |
+
model_inputs = tokenizer.encode_plus(
|
360 |
+
instance[input_field_a],
|
361 |
+
instance[input_field_b] if input_field_b else None,
|
362 |
+
add_special_tokens=True,
|
363 |
+
# add_prefix_space=True,
|
364 |
+
return_tensors='pt'
|
365 |
+
)
|
366 |
+
logger.debug(model_inputs)
|
367 |
+
label = instance[label_field]
|
368 |
+
if label not in label_map:
|
369 |
+
label_map[label] = len(label_map)
|
370 |
+
label_id = label_map[label]
|
371 |
+
label_id = torch.tensor([[label_id]]) # To make collator expectation
|
372 |
+
logger.debug(f'Label id: {label_id}')
|
373 |
+
instances.append((model_inputs, label_id))
|
374 |
+
if limit:
|
375 |
+
instances = random.sample(instances, limit)
|
376 |
+
return instances, label_map
|
prompts/fact_retrieval_bert_prompts.jsonl
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{"relation": "P1001", "template": "[X]vik nationwide disabilities policing within [Y]."}
|
2 |
+
{"relation": "P101", "template": "[X] probability earliest fame totaled studying [Y]."}
|
3 |
+
{"relation": "P103", "template": "[X]PA communerug speaks proper [Y]."}
|
4 |
+
{"relation": "P106", "template": "[X] supporters studied politicians musician turned [Y]."}
|
5 |
+
{"relation": "P108", "template": "[X] 1987adeNBC computing succeeded [Y]."}
|
6 |
+
{"relation": "P127", "template": "[X] is hindwings mainline architecture within [Y]."}
|
7 |
+
{"relation": "P1303", "template": "[X] playingdrum concertoative electric [Y]."}
|
8 |
+
{"relation": "P131", "template": "[X]ediatric close suburb throughout northwest [Y]."}
|
9 |
+
{"relation": "P136", "template": "[X] freaking genre orchestra fiction acid [Y]."}
|
10 |
+
{"relation": "P1376", "template": "[X] boasts native territory traditionally called [Y]."}
|
11 |
+
{"relation": "P138", "template": "[X] consistslanche classical name of [Y]."}
|
12 |
+
{"relation": "P140", "template": "[X]urn openly explicitly mosques practicing [Y]."}
|
13 |
+
{"relation": "P1412", "template": "[X] receivedorganisation 1904 speaking only [Y]."}
|
14 |
+
{"relation": "P159", "template": "[X] isnky galleries headquartered in [Y]."}
|
15 |
+
{"relation": "P17", "template": "[X] is association footballled southeastern [Y]."}
|
16 |
+
{"relation": "P176", "template": "[X] was flight series manufactured by [Y]."}
|
17 |
+
{"relation": "P178", "template": "[X] is memory arcade branding by [Y]."}
|
18 |
+
{"relation": "P19", "template": "[X] clocks literary economist relocated to [Y]."}
|
19 |
+
{"relation": "P190", "template": "[X] proceeded worldwidedick offices near [Y]."}
|
20 |
+
{"relation": "P20", "template": "[X] reorganizationotype photographic studio in [Y]."}
|
21 |
+
{"relation": "P264", "template": "[X] cameo explanation\u00f6table sued [Y]."}
|
22 |
+
{"relation": "P27", "template": "[X] m\u00b3 badminton pieces internationally representing [Y]."}
|
23 |
+
{"relation": "P276", "template": "[X] consists kilograms centred neighborhoods in [Y]."}
|
24 |
+
{"relation": "P279", "template": "[X] is \u00ee adequately termed coated [Y]."}
|
25 |
+
{"relation": "P30", "template": "[X] is commune polar continent in [Y]."}
|
26 |
+
{"relation": "P31", "template": "[X] isious 1970s southwardlier [Y]."}
|
27 |
+
{"relation": "P36", "template": "[X] includesiidae geologic countryside near [Y]."}
|
28 |
+
{"relation": "P361", "template": "[X] isaul archaic section of [Y]."}
|
29 |
+
{"relation": "P364", "template": "[X]dak \u20ac dancers speak standard [Y]."}
|
30 |
+
{"relation": "P37", "template": "[X]inen dialects resembled officially exclusively [Y]."}
|
31 |
+
{"relation": "P39", "template": "[X] explorers voting municipal \u2192 consecrated [Y]."}
|
32 |
+
{"relation": "P407", "template": "[X] playedi\u0107 every dialect but [Y]."}
|
33 |
+
{"relation": "P413", "template": "[X] played colors skier \u2194 defensive [Y]."}
|
34 |
+
{"relation": "P449", "template": "[X] uncredited recording remake aired on [Y]."}
|
35 |
+
{"relation": "P463", "template": "[X] splits artisticlogy prior joining [Y]."}
|
36 |
+
{"relation": "P47", "template": "[X] shares undrafted border northeast neighbours [Y]."}
|
37 |
+
{"relation": "P495", "template": "[X] album spanninggie chart in [Y]."}
|
38 |
+
{"relation": "P527", "template": "[X] nickname involves \u032f\u00bddized [Y]."}
|
39 |
+
{"relation": "P530", "template": "[X] nightclubrah preceding relations with [Y]."}
|
40 |
+
{"relation": "P740", "template": "[X] refers drum blog centred downtown [Y]."}
|
41 |
+
{"relation": "P937", "template": "[X] vol \u300elson gallery in [Y]."}
|
prompts/fact_retrieval_roberta_prompts.jsonl
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{"relation": "P1001", "template": " [X]\u00a2List unsu rivers spanning [Y] ."}
|
2 |
+
{"relation": "P101", "template": " [X] 1830 dissertation applying mathsucci [Y] ."}
|
3 |
+
{"relation": "P103", "template": " [X]neau optionally fluent!?\" traditional [Y] ."}
|
4 |
+
{"relation": "P106", "template": " [X] (), astronomers businessman\u00b7former [Y] ."}
|
5 |
+
{"relation": "P108", "template": " [X] heads opio computer divisionersen [Y] ."}
|
6 |
+
{"relation": "P127", "template": " [X] picThom unwillingness officially governs [Y] ."}
|
7 |
+
{"relation": "P1303", "template": " [X]Trump learned soloKeefe classical [Y] ."}
|
8 |
+
{"relation": "P131", "template": " [X] scenic neighbourhood occurred enqu northeastern [Y] ."}
|
9 |
+
{"relation": "P136", "template": " [X] blends postwar hostage drama sax [Y] ."}
|
10 |
+
{"relation": "P1376", "template": " [X] limestone depositedati boroughDepending [Y] ."}
|
11 |
+
{"relation": "P138", "template": " [X] =alysis northern spellingSaint [Y] ."}
|
12 |
+
{"relation": "P140", "template": " [X] traced pagan fascism individuality extremist [Y] ."}
|
13 |
+
{"relation": "P1412", "template": " [X] translatedANCauld writings binaries [Y] ."}
|
14 |
+
{"relation": "P159", "template": " [X] spinsCompany organisedLocation near [Y] ."}
|
15 |
+
{"relation": "P17", "template": " [X]exec scenic provinces iodine northeastern [Y] ."}
|
16 |
+
{"relation": "P176", "template": " [X] 125definition enormously stunned manufacturer [Y] ."}
|
17 |
+
{"relation": "P178", "template": " [X] 1987 floppy simulator users sued [Y] ."}
|
18 |
+
{"relation": "P19", "template": " [X] 2002 protesting disco constructionamine [Y] ."}
|
19 |
+
{"relation": "P190", "template": " [X] flight facultiesyna arrivesfolios [Y] ."}
|
20 |
+
{"relation": "P20", "template": " [X].. enigmatic twentieth nowadays near [Y] ."}
|
21 |
+
{"relation": "P264", "template": " [X] touring 1958 defunct videog label [Y] ."}
|
22 |
+
{"relation": "P27", "template": " [X] offic organise forests statutes northwestern [Y] ."}
|
23 |
+
{"relation": "P276", "template": " [X] manoeuv constructs whistleblowers hills near [Y] ."}
|
24 |
+
{"relation": "P279", "template": " [X],formerly prayers unstaceous [Y] ."}
|
25 |
+
{"relation": "P30", "template": " [X] coral caves symb polar zone [Y] ."}
|
26 |
+
{"relation": "P31", "template": " [X] (), therapists nationallyrecorded enchanted [Y] ."}
|
27 |
+
{"relation": "P36", "template": " [X] 1954 misinterpretburg narrowly battered [Y] ."}
|
28 |
+
{"relation": "P361", "template": " [X], supplementaryfoot structuresNorthern [Y] ."}
|
29 |
+
{"relation": "P364", "template": " [X]vanathering preferred languagesEnglish [Y] ."}
|
30 |
+
{"relation": "P37", "template": " [X]onen tribes descending speak mainly [Y] ."}
|
31 |
+
{"relation": "P39", "template": " [X] billionaire elected unp\u200b\u200bCatholic [Y] ."}
|
32 |
+
{"relation": "P407", "template": " [X] scaven pronunciation.*Wikipedia speaks [Y] ."}
|
33 |
+
{"relation": "P413", "template": " [X],'' (), ex-,Liverpool [Y] ."}
|
34 |
+
{"relation": "P449", "template": " [X] premiered 1989 simulatively instinctively [Y] ."}
|
35 |
+
{"relation": "P463", "template": " [X] joins reformedolitical endangered grouping [Y] ."}
|
36 |
+
{"relation": "P47", "template": " [X] combinesfill marry territory surrounding [Y] ."}
|
37 |
+
{"relation": "P495", "template": " [X] condom announces manufacturer residence exported [Y] ."}
|
38 |
+
{"relation": "P527", "template": " [X] minus asylum cooked = compressed [Y] ."}
|
39 |
+
{"relation": "P530", "template": " [X]varOriginally kidnappedstrate neighboring [Y] ."}
|
40 |
+
{"relation": "P740", "template": " [X] prefersLondon whilst 182 favors [Y] ."}
|
41 |
+
{"relation": "P937", "template": " [X] bicycles investments railway neighborhoodAlternatively [Y] ."}
|
prompts/relation_extraction_bert_prompts.jsonl
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{"relation": "P1001", "template": "[X] dispatched state consul federally to [Y]."}
|
2 |
+
{"relation": "P101", "template": "[X]icidalology fascinated textbook on [Y]."}
|
3 |
+
{"relation": "P103", "template": "[X] sent literary visa speaking predominantly [Y]."}
|
4 |
+
{"relation": "P106", "template": "[X] as invented firstractical aspiring [Y]."}
|
5 |
+
{"relation": "P108", "template": "[X] funded transmissions business involvement at [Y]."}
|
6 |
+
{"relation": "P127", "template": "[X] sentuti limo sponsorship to [Y]."}
|
7 |
+
{"relation": "P1303", "template": "[X] ] podcast 1935 practices unison [Y]."}
|
8 |
+
{"relation": "P131", "template": "[X] fewer congressional consul corporation bordering [Y]."}
|
9 |
+
{"relation": "P136", "template": "[X] drama bacteriatitled 80s cosmic [Y]."}
|
10 |
+
{"relation": "P138", "template": "[X] positively cited the town nicknamed [Y]."}
|
11 |
+
{"relation": "P140", "template": "[X] 2006 revelation convertedtsky practiced [Y]."}
|
12 |
+
{"relation": "P1412", "template": "[X] imported colleges translations exports speak [Y]."}
|
13 |
+
{"relation": "P159", "template": "[X]rica headquartered town across from [Y]."}
|
14 |
+
{"relation": "P17", "template": "[X] constituteronological country embassy to [Y]."}
|
15 |
+
{"relation": "P176", "template": "[X] became plays sponsor co with [Y]."}
|
16 |
+
{"relation": "P178", "template": "[X] game handed showcased separately by [Y]."}
|
17 |
+
{"relation": "P19", "template": "[X]lancheheim grew house in [Y]."}
|
18 |
+
{"relation": "P190", "template": "[X] attended waived both cities including [Y]."}
|
19 |
+
{"relation": "P20", "template": "[X]rseyjee maintained apartment in [Y]."}
|
20 |
+
{"relation": "P264", "template": "[X] became commemorated label label succeeding [Y]."}
|
21 |
+
{"relation": "P27", "template": "[X] country goals diaspora diplomat visited [Y]."}
|
22 |
+
{"relation": "P276", "template": "[X] visited crore sister town to [Y]."}
|
23 |
+
{"relation": "P279", "template": "[X]districtutical\u00e8ne word resembling [Y]."}
|
24 |
+
{"relation": "P30", "template": "[X] subfamily pardon globallyinae throughout [Y]."}
|
25 |
+
{"relation": "P31", "template": "[X] nm charitiespository nicknamed underwater [Y]."}
|
26 |
+
{"relation": "P36", "template": "[X] sued wraps owner city of [Y]."}
|
27 |
+
{"relation": "P361", "template": "[X] passwordU emblem inspired by [Y]."}
|
28 |
+
{"relation": "P364", "template": "[X] translated mistress culturally language notably [Y]."}
|
29 |
+
{"relation": "P37", "template": "[X] called countries speaking originually [Y]."}
|
30 |
+
{"relation": "P39", "template": "[X]lina \u2500 fifteenthously supreme [Y]."}
|
31 |
+
{"relation": "P407", "template": "[X] sent vocalist languages foreign especially [Y]."}
|
32 |
+
{"relation": "P413", "template": "[X] acts sentiment rookie minimum prone [Y]."}
|
33 |
+
{"relation": "P449", "template": "[X] novels channel similarly also joined [Y]."}
|
34 |
+
{"relation": "P463", "template": "[X] member testified frontman founded also [Y]."}
|
35 |
+
{"relation": "P47", "template": "[X] became consulate will include [Y]."}
|
36 |
+
{"relation": "P495", "template": "[X] shows website country abroad includes [Y]."}
|
37 |
+
{"relation": "P530", "template": "[X] send globally dedicated embassy to [Y]."}
|
38 |
+
{"relation": "P740", "template": "[X] music compliment residents resident in [Y]."}
|
39 |
+
{"relation": "P937", "template": "[X] described courthouse residency career near [Y]."}
|
prompts/relation_extraction_roberta_prompts.jsonl
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{"relation": "P1001", "template": "[X] congratulated killers counterparts residing outage [Y] ."}
|
2 |
+
{"relation": "P101", "template": "[X]itations illustratingModern\u2010 risked [Y] ."}
|
3 |
+
{"relation": "P103", "template": "[X] website canceled learn languageposition [Y] ."}
|
4 |
+
{"relation": "P106", "template": "[X]Officersoglu internationally renown trained [Y] ."}
|
5 |
+
{"relation": "P108", "template": "[X] culinary \u00a9 fixtures file courtesy [Y] ."}
|
6 |
+
{"relation": "P127", "template": "[X] proudly celebrating playthrough ties cultured [Y] ."}
|
7 |
+
{"relation": "P1303", "template": "[X] gubernatorial 410 antique vibritone [Y] ."}
|
8 |
+
{"relation": "P131", "template": "[X] \"{\">phys town bordering [Y] ."}
|
9 |
+
{"relation": "P136", "template": "[X] poking maneuvers genre synonymous baseline [Y] ."}
|
10 |
+
{"relation": "P138", "template": "[X] slideshow painting spelling homage ()); [Y] ."}
|
11 |
+
{"relation": "P140", "template": "[X] modified kosher spiritualitycert imitate [Y] ."}
|
12 |
+
{"relation": "P1412", "template": "[X] translating pled spoken callback fluent [Y] ."}
|
13 |
+
{"relation": "P159", "template": "[X] hometown bonding hahaVisit downtown [Y] ."}
|
14 |
+
{"relation": "P17", "template": "[X] embassy factual diplomatic ambassadorooked [Y] ."}
|
15 |
+
{"relation": "P176", "template": "[X] sponsorship respectfully complimentary courtesy fuckin [Y] ."}
|
16 |
+
{"relation": "P178", "template": "[X] wikiPlanetSOURCE sponsored reckon [Y] ."}
|
17 |
+
{"relation": "P19", "template": "[X] slideshow referencing correctness hometown continent [Y] ."}
|
18 |
+
{"relation": "P190", "template": "[X] planetaking luggage transfer reaching [Y] ."}
|
19 |
+
{"relation": "P20", "template": "[X] ironically resided located recalling downtown [Y] ."}
|
20 |
+
{"relation": "P264", "template": "[X] claims primary label membershipdisc [Y] ."}
|
21 |
+
{"relation": "P27", "template": "[X] smugglers smuggled davidjl forcibly affordability [Y] ."}
|
22 |
+
{"relation": "P276", "template": "[X] photographed>:Folder cliffs overlooking [Y] ."}
|
23 |
+
{"relation": "P279", "template": "[X]enez sculpture disguised mailboxSensor [Y] ."}
|
24 |
+
{"relation": "P30", "template": "[X] tropical continent tropicalmessageoutheastern [Y] ."}
|
25 |
+
{"relation": "P31", "template": "[X] bullies campuses hypothetical substitutiononic [Y] ."}
|
26 |
+
{"relation": "P36", "template": "[X] border*.NOWVisit downtown [Y] ."}
|
27 |
+
{"relation": "P361", "template": "[X] ~/FlickrFORE blessing representing [Y] ."}
|
28 |
+
{"relation": "P364", "template": "[X] population language predomin smoker installer [Y] ."}
|
29 |
+
{"relation": "P37", "template": "[X] screamed visibly fluent descendants nutrients [Y] ."}
|
30 |
+
{"relation": "P39", "template": "[X] slideshow photo\u30aa workforce appointed [Y] ."}
|
31 |
+
{"relation": "P407", "template": "[X] screened pioneering documentaries translated curry [Y] ."}
|
32 |
+
{"relation": "P413", "template": "[X] learnedailed springTherefore veteran [Y] ."}
|
33 |
+
{"relation": "P449", "template": "[X] slideshow courtesy recommendation television broadcaster [Y] ."}
|
34 |
+
{"relation": "P463", "template": "[X] facebook referencing summarizes monikerTeam [Y] ."}
|
35 |
+
{"relation": "P47", "template": "[X] aquatic contacted diplomatic consulate imperialist [Y] ."}
|
36 |
+
{"relation": "P495", "template": "[X] webpage highlighting cultural exportsrero [Y] ."}
|
37 |
+
{"relation": "P530", "template": "[X]ECD establishes diplomatic ties fut [Y] ."}
|
38 |
+
{"relation": "P740", "template": "[X]empt adjoining merchants utilized downtown [Y] ."}
|
39 |
+
{"relation": "P937", "template": "[X]\u00df died\"\" 1931 barbecue [Y] ."}
|
pytest.ini
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[pytest]
|
2 |
+
testpaths = tests/
|
3 |
+
pythonpath = ./
|
4 |
+
log_format = %(asctime)s - %(levelname)s - %(name)s - %(message)s
|
5 |
+
log_level = DEBUG
|
requirements.txt
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
streamlit==0.79.0
|
2 |
+
tqdm==4.49.0
|
3 |
+
pandas==1.2.1
|
4 |
+
numpy==1.17.2
|
5 |
+
torch==1.4.0
|
6 |
+
transformers==2.9.1
|
7 |
+
spacy==2.2.0
|
8 |
+
termcolor==1.1.0
|
9 |
+
colorama==0.4.1
|
10 |
+
matplotlib==3.1.1
|
11 |
+
pytest
|
scripts/run_fact_retrieval_example.sh
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
# Experiment 8
|
3 |
+
# Task: fact retrieval
|
4 |
+
# Model: RoBERTa
|
5 |
+
# Batch sizes: 56
|
6 |
+
# Iters: 1000
|
7 |
+
# Filtering: True
|
8 |
+
|
9 |
+
datadir=$1
|
10 |
+
logfile=$2
|
11 |
+
|
12 |
+
# Clear files
|
13 |
+
cat /dev/null > $logfile
|
14 |
+
cat /dev/null > ${logfile}.log
|
15 |
+
|
16 |
+
for path in $datadir/*; do
|
17 |
+
filename=$(basename "$path")
|
18 |
+
time CUDA_VISIBLE_DEVICES=3 python -m autoprompt.create_trigger \
|
19 |
+
--train $path/train.jsonl \
|
20 |
+
--dev $path/dev.jsonl \
|
21 |
+
--template '<s> {sub_label} [T] [T] [T] [T] [T] [P] . </s>' \
|
22 |
+
--num-cand 10 \
|
23 |
+
--accumulation-steps 1 \
|
24 |
+
--model-name roberta-large \
|
25 |
+
--bsz 56 \
|
26 |
+
--eval-size 56 \
|
27 |
+
--iters 1000 \
|
28 |
+
--label-field 'obj_label' \
|
29 |
+
--tokenize-labels \
|
30 |
+
--filter \
|
31 |
+
--print-lama >> $logfile 2>> ${logfile}.log
|
32 |
+
done
|
scripts/run_relation_extraction_example.sh
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
# Experiment 9
|
3 |
+
# Task: relation extraction
|
4 |
+
# Model: BERT
|
5 |
+
# Batch size: 32
|
6 |
+
# Iters: 500
|
7 |
+
# Filtering: True
|
8 |
+
|
9 |
+
datadir=$1
|
10 |
+
logfile=$2
|
11 |
+
|
12 |
+
# Clear files
|
13 |
+
cat /dev/null > $logfile
|
14 |
+
cat /dev/null > ${logfile}.log
|
15 |
+
|
16 |
+
for path in $datadir/*; do
|
17 |
+
filename=$(basename "$path")
|
18 |
+
time CUDA_VISIBLE_DEVICES=4 python -m autoprompt.create_trigger \
|
19 |
+
--train $path/train.jsonl \
|
20 |
+
--dev $path/dev.jsonl \
|
21 |
+
--template '[CLS] {context} [SEP] {sub_label} [T] [T] [T] [T] [T] [P] . [SEP]' \
|
22 |
+
--num-cand 10 \
|
23 |
+
--accumulation-steps 1 \
|
24 |
+
--model-name bert-base-cased \
|
25 |
+
--bsz 32 \
|
26 |
+
--eval-size 32 \
|
27 |
+
--iters 500 \
|
28 |
+
--label-field 'obj_label' \
|
29 |
+
--tokenize-labels \
|
30 |
+
--filter \
|
31 |
+
--print-lama \
|
32 |
+
--use-ctx >> $logfile 2>> ${logfile}.log
|
33 |
+
done
|
setup.py
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import setuptools
|
3 |
+
import sys
|
4 |
+
|
5 |
+
|
6 |
+
# Load README to get long description.
|
7 |
+
with open('README.md') as f:
|
8 |
+
_LONG_DESCRIPTION = f.read()
|
9 |
+
|
10 |
+
|
11 |
+
setuptools.setup(
|
12 |
+
name='autoprompt',
|
13 |
+
version='0.0.1',
|
14 |
+
description='AutoPrompt',
|
15 |
+
long_description=_LONG_DESCRIPTION,
|
16 |
+
long_description_content_type='text/markdown',
|
17 |
+
author='UCI NLP',
|
18 |
+
url='https://github.com/ucinlp/autoprompt',
|
19 |
+
packages=setuptools.find_packages(),
|
20 |
+
install_requires=[ ],
|
21 |
+
extras_require={
|
22 |
+
'test': ['pytest']
|
23 |
+
},
|
24 |
+
classifiers=[
|
25 |
+
'Intended Audience :: Science/Research',
|
26 |
+
'Topic :: Scientific/Engineering :: Artificial Intelligence',
|
27 |
+
],
|
28 |
+
keywords='text nlp machinelearning',
|
29 |
+
)
|
tests/test_create_trigger.py
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from unittest import TestCase
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from transformers import AutoConfig, AutoModelWithLMHead, AutoTokenizer
|
5 |
+
|
6 |
+
import autoprompt.create_trigger as ct
|
7 |
+
|
8 |
+
|
9 |
+
def _load(model_name):
|
10 |
+
config = AutoConfig.from_pretrained('bert-base-cased')
|
11 |
+
model = AutoModelWithLMHead.from_pretrained('bert-base-cased')
|
12 |
+
tokenizer = AutoTokenizer.from_pretrained('bert-base-cased')
|
13 |
+
return config, model, tokenizer
|
14 |
+
|
15 |
+
|
16 |
+
class TestGetEmbeddings(TestCase):
|
17 |
+
def test_bert(self):
|
18 |
+
model_name = 'bert-base-cased'
|
19 |
+
config, model, tokenizer = _load(model_name)
|
20 |
+
embeddings = ct.get_embeddings(model, config)
|
21 |
+
self.assertEqual(embeddings.weight.shape[0], config.vocab_size)
|
22 |
+
|
23 |
+
def test_roberta(self):
|
24 |
+
model_name = 'roberta-base'
|
25 |
+
config, model, tokenizer = _load(model_name)
|
26 |
+
embeddings = ct.get_embeddings(model, config)
|
27 |
+
self.assertEqual(embeddings.weight.shape[0], config.vocab_size)
|
28 |
+
|
29 |
+
|
30 |
+
class TestGradientStorage(TestCase):
|
31 |
+
def test_gradient_storage(self):
|
32 |
+
num_embeddings = 3
|
33 |
+
embedding_dim = 4
|
34 |
+
embeddings = torch.nn.Embedding(num_embeddings, embedding_dim)
|
35 |
+
embedding_storage = ct.GradientStorage(embeddings)
|
36 |
+
|
37 |
+
inputs = torch.tensor([0, 1, 2, 1])
|
38 |
+
outputs = embeddings(inputs)
|
39 |
+
outputs.retain_grad()
|
40 |
+
loss = outputs.sum()
|
41 |
+
loss.backward()
|
42 |
+
|
43 |
+
assert torch.equal(outputs.grad, embedding_storage.get())
|
44 |
+
|
45 |
+
|
46 |
+
def test_replace_trigger_tokens():
|
47 |
+
model_inputs = {
|
48 |
+
'input_ids': torch.tensor([
|
49 |
+
[1, 2, 3, 4],
|
50 |
+
[1, 1, 1, 0]
|
51 |
+
])
|
52 |
+
}
|
53 |
+
trigger_ids = torch.tensor([[5, 6]])
|
54 |
+
trigger_mask = torch.tensor([
|
55 |
+
[True, True, False, False],
|
56 |
+
[False, True, False, True]
|
57 |
+
])
|
58 |
+
replaced = ct.replace_trigger_tokens(model_inputs, trigger_ids, trigger_mask)
|
59 |
+
expected = torch.tensor([
|
60 |
+
[5, 6, 3, 4],
|
61 |
+
[1, 5, 1, 6]
|
62 |
+
])
|
63 |
+
assert torch.equal(expected, replaced['input_ids'])
|
tests/test_utils.py
ADDED
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from unittest import TestCase
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from torch.utils.data import DataLoader
|
5 |
+
from transformers import AutoConfig, AutoTokenizer
|
6 |
+
|
7 |
+
import autoprompt.utils as utils
|
8 |
+
|
9 |
+
|
10 |
+
class TestEncodeLabel(TestCase):
|
11 |
+
def setUp(self):
|
12 |
+
self._tokenizer = AutoTokenizer.from_pretrained('bert-base-cased')
|
13 |
+
|
14 |
+
def test_single_token(self):
|
15 |
+
output = utils.encode_label(self._tokenizer, 'the')
|
16 |
+
expected_output = torch.tensor([self._tokenizer.convert_tokens_to_ids(['the'])])
|
17 |
+
assert torch.equal(output, expected_output)
|
18 |
+
|
19 |
+
def test_multiple_tokens(self):
|
20 |
+
output = utils.encode_label(self._tokenizer, ['a', 'the'])
|
21 |
+
expected_output = torch.tensor([
|
22 |
+
self._tokenizer.convert_tokens_to_ids(['a', 'the'])
|
23 |
+
])
|
24 |
+
assert torch.equal(output, expected_output)
|
25 |
+
|
26 |
+
|
27 |
+
class TestTriggerTemplatizer(TestCase):
|
28 |
+
def setUp(self):
|
29 |
+
self.default_template = '[T] [T] {arbitrary} [T] {fields} [P]'
|
30 |
+
self.default_config = AutoConfig.from_pretrained('bert-base-cased')
|
31 |
+
self.default_tokenizer = AutoTokenizer.from_pretrained('bert-base-cased')
|
32 |
+
utils.add_task_specific_tokens(self.default_tokenizer)
|
33 |
+
self.default_instance = {
|
34 |
+
'arbitrary': 'does this',
|
35 |
+
'fields': 'work',
|
36 |
+
'label': 'and'
|
37 |
+
}
|
38 |
+
|
39 |
+
def test_bert(self):
|
40 |
+
templatizer = utils.TriggerTemplatizer(
|
41 |
+
self.default_template,
|
42 |
+
self.default_config,
|
43 |
+
self.default_tokenizer,
|
44 |
+
add_special_tokens=False
|
45 |
+
)
|
46 |
+
model_inputs, label = templatizer(self.default_instance)
|
47 |
+
|
48 |
+
# Label should be mapped to its token id
|
49 |
+
expected_label = torch.tensor([self.default_tokenizer.convert_tokens_to_ids([self.default_instance['label']])])
|
50 |
+
assert torch.equal(expected_label, label)
|
51 |
+
|
52 |
+
# For BERT ouput is expected to have the following keys
|
53 |
+
assert 'input_ids' in model_inputs
|
54 |
+
assert 'token_type_ids' in model_inputs
|
55 |
+
assert 'attention_mask' in model_inputs
|
56 |
+
|
57 |
+
# Test that the custom masks match our expectations
|
58 |
+
expected_trigger_mask = torch.tensor(
|
59 |
+
[[True, True, False, False, True, False, False]]
|
60 |
+
)
|
61 |
+
assert torch.equal(expected_trigger_mask, model_inputs['trigger_mask'])
|
62 |
+
|
63 |
+
expected_predict_mask = torch.tensor(
|
64 |
+
[[False, False, False, False, False, False, True]]
|
65 |
+
)
|
66 |
+
assert torch.equal(expected_predict_mask, model_inputs['predict_mask'])
|
67 |
+
|
68 |
+
# Lastly, ensure [P] is replaced by a [MASK] token
|
69 |
+
input_ids = model_inputs['input_ids']
|
70 |
+
predict_mask = model_inputs['predict_mask']
|
71 |
+
predict_token_id = input_ids[predict_mask].squeeze().item()
|
72 |
+
assert predict_token_id == self.default_tokenizer.mask_token_id
|
73 |
+
|
74 |
+
def test_roberta(self):
|
75 |
+
config = AutoConfig.from_pretrained('roberta-base')
|
76 |
+
tokenizer = AutoTokenizer.from_pretrained('roberta-base')
|
77 |
+
utils.add_task_specific_tokens(tokenizer)
|
78 |
+
templatizer = utils.TriggerTemplatizer(
|
79 |
+
self.default_template,
|
80 |
+
config,
|
81 |
+
tokenizer,
|
82 |
+
add_special_tokens=False
|
83 |
+
)
|
84 |
+
|
85 |
+
model_inputs, label = templatizer(self.default_instance)
|
86 |
+
|
87 |
+
# Label should be mapped to its token id
|
88 |
+
expected_label = torch.tensor([tokenizer.convert_tokens_to_ids([self.default_instance['label']])])
|
89 |
+
assert torch.equal(expected_label, label)
|
90 |
+
|
91 |
+
# For BERT ouput is expected to have the following keys
|
92 |
+
print(model_inputs)
|
93 |
+
assert 'input_ids' in model_inputs
|
94 |
+
assert 'attention_mask' in model_inputs
|
95 |
+
|
96 |
+
# Test that the custom masks match our expectations
|
97 |
+
expected_trigger_mask = torch.tensor(
|
98 |
+
[[True, True, False, False, True, False, False]]
|
99 |
+
)
|
100 |
+
assert torch.equal(expected_trigger_mask, model_inputs['trigger_mask'])
|
101 |
+
|
102 |
+
expected_predict_mask = torch.tensor(
|
103 |
+
[[False, False, False, False, False, False, True]]
|
104 |
+
)
|
105 |
+
assert torch.equal(expected_predict_mask, model_inputs['predict_mask'])
|
106 |
+
|
107 |
+
# Lastly, ensure [P] is replaced by a [MASK] token
|
108 |
+
input_ids = model_inputs['input_ids']
|
109 |
+
predict_mask = model_inputs['predict_mask']
|
110 |
+
predict_token_id = input_ids[predict_mask].squeeze().item()
|
111 |
+
assert predict_token_id == tokenizer.mask_token_id
|
112 |
+
|
113 |
+
|
114 |
+
class TestCollator(TestCase):
|
115 |
+
|
116 |
+
def test_collator(self):
|
117 |
+
template = '[T] [T] {arbitrary} [T] {fields} [P]'
|
118 |
+
tokenizer = AutoTokenizer.from_pretrained('bert-base-cased')
|
119 |
+
config = AutoConfig.from_pretrained('bert-base-cased')
|
120 |
+
utils.add_task_specific_tokens(tokenizer)
|
121 |
+
templatizer = utils.TriggerTemplatizer(
|
122 |
+
template,
|
123 |
+
config,
|
124 |
+
tokenizer,
|
125 |
+
add_special_tokens=False
|
126 |
+
)
|
127 |
+
collator = utils.Collator(pad_token_id=tokenizer.pad_token_id)
|
128 |
+
|
129 |
+
instances = [
|
130 |
+
{'arbitrary': 'a', 'fields': 'the', 'label': 'hot'},
|
131 |
+
{'arbitrary': 'a a', 'fields': 'the the', 'label': 'cold'}
|
132 |
+
]
|
133 |
+
templatized_instances = [templatizer(x) for x in instances]
|
134 |
+
loader = DataLoader(
|
135 |
+
templatized_instances,
|
136 |
+
batch_size=2,
|
137 |
+
shuffle=False,
|
138 |
+
collate_fn=collator
|
139 |
+
)
|
140 |
+
model_inputs, labels = next(iter(loader))
|
141 |
+
|
142 |
+
# Check results match our expectations
|
143 |
+
expected_labels = torch.tensor([
|
144 |
+
tokenizer.encode('hot', add_special_tokens=False, add_prefix_space=True),
|
145 |
+
tokenizer.encode('cold', add_special_tokens=False, add_prefix_space=True),
|
146 |
+
])
|
147 |
+
assert torch.equal(expected_labels, labels)
|
148 |
+
|
149 |
+
expected_trigger_mask = torch.tensor([
|
150 |
+
[True, True, False, True, False, False, False, False],
|
151 |
+
[True, True, False, False, True, False, False, False],
|
152 |
+
])
|
153 |
+
assert torch.equal(expected_trigger_mask, model_inputs['trigger_mask'])
|
154 |
+
|
155 |
+
expected_predict_mask = torch.tensor([
|
156 |
+
[False, False, False, False, False, True, False, False],
|
157 |
+
[False, False, False, False, False, False, False, True],
|
158 |
+
])
|
159 |
+
assert torch.equal(expected_predict_mask, model_inputs['predict_mask'])
|