yifan0sun commited on
Commit
993b547
·
1 Parent(s): 6173907
Files changed (6) hide show
  1. BERTmodel.py +257 -0
  2. DISTILLBERTmodel.py +210 -0
  3. ROBERTAmodel.py +155 -0
  4. models.py +34 -0
  5. requirements.txt +4 -0
  6. server.py +215 -0
BERTmodel.py ADDED
@@ -0,0 +1,257 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import BertTokenizer, BertModel
2
+ import torch
3
+ import matplotlib.pyplot as plt
4
+ import torch.nn as nn
5
+ from transformers import BertTokenizer, BertModel, DataCollatorForLanguageModeling
6
+ from datasets import load_dataset
7
+ from torch.utils.data import DataLoader
8
+ import torch.nn.functional as F
9
+
10
+ from transformers import (
11
+ BertTokenizer, BertModel,
12
+ DataCollatorForLanguageModeling
13
+ )
14
+ import torch.optim as optim
15
+
16
+ import os
17
+ from transformers.models.bert.modeling_bert import BertOnlyMLMHead
18
+ from models import TransformerVisualizer
19
+
20
+ from transformers import (
21
+ BertTokenizer,
22
+ BertForMaskedLM,
23
+ BertForSequenceClassification,
24
+ BertForQuestionAnswering,
25
+ )
26
+ import torch
27
+ import torch.nn.functional as F
28
+ from models import TransformerVisualizer
29
+
30
+
31
+
32
+ class BERTVisualizer(TransformerVisualizer):
33
+ def __init__(self,task):
34
+ super().__init__()
35
+ self.task = task
36
+ self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
37
+ print('finding model', self.task)
38
+ if self.task == 'mlm':
39
+ self.model = BertForMaskedLM.from_pretrained(
40
+ "bert-base-uncased",
41
+ attn_implementation="eager" # fallback to standard attention
42
+ ).to(self.device)
43
+ elif self.task == 'sst':
44
+ self.model = BertForSequenceClassification.from_pretrained("textattack/bert-base-uncased-SST-2",device_map=None)
45
+ elif self.task == 'mnli':
46
+ self.model = BertForSequenceClassification.from_pretrained("textattack/bert-base-uncased-MNLI", device_map=None)
47
+ else:
48
+ raise ValueError(f"Unsupported task: {self.task}")
49
+ print('model found')
50
+ #self.model.to(self.device)
51
+ print('self device junk')
52
+ self.model.eval()
53
+ print('self model eval')
54
+ self.num_attention_layers = len(self.model.bert.encoder.layer)
55
+ print('init finished')
56
+
57
+ def tokenize(self, text, hypothesis = ''):
58
+ print('TTTokenize',text,'H:', hypothesis)
59
+ if len(hypothesis) == 0:
60
+ encoded = self.tokenizer(text, return_tensors='pt', return_attention_mask=True)
61
+ else:
62
+ encoded = self.tokenizer(text, hypothesis, return_tensors='pt', return_attention_mask=True)
63
+ input_ids = encoded['input_ids'].to(self.device)
64
+ attention_mask = encoded['attention_mask'].to(self.device)
65
+ tokens = self.tokenizer.convert_ids_to_tokens(input_ids[0])
66
+ return {
67
+ 'input_ids': input_ids,
68
+ 'attention_mask': attention_mask,
69
+ 'tokens': tokens
70
+ }
71
+
72
+
73
+ def predict(self, task, text, hypothesis='', maskID = None):
74
+
75
+ print(task,text,hypothesis)
76
+
77
+
78
+
79
+ if task == 'mlm':
80
+
81
+ # Tokenize and find [MASK] position
82
+ print('Tokenize and find [MASK] position')
83
+ inputs = self.tokenizer(text, return_tensors='pt', padding=False, truncation=True)
84
+ if maskID is not None and 0 <= maskID < inputs['input_ids'].size(1):
85
+ inputs['input_ids'][0][maskID] = self.tokenizer.mask_token_id
86
+ mask_index = maskID
87
+ else:
88
+ raise ValueError(f"Invalid maskID {maskID} for input length {inputs['input_ids'].size(1)}")
89
+
90
+
91
+
92
+ # Move to device
93
+ inputs = {k: v.to(self.device) for k, v in inputs.items()}
94
+
95
+ # Get embeddings
96
+ embedding_layer = self.model.bert.embeddings.word_embeddings
97
+ inputs_embeds = embedding_layer(inputs['input_ids'])
98
+
99
+ # Forward through BERT encoder
100
+
101
+ hidden_states = self.model.bert(inputs_embeds=inputs_embeds,
102
+ attention_mask=inputs['attention_mask']).last_hidden_state
103
+
104
+ # Predict logits via MLM head
105
+ logits = self.model.cls(hidden_states)
106
+ mask_logits = logits[0, mask_index]
107
+
108
+ top_probs, top_indices = torch.topk(mask_logits, k=10, dim=-1)
109
+ top_probs = F.softmax(top_probs, dim=-1)
110
+ decoded = self.tokenizer.convert_ids_to_tokens(top_indices.tolist())
111
+
112
+ return decoded, top_probs
113
+
114
+ elif task == 'sst':
115
+ print('input')
116
+ inputs = self.tokenizer(text, return_tensors='pt', padding=False, truncation=True).to(self.device)
117
+ print('output')
118
+ with torch.no_grad():
119
+ outputs = self.model(**inputs)
120
+ logits = outputs.logits # shape: [1, 2]
121
+ probs = F.softmax(logits, dim=1).squeeze()
122
+
123
+ labels = ["negative", "positive"]
124
+ print('ready to return')
125
+ return labels, probs
126
+
127
+ elif task == 'mnli':
128
+ inputs = self.tokenizer(text, hypothesis, return_tensors='pt', padding=True, truncation=True).to(self.device)
129
+
130
+ with torch.no_grad():
131
+ outputs = self.model(**inputs)
132
+ logits = outputs.logits
133
+ probs = F.softmax(logits, dim=1).squeeze()
134
+
135
+ labels = ["entailment", "neutral", "contradiction"]
136
+ return labels, probs
137
+
138
+
139
+ def get_all_grad_attn_matrix(self, task, sentence, hypothesis='', maskID = 0):
140
+
141
+ print('GET GRAD:', task,'sentence',sentence, 'hypothesis', hypothesis)
142
+
143
+
144
+
145
+ print('Tokenize')
146
+ if task == 'mnli':
147
+ inputs = self.tokenizer(sentence, hypothesis, return_tensors='pt', padding=False, truncation=True)
148
+ elif task == 'mlm':
149
+ inputs = self.tokenizer(sentence, return_tensors='pt', padding=False, truncation=True)
150
+ if maskID is not None and 0 <= maskID < inputs['input_ids'].size(1):
151
+ inputs['input_ids'][0][maskID] = self.tokenizer.mask_token_id
152
+ else:
153
+ raise ValueError(f"Invalid maskID {maskID} for input length {inputs['input_ids'].size(1)}")
154
+ else:
155
+ inputs = self.tokenizer(sentence, return_tensors='pt', padding=False, truncation=True)
156
+ tokens = self.tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
157
+
158
+ inputs = {k: v.to(self.device) for k, v in inputs.items()}
159
+ print(inputs['input_ids'].shape)
160
+ print(tokens,len(tokens))
161
+ print('Input embeddings with grad')
162
+ embedding_layer = self.model.bert.embeddings.word_embeddings
163
+ inputs_embeds = embedding_layer(inputs["input_ids"])
164
+ inputs_embeds.requires_grad_()
165
+
166
+ print('Forward pass')
167
+ outputs = self.model.bert(
168
+ inputs_embeds=inputs_embeds,
169
+ attention_mask=inputs["attention_mask"],
170
+ output_attentions=True
171
+ )
172
+ attentions = outputs.attentions # list of [1, heads, seq, seq]
173
+
174
+ print('Optional: store average attentions per layer')
175
+ mean_attns = [a.squeeze(0).mean(dim=0).detach().cpu() for a in attentions]
176
+
177
+ attn_matrices_all = []
178
+ grad_matrices_all = []
179
+ for target_layer in range(len(attentions)):
180
+ grad_matrix, attn_matrix = self.get_grad_attn_matrix(inputs_embeds, attentions, mean_attns, target_layer)
181
+ grad_matrices_all.append(grad_matrix.tolist())
182
+ attn_matrices_all.append(attn_matrix.tolist())
183
+ return grad_matrices_all, attn_matrices_all
184
+
185
+ def get_grad_attn_matrix(self,inputs_embeds, attentions, mean_attns, target_layer):
186
+
187
+
188
+ attn_matrix = mean_attns[target_layer]
189
+ seq_len = attn_matrix.shape[0]
190
+ attn_layer = attentions[target_layer].squeeze(0).mean(dim=0) # [seq, seq]
191
+
192
+
193
+ print('computing gradnorms now')
194
+
195
+
196
+ grad_norms_list = []
197
+
198
+ for k in range(seq_len):
199
+ scalar = attn_layer[:, k].sum() # ✅ total attention received by token k
200
+
201
+ # Compute gradient: d scalar / d inputs_embeds
202
+
203
+ grad = torch.autograd.grad(scalar, inputs_embeds, retain_graph=True)[0].squeeze(0) # shape: [seq, hidden]
204
+
205
+ grad_norms = grad.norm(dim=1) # shape: [seq]
206
+
207
+ grad_norms_list.append(grad_norms.unsqueeze(1)) # shape: [seq, 1]
208
+
209
+
210
+ grad_matrix = torch.cat(grad_norms_list, dim=1) # shape: [seq, seq]
211
+ print('ready to send!')
212
+
213
+ grad_matrix = grad_matrix[:seq_len, :seq_len]
214
+ attn_matrix = attn_matrix[:seq_len, :seq_len]
215
+
216
+ #tokens = self.tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
217
+
218
+ return grad_matrix, attn_matrix
219
+
220
+
221
+ if __name__ == "__main__":
222
+ import sys
223
+
224
+ MODEL_CLASSES = {
225
+ "bert": BERTVisualizer,
226
+ "roberta": RoBERTaVisualizer,
227
+ "distilbert": DistilBERTVisualizer,
228
+ "bart": BARTVisualizer,
229
+ }
230
+
231
+ # Parse command-line args or fallback to default
232
+ model_name = sys.argv[1] if len(sys.argv) > 1 else "bert"
233
+ text = " ".join(sys.argv[2:]) if len(sys.argv) > 2 else "The quick brown fox jumps over the lazy dog."
234
+
235
+ if model_name.lower() not in MODEL_CLASSES:
236
+ print(f"Supported models: {list(MODEL_CLASSES.keys())}")
237
+ sys.exit(1)
238
+
239
+ # Instantiate the visualizer
240
+ visualizer_class = MODEL_CLASSES[model_name.lower()]
241
+ visualizer = visualizer_class()
242
+
243
+ # Tokenize
244
+ token_info = visualizer.tokenize(text)
245
+
246
+ # Report
247
+ print(f"\nModel: {model_name}")
248
+ print(f"Num attention layers: {visualizer.num_attention_layers}")
249
+ print(f"Tokens: {token_info['tokens']}")
250
+ print(f"Input IDs: {token_info['input_ids'].tolist()}")
251
+ print(f"Attention mask: {token_info['attention_mask'].tolist()}")
252
+
253
+
254
+ """
255
+ usage for debug:
256
+ python your_file.py bert "The rain in Spain falls mainly on the plain."
257
+ """
DISTILLBERTmodel.py ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import matplotlib.pyplot as plt
3
+ import torch.nn as nn
4
+ from datasets import load_dataset
5
+ from torch.utils.data import DataLoader
6
+ import torch.nn.functional as F
7
+
8
+
9
+
10
+ import os
11
+ from transformers import DistilBertModel, DistilBertTokenizer
12
+ from models import TransformerVisualizer
13
+
14
+ from transformers import (
15
+ DistilBertTokenizer,
16
+ DistilBertForMaskedLM, DistilBertForSequenceClassification
17
+ )
18
+
19
+ class DistilBERTVisualizer(TransformerVisualizer):
20
+ def __init__(self, task):
21
+ super().__init__()
22
+ self.task = task
23
+ self.tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
24
+ if self.task == 'mlm':
25
+ self.model = DistilBertForMaskedLM.from_pretrained('distilbert-base-uncased')
26
+ elif self.task == 'sst':
27
+ self.model = DistilBertForSequenceClassification.from_pretrained('distilbert-base-uncased-finetuned-sst-2-english')
28
+ elif self.task == 'mnli':
29
+ self.model = DistilBertForSequenceClassification.from_pretrained("textattack/distilbert-base-uncased-MNLI")
30
+
31
+
32
+ else:
33
+ raise NotImplementedError("Task not supported for DistilBERT")
34
+
35
+
36
+ self.model.eval()
37
+ self.num_attention_layers = len(self.model.distilbert.transformer.layer)
38
+
39
+ self.model.to(self.device)
40
+
41
+
42
+
43
+ def tokenize(self, text, hypothesis = ''):
44
+
45
+
46
+
47
+ if len(hypothesis) == 0:
48
+ encoded = self.tokenizer(text, return_tensors='pt', return_attention_mask=True,padding=False, truncation=True)
49
+ else:
50
+ encoded = self.tokenizer(text, hypothesis, return_tensors='pt', return_attention_mask=True,padding=False, truncation=True)
51
+
52
+
53
+ input_ids = encoded['input_ids'].to(self.device)
54
+ attention_mask = encoded['attention_mask'].to(self.device)
55
+ tokens = self.tokenizer.convert_ids_to_tokens(input_ids[0])
56
+ return {
57
+ 'input_ids': input_ids,
58
+ 'attention_mask': attention_mask,
59
+ 'tokens': tokens
60
+ }
61
+
62
+ def predict(self, task, text, hypothesis='', maskID = 0):
63
+
64
+ if task == 'mlm':
65
+ inputs = self.tokenizer(text, return_tensors='pt', padding=False, truncation=True)
66
+ if maskID is not None and 0 <= maskID < inputs['input_ids'].size(1):
67
+ inputs['input_ids'][0][maskID] = self.tokenizer.mask_token_id
68
+ mask_index = maskID
69
+ else:
70
+ raise ValueError(f"Invalid maskID {maskID} for input of length {inputs['input_ids'].size(1)}")
71
+ inputs = {k: v.to(self.device) for k, v in inputs.items()}
72
+
73
+ with torch.no_grad():
74
+ outputs = self.model(**inputs)
75
+ logits = outputs.logits
76
+
77
+ mask_logits = logits[0, mask_index]
78
+ top_probs, top_indices = torch.topk(F.softmax(mask_logits, dim=-1), 10)
79
+ decoded = self.tokenizer.convert_ids_to_tokens(top_indices.tolist())
80
+ return decoded, top_probs
81
+
82
+ elif task == 'sst':
83
+ inputs = self.tokenizer(text, return_tensors='pt', padding=False, truncation=True).to(self.device)
84
+
85
+ with torch.no_grad():
86
+ outputs = self.model(**inputs)
87
+ logits = outputs.logits
88
+ probs = F.softmax(logits, dim=1).squeeze()
89
+
90
+ labels = ["negative", "positive"]
91
+ return labels, probs
92
+ elif task == 'mnli':
93
+ inputs = self.tokenizer(text, hypothesis, return_tensors='pt', padding=True, truncation=True).to(self.device)
94
+
95
+ with torch.no_grad():
96
+ outputs = self.model(**inputs)
97
+ logits = outputs.logits
98
+ probs = F.softmax(logits, dim=1).squeeze()
99
+
100
+ labels = ["entailment", "neutral", "contradiction"]
101
+ return labels, probs
102
+
103
+ else:
104
+ raise NotImplementedError(f"Task '{task}' not supported for DistilBERT")
105
+
106
+ def get_all_grad_attn_matrix(self, task, sentence, hypothesis='', maskID = 0):
107
+ print(task, sentence,hypothesis)
108
+
109
+ print('Tokenize')
110
+ if task == 'mnli':
111
+ inputs = self.tokenizer(sentence, hypothesis, return_tensors='pt', padding=False, truncation=True)
112
+ elif task == 'mlm':
113
+ inputs = self.tokenizer(sentence, return_tensors='pt', padding=False, truncation=True)
114
+ if maskID is not None and 0 <= maskID < inputs['input_ids'].size(1):
115
+ inputs['input_ids'][0][maskID] = self.tokenizer.mask_token_id
116
+ else:
117
+ print(f"Invalid maskID {maskID} for input of length {inputs['input_ids'].size(1)}")
118
+ raise ValueError(f"Invalid maskID {maskID} for input of length {inputs['input_ids'].size(1)}")
119
+ inputs = {k: v.to(self.device) for k, v in inputs.items()}
120
+ else:
121
+ inputs = self.tokenizer(sentence, return_tensors='pt', padding=False, truncation=True)
122
+ tokens = self.tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
123
+ print(tokens)
124
+ inputs = {k: v.to(self.device) for k, v in inputs.items()}
125
+
126
+ print('Input embeddings with grad')
127
+ embedding_layer = self.model.distilbert.embeddings.word_embeddings
128
+ inputs_embeds = embedding_layer(inputs["input_ids"])
129
+ inputs_embeds.requires_grad_()
130
+
131
+ print('Forward pass')
132
+ outputs = self.model.distilbert(
133
+ inputs_embeds=inputs_embeds,
134
+ attention_mask=inputs["attention_mask"],
135
+ output_attentions=True,
136
+ )
137
+ attentions = outputs.attentions # list of [1, heads, seq, seq]
138
+
139
+ print('Mean attentions per layer')
140
+ mean_attns = [a.squeeze(0).mean(dim=0).detach().cpu() for a in attentions]
141
+
142
+
143
+
144
+ attn_matrices_all = []
145
+ grad_matrices_all = []
146
+ for target_layer in range(len(attentions)):
147
+ grad_matrix, attn_matrix = self.get_grad_attn_matrix(inputs_embeds, attentions, mean_attns, target_layer)
148
+ grad_matrices_all.append(grad_matrix.tolist())
149
+ attn_matrices_all.append(attn_matrix.tolist())
150
+ return grad_matrices_all, attn_matrices_all
151
+
152
+
153
+ def get_grad_attn_matrix(self,inputs_embeds, attentions, mean_attns, target_layer):
154
+ attn_matrix = mean_attns[target_layer]
155
+ seq_len = attn_matrix.shape[0]
156
+ attn_layer = attentions[target_layer].squeeze(0).mean(dim=0)
157
+
158
+ print('Computing grad norms')
159
+ grad_norms_list = []
160
+ for k in range(seq_len):
161
+ scalar = attn_layer[:, k].sum()
162
+ grad = torch.autograd.grad(scalar, inputs_embeds, retain_graph=True)[0].squeeze(0)
163
+ grad_norms = grad.norm(dim=1)
164
+ grad_norms_list.append(grad_norms.unsqueeze(1))
165
+
166
+ grad_matrix = torch.cat(grad_norms_list, dim=1)
167
+ grad_matrix = grad_matrix[:seq_len, :seq_len]
168
+ attn_matrix = attn_matrix[:seq_len, :seq_len]
169
+
170
+ return grad_matrix, attn_matrix
171
+
172
+
173
+
174
+ if __name__ == "__main__":
175
+ import sys
176
+
177
+ MODEL_CLASSES = {
178
+ "bert": BERTVisualizer,
179
+ "roberta": RoBERTaVisualizer,
180
+ "distilbert": DistilBERTVisualizer,
181
+ "bart": BARTVisualizer,
182
+ }
183
+
184
+ # Parse command-line args or fallback to default
185
+ model_name = sys.argv[1] if len(sys.argv) > 1 else "bert"
186
+ text = " ".join(sys.argv[2:]) if len(sys.argv) > 2 else "The quick brown fox jumps over the lazy dog."
187
+
188
+ if model_name.lower() not in MODEL_CLASSES:
189
+ print(f"Supported models: {list(MODEL_CLASSES.keys())}")
190
+ sys.exit(1)
191
+
192
+ # Instantiate the visualizer
193
+ visualizer_class = MODEL_CLASSES[model_name.lower()]
194
+ visualizer = visualizer_class()
195
+
196
+ # Tokenize
197
+ token_info = visualizer.tokenize(text)
198
+
199
+ # Report
200
+ print(f"\nModel: {model_name}")
201
+ print(f"Num attention layers: {visualizer.num_attention_layers}")
202
+ print(f"Tokens: {token_info['tokens']}")
203
+ print(f"Input IDs: {token_info['input_ids'].tolist()}")
204
+ print(f"Attention mask: {token_info['attention_mask'].tolist()}")
205
+
206
+
207
+ """
208
+ usage for debug:
209
+ python your_file.py bert "The rain in Spain falls mainly on the plain."
210
+ """
ROBERTAmodel.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import RobertaTokenizer, RobertaForMaskedLM
2
+ import torch
3
+ import torch.nn.functional as F
4
+ from models import TransformerVisualizer
5
+ from transformers import (
6
+ RobertaForMaskedLM, RobertaForSequenceClassification, RobertaForQuestionAnswering,
7
+ )
8
+
9
+ class RoBERTaVisualizer(TransformerVisualizer):
10
+ def __init__(self, task):
11
+ super().__init__()
12
+ self.task = task
13
+ self.tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
14
+ if self.task == 'mlm':
15
+ self.model = RobertaForMaskedLM.from_pretrained("roberta-base")
16
+ elif self.task == 'sst':
17
+ self.model = RobertaForSequenceClassification.from_pretrained('textattack/roberta-base-SST-2')
18
+ elif self.task == 'mnli':
19
+ self.model = RobertaForSequenceClassification.from_pretrained("roberta-large-mnli")
20
+
21
+
22
+ self.model.to(self.device)
23
+ self.model.eval()
24
+ self.num_attention_layers = self.model.config.num_hidden_layers
25
+
26
+
27
+ def tokenize(self, text, hypothesis = ''):
28
+
29
+
30
+
31
+ if len(hypothesis) == 0:
32
+ encoded = self.tokenizer(text, return_tensors='pt', return_attention_mask=True,padding=False, truncation=True)
33
+ else:
34
+ encoded = self.tokenizer(text, hypothesis, return_tensors='pt', return_attention_mask=True,padding=False, truncation=True)
35
+
36
+ input_ids = encoded['input_ids'].to(self.device)
37
+ attention_mask = encoded['attention_mask'].to(self.device)
38
+ tokens = self.tokenizer.convert_ids_to_tokens(input_ids[0])
39
+ print('First time tokenizing:', tokens, len(tokens))
40
+
41
+ response = {
42
+ 'input_ids': input_ids,
43
+ 'attention_mask': attention_mask,
44
+ 'tokens': tokens
45
+ }
46
+ print(response)
47
+ return response
48
+
49
+ def predict(self, task, text, hypothesis='', maskID = None):
50
+
51
+
52
+
53
+ if task == 'mlm':
54
+ inputs = self.tokenizer(text, return_tensors='pt', padding=False, truncation=True)
55
+ if maskID is not None and 0 <= maskID < inputs['input_ids'].size(1):
56
+ inputs['input_ids'][0][maskID] = self.tokenizer.mask_token_id
57
+ mask_index = maskID
58
+ else:
59
+ raise ValueError(f"Invalid maskID {maskID} for input of length {inputs['input_ids'].size(1)}")
60
+ inputs = {k: v.to(self.device) for k, v in inputs.items()}
61
+
62
+ with torch.no_grad():
63
+ outputs = self.model(**inputs)
64
+ logits = outputs.logits
65
+
66
+ mask_logits = logits[0, mask_index]
67
+ top_probs, top_indices = torch.topk(F.softmax(mask_logits, dim=-1), 10)
68
+ decoded = self.tokenizer.convert_ids_to_tokens(top_indices.tolist())
69
+ return decoded, top_probs
70
+
71
+ elif task == 'sst':
72
+ inputs = self.tokenizer(text, return_tensors='pt', padding=False, truncation=True).to(self.device)
73
+
74
+ with torch.no_grad():
75
+ outputs = self.model(**inputs)
76
+ logits = outputs.logits
77
+ probs = F.softmax(logits, dim=1).squeeze()
78
+
79
+ labels = ["negative", "positive"]
80
+ return labels, probs
81
+
82
+ elif task == 'mnli':
83
+ inputs = self.tokenizer(text, hypothesis, return_tensors='pt', padding=True, truncation=True).to(self.device)
84
+
85
+ with torch.no_grad():
86
+ outputs = self.model(**inputs)
87
+ logits = outputs.logits
88
+ probs = F.softmax(logits, dim=1).squeeze()
89
+
90
+ labels = ["entailment", "neutral", "contradiction"]
91
+ return labels, probs
92
+
93
+ else:
94
+ raise NotImplementedError(f"Task '{task}' not supported for RoBERTa")
95
+
96
+
97
+ def get_all_grad_attn_matrix(self, task, sentence, hypothesis='', maskID = None):
98
+ print(task, sentence, hypothesis)
99
+ print('Tokenize')
100
+ if task == 'mnli':
101
+ inputs = self.tokenizer(sentence, hypothesis, return_tensors='pt', padding=False, truncation=True)
102
+ elif task == 'mlm':
103
+ inputs = self.tokenizer(sentence, return_tensors='pt', padding=False, truncation=True)
104
+ if maskID is not None and 0 <= maskID < inputs['input_ids'].size(1):
105
+ inputs['input_ids'][0][maskID] = self.tokenizer.mask_token_id
106
+ else:
107
+ inputs = self.tokenizer(sentence, return_tensors='pt', padding=False, truncation=True)
108
+ tokens = self.tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
109
+ print(tokens)
110
+ inputs = {k: v.to(self.device) for k, v in inputs.items()}
111
+
112
+ print('Input embeddings with grad')
113
+ embedding_layer = self.model.roberta.embeddings.word_embeddings
114
+ inputs_embeds = embedding_layer(inputs["input_ids"])
115
+ inputs_embeds.requires_grad_()
116
+
117
+ print('Forward pass')
118
+ outputs = self.model.roberta(
119
+ inputs_embeds=inputs_embeds,
120
+ attention_mask=inputs["attention_mask"],
121
+ output_attentions=True
122
+ )
123
+ attentions = outputs.attentions # list of [1, heads, seq, seq]
124
+
125
+ print('Average attentions per layer')
126
+ mean_attns = [a.squeeze(0).mean(dim=0).detach().cpu() for a in attentions]
127
+
128
+ attn_matrices_all = []
129
+ grad_matrices_all = []
130
+ for target_layer in range(len(attentions)):
131
+ grad_matrix, attn_matrix = self.get_grad_attn_matrix(inputs_embeds, attentions, mean_attns, target_layer)
132
+ grad_matrices_all.append(grad_matrix.tolist())
133
+ attn_matrices_all.append(attn_matrix.tolist())
134
+ return grad_matrices_all, attn_matrices_all
135
+
136
+ def get_grad_attn_matrix(self,inputs_embeds, attentions, mean_attns, target_layer):
137
+
138
+ attn_matrix = mean_attns[target_layer]
139
+ seq_len = attn_matrix.shape[0]
140
+ attn_layer = attentions[target_layer].squeeze(0).mean(dim=0) # [seq, seq]
141
+
142
+ print('Computing grad norms')
143
+ grad_norms_list = []
144
+ for k in range(seq_len):
145
+ scalar = attn_layer[:, k].sum()
146
+ grad = torch.autograd.grad(scalar, inputs_embeds, retain_graph=True)[0].squeeze(0)
147
+ grad_norms = grad.norm(dim=1)
148
+ grad_norms_list.append(grad_norms.unsqueeze(1))
149
+
150
+ grad_matrix = torch.cat(grad_norms_list, dim=1)
151
+ grad_matrix = grad_matrix[:seq_len, :seq_len]
152
+ attn_matrix = attn_matrix[:seq_len, :seq_len]
153
+
154
+
155
+ return grad_matrix, attn_matrix
models.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import BertTokenizer, BertModel
2
+ import torch
3
+ import matplotlib.pyplot as plt
4
+ import torch.nn as nn
5
+ from transformers import BertTokenizer, BertModel, DataCollatorForLanguageModeling
6
+ from datasets import load_dataset
7
+ from torch.utils.data import DataLoader
8
+ import torch.nn.functional as F
9
+
10
+ from transformers import (
11
+ BertTokenizer, BertModel,
12
+ DataCollatorForLanguageModeling
13
+ )
14
+ import torch.optim as optim
15
+
16
+ import os
17
+ from transformers.models.bert.modeling_bert import BertOnlyMLMHead
18
+ from transformers import RobertaModel, RobertaTokenizer
19
+ from transformers import DistilBertModel, DistilBertTokenizer
20
+ from transformers import BartModel, BartTokenizer
21
+
22
+
23
+
24
+ class TransformerVisualizer():
25
+ def __init__(self):
26
+ self.device = torch.device('cpu')
27
+
28
+ def predict(self, task, text):
29
+ return task, text,1
30
+
31
+
32
+ def get_attention_gradient_matrix(self, task, text, target_layer):
33
+ return task, text,target_layer,1
34
+
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ fastapi
2
+ uvicorn
3
+ transformers
4
+ torch
server.py ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, Request
2
+ from pydantic import BaseModel
3
+
4
+ import torch
5
+
6
+ from fastapi.middleware.cors import CORSMiddleware
7
+
8
+ from ROBERTAmodel import *
9
+ from BERTmodel import *
10
+ from DISTILLBERTmodel import *
11
+
12
+ VISUALIZER_CLASSES = {
13
+ "BERT": BERTVisualizer,
14
+ "RoBERTa": RoBERTaVisualizer,
15
+ "DistilBERT": DistilBERTVisualizer,
16
+ }
17
+
18
+ VISUALIZER_CACHE = {}
19
+ app = FastAPI()
20
+
21
+ app.add_middleware(
22
+ CORSMiddleware,
23
+ allow_origins=["*"], # or restrict to ["http://localhost:3000"]
24
+ allow_credentials=True,
25
+ allow_methods=["*"],
26
+ allow_headers=["*"],
27
+ )
28
+
29
+ MODEL_MAP = {
30
+ "BERT": "bert-base-uncased",
31
+ "RoBERTa": "roberta-base",
32
+ "DistilBERT": "distilbert-base-uncased",
33
+ }
34
+
35
+ class LoadModelRequest(BaseModel):
36
+ model: str
37
+ sentence: str
38
+ task:str
39
+ hypothesis:str
40
+
41
+ class GradAttnModelRequest(BaseModel):
42
+ model: str
43
+ task: str
44
+ sentence: str
45
+ hypothesis:str
46
+ maskID: int | None = None
47
+
48
+ class PredModelRequest(BaseModel):
49
+ model: str
50
+ sentence: str
51
+ task:str
52
+ hypothesis:str
53
+ maskID: int | None = None
54
+
55
+
56
+ @app.post("/load_model")
57
+ def load_model(req: LoadModelRequest):
58
+ print(f"\n--- /load_model request received ---")
59
+ print(f"Model: {req.model}")
60
+ print(f"Sentence: {req.sentence}")
61
+ print(f"Task: {req.task}")
62
+ print(f"hypothesis: {req.hypothesis}")
63
+
64
+
65
+ if req.model in VISUALIZER_CACHE:
66
+ del VISUALIZER_CACHE[req.model]
67
+ torch.cuda.empty_cache()
68
+
69
+ vis_class = VISUALIZER_CLASSES.get(req.model)
70
+ if vis_class is None:
71
+ return {"error": f"Unknown model: {req.model}"}
72
+
73
+ print("instantiating visualizer")
74
+ try:
75
+ vis = vis_class(task=req.task.lower())
76
+ print(vis)
77
+ VISUALIZER_CACHE[req.model] = vis
78
+ print("Visualizer instantiated")
79
+ except Exception as e:
80
+ print("Visualizer init failed:", e)
81
+ return {"error": f"Instantiation failed: {str(e)}"}
82
+
83
+ print('tokenizing')
84
+ try:
85
+ if req.task.lower() == 'mnli':
86
+ token_output = vis.tokenize(req.sentence, hypothesis=req.hypothesis)
87
+ else:
88
+ token_output = vis.tokenize(req.sentence)
89
+ print("0 Tokenization successful:", token_output["tokens"])
90
+ except Exception as e:
91
+ print("Tokenization failed:", e)
92
+ return {"error": f"Tokenization failed: {str(e)}"}
93
+
94
+ print('response ready')
95
+ response = {
96
+ "model": req.model,
97
+ "tokens": token_output['tokens'],
98
+ "num_layers": vis.num_attention_layers,
99
+ }
100
+ print("load model successful")
101
+ print(response)
102
+ return response
103
+
104
+
105
+
106
+
107
+
108
+ @app.post("/predict_model")
109
+ def predict_model(req: PredModelRequest):
110
+
111
+ print(f"\n--- /predict_model request received ---")
112
+ print(f"predict: Model: {req.model}")
113
+ print(f"predict: Task: {req.task}")
114
+ print(f"predict: sentence: {req.sentence}")
115
+ print(f"predict: hypothesis: {req.hypothesis}")
116
+ print(f"predict: maskID: {req.maskID}")
117
+
118
+
119
+
120
+ print('predict: instantiating')
121
+ try:
122
+ vis_class = VISUALIZER_CLASSES.get(req.model)
123
+ if vis_class is None:
124
+ return {"error": f"Unknown model: {req.model}"}
125
+ #if any(p.device.type == 'meta' for p in vis.model.parameters()):
126
+ # vis.model = torch.nn.Module.to_empty(vis.model, device=torch.device("cpu"))
127
+
128
+ vis = vis_class(task=req.task.lower())
129
+ VISUALIZER_CACHE[req.model] = vis
130
+ print("Model reloaded and cached.")
131
+ except Exception as e:
132
+ return {"error": f"Failed to reload model: {str(e)}"}
133
+
134
+ print('predict: meta stuff')
135
+
136
+
137
+
138
+ print('predict: Run prediction')
139
+ try:
140
+ if req.task.lower() == 'mnli':
141
+ decoded, top_probs = vis.predict(req.task.lower(), req.sentence, hypothesis=req.hypothesis)
142
+ elif req.task.lower() == 'mlm':
143
+ decoded, top_probs = vis.predict(req.task.lower(), req.sentence, maskID=req.maskID)
144
+
145
+ else:
146
+ decoded, top_probs = vis.predict(req.task.lower(), req.sentence)
147
+ except Exception as e:
148
+ decoded, top_probs = "error", e
149
+ print(e)
150
+
151
+ print('predict: response ready')
152
+ response = {
153
+ "decoded": decoded,
154
+ "top_probs": top_probs.tolist(),
155
+ }
156
+ print("predict: predict model successful")
157
+ if len(decoded) > 5:
158
+ print([(k,v[:5]) for k,v in response.items()])
159
+ else:
160
+ print(response)
161
+ return response
162
+
163
+
164
+
165
+ @app.post("/get_grad_attn_matrix")
166
+ def get_grad_attn_matrix(req: GradAttnModelRequest):
167
+
168
+ try:
169
+ print(f"\n--- /get_grad_matrix request received ---")
170
+ print(f"grad:Model: {req.model}")
171
+ print(f"grad:Task: {req.task}")
172
+ print(f"grad:sentence: {req.sentence}")
173
+ print(f"grad: hypothesis: {req.hypothesis}")
174
+ print(f"predict: maskID: {req.maskID}")
175
+
176
+
177
+
178
+ try:
179
+ vis_class = VISUALIZER_CLASSES.get(req.model)
180
+ if vis_class is None:
181
+ return {"error": f"Unknown model: {req.model}"}
182
+ #if any(p.device.type == 'meta' for p in vis.model.parameters()):
183
+ # vis.model = torch.nn.Module.to_empty(vis.model, device=torch.device("cpu"))
184
+ vis = vis_class(task=req.task.lower())
185
+ VISUALIZER_CACHE[req.model] = vis
186
+ print("Model reloaded and cached.")
187
+ except Exception as e:
188
+ return {"error": f"Failed to reload model: {str(e)}"}
189
+
190
+
191
+
192
+ print("run function")
193
+ try:
194
+ if req.task.lower()=='mnli':
195
+ grad_matrix, attn_matrix = vis.get_all_grad_attn_matrix(req.task.lower(), req.sentence,hypothesis=req.hypothesis)
196
+ elif req.task.lower()=='mlm':
197
+ grad_matrix, attn_matrix = vis.get_all_grad_attn_matrix(req.task.lower(), req.sentence,maskID=req.maskID)
198
+ else:
199
+ grad_matrix, attn_matrix = vis.get_all_grad_attn_matrix(req.task.lower(), req.sentence)
200
+ except Exception as e:
201
+ print("Exception during grad/attn computation:", e)
202
+ grad_matrix, attn_matrix = e,e
203
+
204
+
205
+ response = {
206
+ "grad_matrix": grad_matrix,
207
+ "attn_matrix": attn_matrix,
208
+ }
209
+ print('grad attn successful')
210
+ return response
211
+ except Exception as e:
212
+ print("SERVER EXCEPTION:", e)
213
+ return {"error": str(e)}
214
+
215
+