yifan0sun commited on
Commit
1f7f45a
·
verified ·
1 Parent(s): 9b48d92

Update ROBERTAmodel.py

Browse files
Files changed (1) hide show
  1. ROBERTAmodel.py +5 -35
ROBERTAmodel.py CHANGED
@@ -194,7 +194,7 @@ class RoBERTaVisualizer(TransformerVisualizer):
194
  attn_matrices_all.append(attn_matrix.tolist())
195
 
196
 
197
-
198
  start = time.time()
199
  def scalar_outputs(inputs_embeds):
200
 
@@ -210,14 +210,12 @@ class RoBERTaVisualizer(TransformerVisualizer):
210
 
211
  grad_matrices_all.append(jac.tolist())
212
  print(1,time.time()-start)
213
-
214
  start = time.time()
215
  grad_norms_list = []
216
-
217
  for k in range(seq_len):
218
- scalar = attentions[target_layer].mean(dim=0).mean(dim=0)
219
- scalar = scalar[:, k].sum()
220
-
221
  grad = torch.autograd.grad(scalar, inputs_embeds, retain_graph=True)[0].squeeze(0)
222
 
223
  grad_norms = grad.norm(dim=1)
@@ -225,32 +223,4 @@ class RoBERTaVisualizer(TransformerVisualizer):
225
  print(2,time.time()-start)
226
 
227
  return grad_matrices_all, attn_matrices_all
228
-
229
- def get_grad_attn_matrix(self,inputs_embeds, attentions, mean_attns, target_layer):
230
-
231
- attn_matrix = mean_attns[target_layer]
232
- seq_len = attn_matrix.shape[0]
233
- attn_layer = attentions[target_layer].squeeze(0).mean(dim=0) # [seq, seq]
234
- """
235
- print('Computing grad norms')
236
- grad_norms_list = []
237
-
238
- for k in range(seq_len):
239
- scalar = attn_layer[:, k].sum()
240
-
241
- grad = torch.autograd.grad(scalar, inputs_embeds, retain_graph=True)[0].squeeze(0)
242
-
243
- grad_norms = grad.norm(dim=1)
244
- grad_norms_list.append(grad_norms.unsqueeze(1))
245
-
246
- grad_matrix = torch.cat(grad_norms_list, dim=1)
247
-
248
-
249
-
250
-
251
- grad_matrix = grad_matrix[:seq_len, :seq_len]
252
- """
253
- attn_matrix = attn_matrix[:seq_len, :seq_len]
254
- grad_matrix = attn_matrix
255
-
256
- return grad_matrix, attn_matrix
 
194
  attn_matrices_all.append(attn_matrix.tolist())
195
 
196
 
197
+ """
198
  start = time.time()
199
  def scalar_outputs(inputs_embeds):
200
 
 
210
 
211
  grad_matrices_all.append(jac.tolist())
212
  print(1,time.time()-start)
213
+ """
214
  start = time.time()
215
  grad_norms_list = []
216
+ scalar_layer = attentions[target_layer].mean(dim=0).mean(dim=0)
217
  for k in range(seq_len):
218
+ scalar = scalar_layer[:, k].sum()
 
 
219
  grad = torch.autograd.grad(scalar, inputs_embeds, retain_graph=True)[0].squeeze(0)
220
 
221
  grad_norms = grad.norm(dim=1)
 
223
  print(2,time.time()-start)
224
 
225
  return grad_matrices_all, attn_matrices_all
226
+