yifan0sun commited on
Commit
93adbb1
·
verified ·
1 Parent(s): 04ccab0

Update ROBERTAmodel.py

Browse files
Files changed (1) hide show
  1. ROBERTAmodel.py +224 -207
ROBERTAmodel.py CHANGED
@@ -1,207 +1,224 @@
1
- from transformers import RobertaTokenizer, RobertaForMaskedLM
2
- import torch
3
- import torch.nn.functional as F
4
- from models import TransformerVisualizer
5
- from transformers import (
6
- RobertaForMaskedLM, RobertaForSequenceClassification
7
- )
8
- import os
9
-
10
- CACHE_DIR = "/data/hf_cache"
11
-
12
- class RoBERTaVisualizer(TransformerVisualizer):
13
- def __init__(self, task):
14
- super().__init__()
15
- self.task = task
16
-
17
-
18
-
19
- TOKENIZER = 'roberta-base'
20
- LOCAL_PATH = os.path.join(CACHE_DIR, "tokenizers",TOKENIZER)
21
-
22
- self.tokenizer = RobertaTokenizer.from_pretrained(LOCAL_PATH, local_files_only=True)
23
- """
24
- try:
25
- self.tokenizer = RobertaTokenizer.from_pretrained(LOCAL_PATH, local_files_only=True)
26
- except Exception as e:
27
- self.tokenizer = RobertaTokenizer.from_pretrained(TOKENIZER)
28
- self.tokenizer.save_pretrained(LOCAL_PATH)
29
- """
30
- if self.task == 'mlm':
31
-
32
- MODEL = "roberta-base"
33
- LOCAL_PATH = os.path.join(CACHE_DIR, "models",MODEL)
34
-
35
- self.model = RobertaForMaskedLM.from_pretrained( LOCAL_PATH, local_files_only=True )
36
- """
37
- try:
38
- self.model = RobertaForMaskedLM.from_pretrained( LOCAL_PATH, local_files_only=True )
39
- except Exception as e:
40
- self.model = RobertaForMaskedLM.from_pretrained( MODEL )
41
- self.model.save_pretrained(LOCAL_PATH)
42
- """
43
- elif self.task == 'sst':
44
-
45
-
46
- MODEL = 'textattack_roberta-base-SST-2'
47
- LOCAL_PATH = os.path.join(CACHE_DIR, "models",MODEL)
48
-
49
- self.model = RobertaForSequenceClassification.from_pretrained( LOCAL_PATH, local_files_only=True )
50
- """
51
- try:
52
- self.model = RobertaForSequenceClassification.from_pretrained( LOCAL_PATH, local_files_only=True )
53
- except Exception as e:
54
- self.model = RobertaForSequenceClassification.from_pretrained( MODEL )
55
- self.model.save_pretrained(LOCAL_PATH)
56
- """
57
-
58
- elif self.task == 'mnli':
59
- MODEL = "roberta-large-mnli"
60
- LOCAL_PATH = os.path.join(CACHE_DIR, "models",MODEL)
61
-
62
-
63
- self.model = RobertaForSequenceClassification.from_pretrained( LOCAL_PATH, local_files_only=True)
64
- """
65
- try:
66
- self.model = RobertaForSequenceClassification.from_pretrained( LOCAL_PATH, local_files_only=True)
67
- except Exception as e:
68
- self.model = RobertaForSequenceClassification.from_pretrained( MODEL)
69
- self.model.save_pretrained(LOCAL_PATH)
70
- """
71
-
72
-
73
-
74
- self.model.to(self.device)
75
- self.model.eval()
76
- self.num_attention_layers = self.model.config.num_hidden_layers
77
-
78
-
79
- def tokenize(self, text, hypothesis = ''):
80
-
81
-
82
-
83
- if len(hypothesis) == 0:
84
- encoded = self.tokenizer(text, return_tensors='pt', return_attention_mask=True,padding=False, truncation=True)
85
- else:
86
- encoded = self.tokenizer(text, hypothesis, return_tensors='pt', return_attention_mask=True,padding=False, truncation=True)
87
-
88
- input_ids = encoded['input_ids'].to(self.device)
89
- attention_mask = encoded['attention_mask'].to(self.device)
90
- tokens = self.tokenizer.convert_ids_to_tokens(input_ids[0])
91
- print('First time tokenizing:', tokens, len(tokens))
92
-
93
- response = {
94
- 'input_ids': input_ids,
95
- 'attention_mask': attention_mask,
96
- 'tokens': tokens
97
- }
98
- print(response)
99
- return response
100
-
101
- def predict(self, task, text, hypothesis='', maskID = None):
102
-
103
-
104
-
105
- if task == 'mlm':
106
- inputs = self.tokenizer(text, return_tensors='pt', padding=False, truncation=True)
107
- if maskID is not None and 0 <= maskID < inputs['input_ids'].size(1):
108
- inputs['input_ids'][0][maskID] = self.tokenizer.mask_token_id
109
- mask_index = maskID
110
- else:
111
- raise ValueError(f"Invalid maskID {maskID} for input of length {inputs['input_ids'].size(1)}")
112
- inputs = {k: v.to(self.device) for k, v in inputs.items()}
113
-
114
- with torch.no_grad():
115
- outputs = self.model(**inputs)
116
- logits = outputs.logits
117
-
118
- mask_logits = logits[0, mask_index]
119
- top_probs, top_indices = torch.topk(F.softmax(mask_logits, dim=-1), 10)
120
- decoded = self.tokenizer.convert_ids_to_tokens(top_indices.tolist())
121
- return decoded, top_probs
122
-
123
- elif task == 'sst':
124
- inputs = self.tokenizer(text, return_tensors='pt', padding=False, truncation=True).to(self.device)
125
-
126
- with torch.no_grad():
127
- outputs = self.model(**inputs)
128
- logits = outputs.logits
129
- probs = F.softmax(logits, dim=1).squeeze()
130
-
131
- labels = ["negative", "positive"]
132
- return labels, probs
133
-
134
- elif task == 'mnli':
135
- inputs = self.tokenizer(text, hypothesis, return_tensors='pt', padding=True, truncation=True).to(self.device)
136
-
137
- with torch.no_grad():
138
- outputs = self.model(**inputs)
139
- logits = outputs.logits
140
- probs = F.softmax(logits, dim=1).squeeze()
141
-
142
- labels = ["entailment", "neutral", "contradiction"]
143
- return labels, probs
144
-
145
- else:
146
- raise NotImplementedError(f"Task '{task}' not supported for RoBERTa")
147
-
148
-
149
- def get_all_grad_attn_matrix(self, task, sentence, hypothesis='', maskID = None):
150
- print(task, sentence, hypothesis)
151
- print('Tokenize')
152
- if task == 'mnli':
153
- inputs = self.tokenizer(sentence, hypothesis, return_tensors='pt', padding=False, truncation=True)
154
- elif task == 'mlm':
155
- inputs = self.tokenizer(sentence, return_tensors='pt', padding=False, truncation=True)
156
- if maskID is not None and 0 <= maskID < inputs['input_ids'].size(1):
157
- inputs['input_ids'][0][maskID] = self.tokenizer.mask_token_id
158
- else:
159
- inputs = self.tokenizer(sentence, return_tensors='pt', padding=False, truncation=True)
160
- tokens = self.tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
161
- print(tokens)
162
- inputs = {k: v.to(self.device) for k, v in inputs.items()}
163
-
164
- print('Input embeddings with grad')
165
- embedding_layer = self.model.roberta.embeddings.word_embeddings
166
- inputs_embeds = embedding_layer(inputs["input_ids"])
167
- inputs_embeds.requires_grad_()
168
-
169
- print('Forward pass')
170
- outputs = self.model.roberta(
171
- inputs_embeds=inputs_embeds,
172
- attention_mask=inputs["attention_mask"],
173
- output_attentions=True
174
- )
175
- attentions = outputs.attentions # list of [1, heads, seq, seq]
176
-
177
- print('Average attentions per layer')
178
- mean_attns = [a.squeeze(0).mean(dim=0).detach().cpu() for a in attentions]
179
-
180
- attn_matrices_all = []
181
- grad_matrices_all = []
182
- for target_layer in range(len(attentions)):
183
- grad_matrix, attn_matrix = self.get_grad_attn_matrix(inputs_embeds, attentions, mean_attns, target_layer)
184
- grad_matrices_all.append(grad_matrix.tolist())
185
- attn_matrices_all.append(attn_matrix.tolist())
186
- return grad_matrices_all, attn_matrices_all
187
-
188
- def get_grad_attn_matrix(self,inputs_embeds, attentions, mean_attns, target_layer):
189
-
190
- attn_matrix = mean_attns[target_layer]
191
- seq_len = attn_matrix.shape[0]
192
- attn_layer = attentions[target_layer].squeeze(0).mean(dim=0) # [seq, seq]
193
-
194
- print('Computing grad norms')
195
- grad_norms_list = []
196
- for k in range(seq_len):
197
- scalar = attn_layer[:, k].sum()
198
- grad = torch.autograd.grad(scalar, inputs_embeds, retain_graph=True)[0].squeeze(0)
199
- grad_norms = grad.norm(dim=1)
200
- grad_norms_list.append(grad_norms.unsqueeze(1))
201
-
202
- grad_matrix = torch.cat(grad_norms_list, dim=1)
203
- grad_matrix = grad_matrix[:seq_len, :seq_len]
204
- attn_matrix = attn_matrix[:seq_len, :seq_len]
205
-
206
-
207
- return grad_matrix, attn_matrix
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import RobertaTokenizer, RobertaForMaskedLM
2
+ import torch
3
+ import torch.nn.functional as F
4
+ from models import TransformerVisualizer
5
+ from transformers import (
6
+ RobertaForMaskedLM, RobertaForSequenceClassification
7
+ )
8
+ import os
9
+
10
+ CACHE_DIR = "/data/hf_cache"
11
+
12
+ class RoBERTaVisualizer(TransformerVisualizer):
13
+ def __init__(self, task):
14
+ super().__init__()
15
+ self.task = task
16
+
17
+
18
+
19
+ TOKENIZER = 'roberta-base'
20
+ LOCAL_PATH = os.path.join(CACHE_DIR, "tokenizers",TOKENIZER)
21
+
22
+ self.tokenizer = RobertaTokenizer.from_pretrained(LOCAL_PATH, local_files_only=True)
23
+ """
24
+ try:
25
+ self.tokenizer = RobertaTokenizer.from_pretrained(LOCAL_PATH, local_files_only=True)
26
+ except Exception as e:
27
+ self.tokenizer = RobertaTokenizer.from_pretrained(TOKENIZER)
28
+ self.tokenizer.save_pretrained(LOCAL_PATH)
29
+ """
30
+ if self.task == 'mlm':
31
+
32
+ MODEL = "roberta-base"
33
+ LOCAL_PATH = os.path.join(CACHE_DIR, "models",MODEL)
34
+
35
+ self.model = RobertaForMaskedLM.from_pretrained( LOCAL_PATH, local_files_only=True )
36
+ """
37
+ try:
38
+ self.model = RobertaForMaskedLM.from_pretrained( LOCAL_PATH, local_files_only=True )
39
+ except Exception as e:
40
+ self.model = RobertaForMaskedLM.from_pretrained( MODEL )
41
+ self.model.save_pretrained(LOCAL_PATH)
42
+ """
43
+ elif self.task == 'sst':
44
+
45
+
46
+ MODEL = 'textattack_roberta-base-SST-2'
47
+ LOCAL_PATH = os.path.join(CACHE_DIR, "models",MODEL)
48
+
49
+ self.model = RobertaForSequenceClassification.from_pretrained( LOCAL_PATH, local_files_only=True )
50
+ """
51
+ try:
52
+ self.model = RobertaForSequenceClassification.from_pretrained( LOCAL_PATH, local_files_only=True )
53
+ except Exception as e:
54
+ self.model = RobertaForSequenceClassification.from_pretrained( MODEL )
55
+ self.model.save_pretrained(LOCAL_PATH)
56
+ """
57
+
58
+ elif self.task == 'mnli':
59
+ MODEL = "roberta-large-mnli"
60
+ LOCAL_PATH = os.path.join(CACHE_DIR, "models",MODEL)
61
+
62
+
63
+ self.model = RobertaForSequenceClassification.from_pretrained( LOCAL_PATH, local_files_only=True)
64
+ """
65
+ try:
66
+ self.model = RobertaForSequenceClassification.from_pretrained( LOCAL_PATH, local_files_only=True)
67
+ except Exception as e:
68
+ self.model = RobertaForSequenceClassification.from_pretrained( MODEL)
69
+ self.model.save_pretrained(LOCAL_PATH)
70
+ """
71
+
72
+
73
+
74
+ self.model.to(self.device)
75
+ self.model.eval()
76
+ self.num_attention_layers = self.model.config.num_hidden_layers
77
+
78
+
79
+ def tokenize(self, text, hypothesis = ''):
80
+
81
+
82
+
83
+ if len(hypothesis) == 0:
84
+ encoded = self.tokenizer(text, return_tensors='pt', return_attention_mask=True,padding=False, truncation=True)
85
+ else:
86
+ encoded = self.tokenizer(text, hypothesis, return_tensors='pt', return_attention_mask=True,padding=False, truncation=True)
87
+
88
+ input_ids = encoded['input_ids'].to(self.device)
89
+ attention_mask = encoded['attention_mask'].to(self.device)
90
+ tokens = self.tokenizer.convert_ids_to_tokens(input_ids[0])
91
+ print('First time tokenizing:', tokens, len(tokens))
92
+
93
+ response = {
94
+ 'input_ids': input_ids,
95
+ 'attention_mask': attention_mask,
96
+ 'tokens': tokens
97
+ }
98
+ print(response)
99
+ return response
100
+
101
+ def predict(self, task, text, hypothesis='', maskID = None):
102
+
103
+
104
+
105
+ if task == 'mlm':
106
+ inputs = self.tokenizer(text, return_tensors='pt', padding=False, truncation=True)
107
+ if maskID is not None and 0 <= maskID < inputs['input_ids'].size(1):
108
+ inputs['input_ids'][0][maskID] = self.tokenizer.mask_token_id
109
+ mask_index = maskID
110
+ else:
111
+ raise ValueError(f"Invalid maskID {maskID} for input of length {inputs['input_ids'].size(1)}")
112
+ inputs = {k: v.to(self.device) for k, v in inputs.items()}
113
+
114
+ with torch.no_grad():
115
+ outputs = self.model(**inputs)
116
+ logits = outputs.logits
117
+
118
+ mask_logits = logits[0, mask_index]
119
+ top_probs, top_indices = torch.topk(F.softmax(mask_logits, dim=-1), 10)
120
+ decoded = self.tokenizer.convert_ids_to_tokens(top_indices.tolist())
121
+ return decoded, top_probs
122
+
123
+ elif task == 'sst':
124
+ inputs = self.tokenizer(text, return_tensors='pt', padding=False, truncation=True).to(self.device)
125
+
126
+ with torch.no_grad():
127
+ outputs = self.model(**inputs)
128
+ logits = outputs.logits
129
+ probs = F.softmax(logits, dim=1).squeeze()
130
+
131
+ labels = ["negative", "positive"]
132
+ return labels, probs
133
+
134
+ elif task == 'mnli':
135
+ inputs = self.tokenizer(text, hypothesis, return_tensors='pt', padding=True, truncation=True).to(self.device)
136
+
137
+ with torch.no_grad():
138
+ outputs = self.model(**inputs)
139
+ logits = outputs.logits
140
+ probs = F.softmax(logits, dim=1).squeeze()
141
+
142
+ labels = ["entailment", "neutral", "contradiction"]
143
+ return labels, probs
144
+
145
+ else:
146
+ raise NotImplementedError(f"Task '{task}' not supported for RoBERTa")
147
+
148
+
149
+ def get_all_grad_attn_matrix(self, task, sentence, hypothesis='', maskID = None):
150
+ print(task, sentence, hypothesis)
151
+ print('Tokenize')
152
+ if task == 'mnli':
153
+ inputs = self.tokenizer(sentence, hypothesis, return_tensors='pt', padding=False, truncation=True)
154
+ elif task == 'mlm':
155
+ inputs = self.tokenizer(sentence, return_tensors='pt', padding=False, truncation=True)
156
+ if maskID is not None and 0 <= maskID < inputs['input_ids'].size(1):
157
+ inputs['input_ids'][0][maskID] = self.tokenizer.mask_token_id
158
+ else:
159
+ inputs = self.tokenizer(sentence, return_tensors='pt', padding=False, truncation=True)
160
+ tokens = self.tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
161
+ print(tokens)
162
+ inputs = {k: v.to(self.device) for k, v in inputs.items()}
163
+
164
+ print('Input embeddings with grad')
165
+ embedding_layer = self.model.roberta.embeddings.word_embeddings
166
+ inputs_embeds = embedding_layer(inputs["input_ids"])
167
+ inputs_embeds.requires_grad_()
168
+
169
+ print('Forward pass')
170
+ outputs = self.model.roberta(
171
+ inputs_embeds=inputs_embeds,
172
+ attention_mask=inputs["attention_mask"],
173
+ output_attentions=True
174
+ )
175
+ attentions = outputs.attentions # list of [1, heads, seq, seq]
176
+
177
+ print('Average attentions per layer')
178
+ mean_attns = [a.squeeze(0).mean(dim=0).detach().cpu() for a in attentions]
179
+
180
+ attn_matrices_all = []
181
+ grad_matrices_all = []
182
+ for target_layer in range(len(attentions)):
183
+ grad_matrix, attn_matrix = self.get_grad_attn_matrix(inputs_embeds, attentions, mean_attns, target_layer)
184
+ grad_matrices_all.append(grad_matrix.tolist())
185
+ attn_matrices_all.append(attn_matrix.tolist())
186
+ return grad_matrices_all, attn_matrices_all
187
+
188
+ def get_grad_attn_matrix(self,inputs_embeds, attentions, mean_attns, target_layer):
189
+
190
+ attn_matrix = mean_attns[target_layer]
191
+ seq_len = attn_matrix.shape[0]
192
+
193
+ attn_matrix = torch.round(attn_matrix.float() * 100) / 100
194
+ attn_matrix = attn_matrix.to(torch.float16)
195
+
196
+ attn_layer = attentions[target_layer].squeeze(0).mean(dim=0) # [seq, seq]
197
+
198
+ print('Computing grad norms')
199
+ grad_norms_list = []
200
+ for k in range(seq_len):
201
+ scalar = attn_layer[:, k].sum()
202
+ grad = torch.autograd.grad(scalar, inputs_embeds, retain_graph=True)[0].squeeze(0)
203
+ grad_norms = grad.norm(dim=1)
204
+
205
+
206
+ grad_norms = torch.round(grad_norms.unsqueeze(1).float() * 100) / 100
207
+ grad_norms = grad_norms.to(torch.float16)
208
+
209
+ grad_norms_list.append(grad_norms)
210
+
211
+ grad_matrix = torch.cat(grad_norms_list, dim=1)
212
+ grad_matrix = grad_matrix[:seq_len, :seq_len]
213
+ attn_matrix = attn_matrix[:seq_len, :seq_len]
214
+
215
+ attn_matrix = torch.round(attn_matrix.float() * 100) / 100
216
+ attn_matrix = attn_matrix.to(torch.float16)
217
+
218
+ grad_matrix = torch.round(grad_matrix.float() * 100) / 100
219
+ grad_matrix = grad_matrix.to(torch.float16)
220
+
221
+
222
+
223
+
224
+ return grad_matrix, attn_matrix