yifan0sun commited on
Commit
6aed7ad
·
verified ·
1 Parent(s): 19c6bd3

Update DISTILLBERTmodel.py

Browse files
Files changed (1) hide show
  1. DISTILLBERTmodel.py +270 -257
DISTILLBERTmodel.py CHANGED
@@ -1,258 +1,271 @@
1
- import torch
2
- import torch.nn.functional as F
3
-
4
-
5
-
6
- import os
7
- from models import TransformerVisualizer
8
-
9
- from transformers import (
10
- DistilBertTokenizer,
11
- DistilBertForMaskedLM, DistilBertForSequenceClassification
12
- )
13
-
14
- CACHE_DIR = "/data/hf_cache"
15
- class DistilBERTVisualizer(TransformerVisualizer):
16
- def __init__(self, task):
17
- super().__init__()
18
- self.task = task
19
-
20
-
21
- TOKENIZER = 'distilbert-base-uncased'
22
- LOCAL_PATH = os.path.join(CACHE_DIR, "tokenizers",TOKENIZER.replace("/", "_"))
23
-
24
- self.tokenizer = DistilBertTokenizer.from_pretrained(LOCAL_PATH, local_files_only=True)
25
- """
26
- try:
27
- self.tokenizer = DistilBertTokenizer.from_pretrained(LOCAL_PATH, local_files_only=True)
28
- except Exception as e:
29
- self.tokenizer = DistilBertTokenizer.from_pretrained(TOKENIZER)
30
- self.tokenizer.save_pretrained(LOCAL_PATH)
31
- """
32
-
33
-
34
- print('finding model', self.task)
35
- if self.task == 'mlm':
36
-
37
- MODEL = 'distilbert-base-uncased'
38
- LOCAL_PATH = os.path.join(CACHE_DIR, "models",MODEL)
39
-
40
- self.model = DistilBertForMaskedLM.from_pretrained( LOCAL_PATH, local_files_only=True )
41
- """
42
- try:
43
- except Exception as e:
44
- self.model = DistilBertForMaskedLM.from_pretrained( MODEL )
45
- self.model.save_pretrained(LOCAL_PATH)
46
- """
47
- elif self.task == 'sst':
48
- MODEL = 'distilbert-base-uncased-finetuned-sst-2-english'
49
- LOCAL_PATH = os.path.join(CACHE_DIR, "models",MODEL)
50
-
51
- self.model = DistilBertForSequenceClassification.from_pretrained( LOCAL_PATH, local_files_only=True )
52
- """
53
- try:
54
- self.model = DistilBertForSequenceClassification.from_pretrained( LOCAL_PATH, local_files_only=True )
55
- except Exception as e:
56
- self.model = DistilBertForSequenceClassification.from_pretrained( MODEL )
57
- self.model.save_pretrained(LOCAL_PATH)
58
- """
59
-
60
- elif self.task == 'mnli':
61
- MODEL = "textattack_distilbert-base-uncased-MNLI"
62
- LOCAL_PATH = os.path.join(CACHE_DIR, "models",MODEL)
63
-
64
-
65
- self.model = DistilBertForSequenceClassification.from_pretrained( LOCAL_PATH, local_files_only=True)
66
- """
67
- try:
68
- self.model = DistilBertForSequenceClassification.from_pretrained( LOCAL_PATH, local_files_only=True)
69
- except Exception as e:
70
- self.model = DistilBertForSequenceClassification.from_pretrained( MODEL)
71
- self.model.save_pretrained(LOCAL_PATH)
72
- """
73
-
74
-
75
-
76
- else:
77
- raise ValueError(f"Unsupported task: {self.task}")
78
-
79
-
80
-
81
-
82
-
83
-
84
- self.model.eval()
85
- self.num_attention_layers = len(self.model.distilbert.transformer.layer)
86
-
87
- self.model.to(self.device)
88
-
89
-
90
-
91
- def tokenize(self, text, hypothesis = ''):
92
-
93
-
94
-
95
- if len(hypothesis) == 0:
96
- encoded = self.tokenizer(text, return_tensors='pt', return_attention_mask=True,padding=False, truncation=True)
97
- else:
98
- encoded = self.tokenizer(text, hypothesis, return_tensors='pt', return_attention_mask=True,padding=False, truncation=True)
99
-
100
-
101
- input_ids = encoded['input_ids'].to(self.device)
102
- attention_mask = encoded['attention_mask'].to(self.device)
103
- tokens = self.tokenizer.convert_ids_to_tokens(input_ids[0])
104
- return {
105
- 'input_ids': input_ids,
106
- 'attention_mask': attention_mask,
107
- 'tokens': tokens
108
- }
109
-
110
- def predict(self, task, text, hypothesis='', maskID = 0):
111
-
112
- if task == 'mlm':
113
- inputs = self.tokenizer(text, return_tensors='pt', padding=False, truncation=True)
114
- if maskID is not None and 0 <= maskID < inputs['input_ids'].size(1):
115
- inputs['input_ids'][0][maskID] = self.tokenizer.mask_token_id
116
- mask_index = maskID
117
- else:
118
- raise ValueError(f"Invalid maskID {maskID} for input of length {inputs['input_ids'].size(1)}")
119
- inputs = {k: v.to(self.device) for k, v in inputs.items()}
120
-
121
- with torch.no_grad():
122
- outputs = self.model(**inputs)
123
- logits = outputs.logits
124
-
125
- mask_logits = logits[0, mask_index]
126
- top_probs, top_indices = torch.topk(F.softmax(mask_logits, dim=-1), 10)
127
- decoded = self.tokenizer.convert_ids_to_tokens(top_indices.tolist())
128
- return decoded, top_probs
129
-
130
- elif task == 'sst':
131
- inputs = self.tokenizer(text, return_tensors='pt', padding=False, truncation=True).to(self.device)
132
-
133
- with torch.no_grad():
134
- outputs = self.model(**inputs)
135
- logits = outputs.logits
136
- probs = F.softmax(logits, dim=1).squeeze()
137
-
138
- labels = ["negative", "positive"]
139
- return labels, probs
140
- elif task == 'mnli':
141
- inputs = self.tokenizer(text, hypothesis, return_tensors='pt', padding=True, truncation=True).to(self.device)
142
-
143
- with torch.no_grad():
144
- outputs = self.model(**inputs)
145
- logits = outputs.logits
146
- probs = F.softmax(logits, dim=1).squeeze()
147
-
148
- labels = ["entailment", "neutral", "contradiction"]
149
- return labels, probs
150
-
151
- else:
152
- raise NotImplementedError(f"Task '{task}' not supported for DistilBERT")
153
-
154
- def get_all_grad_attn_matrix(self, task, sentence, hypothesis='', maskID = 0):
155
- print(task, sentence,hypothesis)
156
-
157
- print('Tokenize')
158
- if task == 'mnli':
159
- inputs = self.tokenizer(sentence, hypothesis, return_tensors='pt', padding=False, truncation=True)
160
- elif task == 'mlm':
161
- inputs = self.tokenizer(sentence, return_tensors='pt', padding=False, truncation=True)
162
- if maskID is not None and 0 <= maskID < inputs['input_ids'].size(1):
163
- inputs['input_ids'][0][maskID] = self.tokenizer.mask_token_id
164
- else:
165
- print(f"Invalid maskID {maskID} for input of length {inputs['input_ids'].size(1)}")
166
- raise ValueError(f"Invalid maskID {maskID} for input of length {inputs['input_ids'].size(1)}")
167
- inputs = {k: v.to(self.device) for k, v in inputs.items()}
168
- else:
169
- inputs = self.tokenizer(sentence, return_tensors='pt', padding=False, truncation=True)
170
- tokens = self.tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
171
- print(tokens)
172
- inputs = {k: v.to(self.device) for k, v in inputs.items()}
173
-
174
- print('Input embeddings with grad')
175
- embedding_layer = self.model.distilbert.embeddings.word_embeddings
176
- inputs_embeds = embedding_layer(inputs["input_ids"])
177
- inputs_embeds.requires_grad_()
178
-
179
- print('Forward pass')
180
- outputs = self.model.distilbert(
181
- inputs_embeds=inputs_embeds,
182
- attention_mask=inputs["attention_mask"],
183
- output_attentions=True,
184
- )
185
- attentions = outputs.attentions # list of [1, heads, seq, seq]
186
-
187
- print('Mean attentions per layer')
188
- mean_attns = [a.squeeze(0).mean(dim=0).detach().cpu() for a in attentions]
189
-
190
-
191
-
192
- attn_matrices_all = []
193
- grad_matrices_all = []
194
- for target_layer in range(len(attentions)):
195
- grad_matrix, attn_matrix = self.get_grad_attn_matrix(inputs_embeds, attentions, mean_attns, target_layer)
196
- grad_matrices_all.append(grad_matrix.tolist())
197
- attn_matrices_all.append(attn_matrix.tolist())
198
- return grad_matrices_all, attn_matrices_all
199
-
200
-
201
- def get_grad_attn_matrix(self,inputs_embeds, attentions, mean_attns, target_layer):
202
- attn_matrix = mean_attns[target_layer]
203
- seq_len = attn_matrix.shape[0]
204
- attn_layer = attentions[target_layer].squeeze(0).mean(dim=0)
205
-
206
- print('Computing grad norms')
207
- grad_norms_list = []
208
- for k in range(seq_len):
209
- scalar = attn_layer[:, k].sum()
210
- grad = torch.autograd.grad(scalar, inputs_embeds, retain_graph=True)[0].squeeze(0)
211
- grad_norms = grad.norm(dim=1)
212
- grad_norms_list.append(grad_norms.unsqueeze(1))
213
-
214
- grad_matrix = torch.cat(grad_norms_list, dim=1)
215
- grad_matrix = grad_matrix[:seq_len, :seq_len]
216
- attn_matrix = attn_matrix[:seq_len, :seq_len]
217
-
218
- return grad_matrix, attn_matrix
219
-
220
-
221
-
222
- if __name__ == "__main__":
223
- import sys
224
-
225
- MODEL_CLASSES = {
226
- "bert": BERTVisualizer,
227
- "roberta": RoBERTaVisualizer,
228
- "distilbert": DistilBERTVisualizer,
229
- "bart": BARTVisualizer,
230
- }
231
-
232
- # Parse command-line args or fallback to default
233
- model_name = sys.argv[1] if len(sys.argv) > 1 else "bert"
234
- text = " ".join(sys.argv[2:]) if len(sys.argv) > 2 else "The quick brown fox jumps over the lazy dog."
235
-
236
- if model_name.lower() not in MODEL_CLASSES:
237
- print(f"Supported models: {list(MODEL_CLASSES.keys())}")
238
- sys.exit(1)
239
-
240
- # Instantiate the visualizer
241
- visualizer_class = MODEL_CLASSES[model_name.lower()]
242
- visualizer = visualizer_class()
243
-
244
- # Tokenize
245
- token_info = visualizer.tokenize(text)
246
-
247
- # Report
248
- print(f"\nModel: {model_name}")
249
- print(f"Num attention layers: {visualizer.num_attention_layers}")
250
- print(f"Tokens: {token_info['tokens']}")
251
- print(f"Input IDs: {token_info['input_ids'].tolist()}")
252
- print(f"Attention mask: {token_info['attention_mask'].tolist()}")
253
-
254
-
255
- """
256
- usage for debug:
257
- python your_file.py bert "The rain in Spain falls mainly on the plain."
 
 
 
 
 
 
 
 
 
 
 
 
 
258
  """
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+
4
+
5
+
6
+ import os, time
7
+ from models import TransformerVisualizer
8
+
9
+ from transformers import (
10
+ DistilBertTokenizer,
11
+ DistilBertForMaskedLM, DistilBertForSequenceClassification
12
+ )
13
+
14
+ CACHE_DIR = "/data/hf_cache"
15
+ class DistilBERTVisualizer(TransformerVisualizer):
16
+ def __init__(self, task):
17
+ super().__init__()
18
+ self.task = task
19
+
20
+
21
+ TOKENIZER = 'distilbert-base-uncased'
22
+ LOCAL_PATH = os.path.join(CACHE_DIR, "tokenizers",TOKENIZER.replace("/", "_"))
23
+
24
+ self.tokenizer = DistilBertTokenizer.from_pretrained(LOCAL_PATH, local_files_only=True)
25
+ """
26
+ try:
27
+ self.tokenizer = DistilBertTokenizer.from_pretrained(LOCAL_PATH, local_files_only=True)
28
+ except Exception as e:
29
+ self.tokenizer = DistilBertTokenizer.from_pretrained(TOKENIZER)
30
+ self.tokenizer.save_pretrained(LOCAL_PATH)
31
+ """
32
+
33
+
34
+ print('finding model', self.task)
35
+ if self.task == 'mlm':
36
+
37
+ MODEL = 'distilbert-base-uncased'
38
+ LOCAL_PATH = os.path.join(CACHE_DIR, "models",MODEL)
39
+
40
+ self.model = DistilBertForMaskedLM.from_pretrained( LOCAL_PATH, local_files_only=True )
41
+ """
42
+ try:
43
+ except Exception as e:
44
+ self.model = DistilBertForMaskedLM.from_pretrained( MODEL )
45
+ self.model.save_pretrained(LOCAL_PATH)
46
+ """
47
+ elif self.task == 'sst':
48
+ MODEL = 'distilbert-base-uncased-finetuned-sst-2-english'
49
+ LOCAL_PATH = os.path.join(CACHE_DIR, "models",MODEL)
50
+
51
+ self.model = DistilBertForSequenceClassification.from_pretrained( LOCAL_PATH, local_files_only=True )
52
+ """
53
+ try:
54
+ self.model = DistilBertForSequenceClassification.from_pretrained( LOCAL_PATH, local_files_only=True )
55
+ except Exception as e:
56
+ self.model = DistilBertForSequenceClassification.from_pretrained( MODEL )
57
+ self.model.save_pretrained(LOCAL_PATH)
58
+ """
59
+
60
+ elif self.task == 'mnli':
61
+ MODEL = "textattack_distilbert-base-uncased-MNLI"
62
+ LOCAL_PATH = os.path.join(CACHE_DIR, "models",MODEL)
63
+
64
+
65
+ self.model = DistilBertForSequenceClassification.from_pretrained( LOCAL_PATH, local_files_only=True)
66
+ """
67
+ try:
68
+ self.model = DistilBertForSequenceClassification.from_pretrained( LOCAL_PATH, local_files_only=True)
69
+ except Exception as e:
70
+ self.model = DistilBertForSequenceClassification.from_pretrained( MODEL)
71
+ self.model.save_pretrained(LOCAL_PATH)
72
+ """
73
+
74
+
75
+
76
+ else:
77
+ raise ValueError(f"Unsupported task: {self.task}")
78
+
79
+
80
+
81
+
82
+
83
+
84
+ self.model.to(self.device)
85
+ # Force materialization of all layers (avoids meta device errors)
86
+ with torch.no_grad():
87
+ dummy_ids = torch.tensor([[0, 1]], device=self.device)
88
+ dummy_mask = torch.tensor([[1, 1]], device=self.device)
89
+ _ = self.model(input_ids=dummy_ids, attention_mask=dummy_mask)
90
+ self.model.eval()
91
+ self.num_attention_layers = len(self.model.distilbert.transformer.layer)
92
+
93
+
94
+
95
+
96
+ def tokenize(self, text, hypothesis = ''):
97
+
98
+
99
+
100
+ if len(hypothesis) == 0:
101
+ encoded = self.tokenizer(text, return_tensors='pt', return_attention_mask=True,padding=False, truncation=True)
102
+ else:
103
+ encoded = self.tokenizer(text, hypothesis, return_tensors='pt', return_attention_mask=True,padding=False, truncation=True)
104
+
105
+
106
+ input_ids = encoded['input_ids'].to(self.device)
107
+ attention_mask = encoded['attention_mask'].to(self.device)
108
+ tokens = self.tokenizer.convert_ids_to_tokens(input_ids[0])
109
+ return {
110
+ 'input_ids': input_ids,
111
+ 'attention_mask': attention_mask,
112
+ 'tokens': tokens
113
+ }
114
+
115
+ def predict(self, task, text, hypothesis='', maskID = 0):
116
+
117
+ if task == 'mlm':
118
+ inputs = self.tokenizer(text, return_tensors='pt', padding=False, truncation=True)
119
+ if maskID is not None and 0 <= maskID < inputs['input_ids'].size(1):
120
+ inputs['input_ids'][0][maskID] = self.tokenizer.mask_token_id
121
+ mask_index = maskID
122
+ else:
123
+ raise ValueError(f"Invalid maskID {maskID} for input of length {inputs['input_ids'].size(1)}")
124
+ inputs = {k: v.to(self.device) for k, v in inputs.items()}
125
+
126
+ with torch.no_grad():
127
+ outputs = self.model(**inputs)
128
+ logits = outputs.logits
129
+
130
+ mask_logits = logits[0, mask_index]
131
+ top_probs, top_indices = torch.topk(F.softmax(mask_logits, dim=-1), 10)
132
+ decoded = self.tokenizer.convert_ids_to_tokens(top_indices.tolist())
133
+ return decoded, top_probs
134
+
135
+ elif task == 'sst':
136
+ inputs = self.tokenizer(text, return_tensors='pt', padding=False, truncation=True).to(self.device)
137
+
138
+ with torch.no_grad():
139
+ outputs = self.model(**inputs)
140
+ logits = outputs.logits
141
+ probs = F.softmax(logits, dim=1).squeeze()
142
+
143
+ labels = ["negative", "positive"]
144
+ return labels, probs
145
+ elif task == 'mnli':
146
+ inputs = self.tokenizer(text, hypothesis, return_tensors='pt', padding=True, truncation=True).to(self.device)
147
+
148
+ with torch.no_grad():
149
+ outputs = self.model(**inputs)
150
+ logits = outputs.logits
151
+ probs = F.softmax(logits, dim=1).squeeze()
152
+
153
+ labels = ["entailment", "neutral", "contradiction"]
154
+ return labels, probs
155
+
156
+ else:
157
+ raise NotImplementedError(f"Task '{task}' not supported for DistilBERT")
158
+
159
+ def get_all_grad_attn_matrix(self, task, sentence, hypothesis='', maskID = 0):
160
+ print(task, sentence,hypothesis)
161
+
162
+ print('Tokenize')
163
+ if task == 'mnli':
164
+ inputs = self.tokenizer(sentence, hypothesis, return_tensors='pt', padding=False, truncation=True)
165
+ elif task == 'mlm':
166
+ inputs = self.tokenizer(sentence, return_tensors='pt', padding=False, truncation=True)
167
+ if maskID is not None and 0 <= maskID < inputs['input_ids'].size(1):
168
+ inputs['input_ids'][0][maskID] = self.tokenizer.mask_token_id
169
+ else:
170
+ print(f"Invalid maskID {maskID} for input of length {inputs['input_ids'].size(1)}")
171
+ raise ValueError(f"Invalid maskID {maskID} for input of length {inputs['input_ids'].size(1)}")
172
+ inputs = {k: v.to(self.device) for k, v in inputs.items()}
173
+ else:
174
+ inputs = self.tokenizer(sentence, return_tensors='pt', padding=False, truncation=True)
175
+ tokens = self.tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
176
+ print(tokens)
177
+ inputs = {k: v.to(self.device) for k, v in inputs.items()}
178
+
179
+ print('Input embeddings with grad')
180
+ embedding_layer = self.model.distilbert.embeddings.word_embeddings
181
+ inputs_embeds = embedding_layer(inputs["input_ids"]).to(self.device)
182
+ inputs_embeds.requires_grad_()
183
+
184
+ print('Forward pass')
185
+ outputs = self.model.distilbert(
186
+ inputs_embeds=inputs_embeds,
187
+ attention_mask=inputs["attention_mask"],
188
+ output_attentions=True,
189
+ )
190
+
191
+ attentions = outputs.attentions # list of [1, heads, seq, seq]
192
+
193
+ print('Average attentions per layer')
194
+ mean_attns = [a.squeeze(0).mean(dim=0).detach().cpu() for a in attentions]
195
+
196
+
197
+ def scalar_outputs(inputs_embeds):
198
+
199
+ outputs = self.model.distilbert(
200
+ inputs_embeds=inputs_embeds,
201
+ attention_mask=inputs["attention_mask"],
202
+ output_attentions=True
203
+ )
204
+ attentions = outputs.attentions
205
+ attentions_condensed = [a.mean(dim=0).mean(dim=0).sum(dim=0) for a in attentions]
206
+ attentions_condensed= torch.vstack(attentions_condensed)
207
+ return attentions_condensed
208
+
209
+ start = time.time()
210
+ jac = torch.autograd.functional.jacobian(scalar_outputs, inputs_embeds).to(torch.float16)
211
+ print('time to get jacobian: ', time.time()-start)
212
+ jac = jac.norm(dim=-1).squeeze(dim=2)
213
+ seq_len = jac.shape[0]
214
+ grad_matrices_all = [jac[ii,:,:].tolist() for ii in range(seq_len)]
215
+
216
+
217
+ attn_matrices_all = []
218
+ for target_layer in range(len(attentions)):
219
+ #grad_matrix, attn_matrix = self.get_grad_attn_matrix(inputs_embeds, attentions, mean_attns, target_layer)
220
+
221
+ attn_matrix = mean_attns[target_layer]
222
+ seq_len = attn_matrix.shape[0]
223
+ attn_matrix = attn_matrix[:seq_len, :seq_len]
224
+ attn_matrices_all.append(attn_matrix.tolist())
225
+
226
+
227
+
228
+ return grad_matrices_all, attn_matrices_all
229
+
230
+
231
+
232
+
233
+
234
+
235
+ if __name__ == "__main__":
236
+ import sys
237
+
238
+ MODEL_CLASSES = {
239
+ "bert": BERTVisualizer,
240
+ "roberta": RoBERTaVisualizer,
241
+ "distilbert": DistilBERTVisualizer,
242
+ "bart": BARTVisualizer,
243
+ }
244
+
245
+ # Parse command-line args or fallback to default
246
+ model_name = sys.argv[1] if len(sys.argv) > 1 else "bert"
247
+ text = " ".join(sys.argv[2:]) if len(sys.argv) > 2 else "The quick brown fox jumps over the lazy dog."
248
+
249
+ if model_name.lower() not in MODEL_CLASSES:
250
+ print(f"Supported models: {list(MODEL_CLASSES.keys())}")
251
+ sys.exit(1)
252
+
253
+ # Instantiate the visualizer
254
+ visualizer_class = MODEL_CLASSES[model_name.lower()]
255
+ visualizer = visualizer_class()
256
+
257
+ # Tokenize
258
+ token_info = visualizer.tokenize(text)
259
+
260
+ # Report
261
+ print(f"\nModel: {model_name}")
262
+ print(f"Num attention layers: {visualizer.num_attention_layers}")
263
+ print(f"Tokens: {token_info['tokens']}")
264
+ print(f"Input IDs: {token_info['input_ids'].tolist()}")
265
+ print(f"Attention mask: {token_info['attention_mask'].tolist()}")
266
+
267
+
268
+ """
269
+ usage for debug:
270
+ python your_file.py bert "The rain in Spain falls mainly on the plain."
271
  """