yifan0sun commited on
Commit
6362f90
·
verified ·
1 Parent(s): d37fcc1

Update ROBERTAmodel.py

Browse files
Files changed (1) hide show
  1. ROBERTAmodel.py +7 -6
ROBERTAmodel.py CHANGED
@@ -215,11 +215,11 @@ class RoBERTaVisualizer(TransformerVisualizer):
215
  attn_matrix = attn_matrix.to(torch.float16)
216
 
217
  attn_layer = attentions[target_layer].squeeze(0).mean(dim=0) # [seq, seq]
218
-
219
- print('Computing grad norms')
 
220
  grad_norms_list = []
221
  for k in range(seq_len):
222
- start = time.time()
223
  scalar = attn_layer[:, k].sum()
224
  grad = torch.autograd.grad(scalar, inputs_embeds, retain_graph=True)[0].squeeze(0)
225
  grad_norms = grad.norm(dim=1)
@@ -228,15 +228,16 @@ class RoBERTaVisualizer(TransformerVisualizer):
228
  grad_norms = torch.round(grad_norms.unsqueeze(1).float() * 100) / 100
229
  grad_norms = grad_norms.to(torch.float16)
230
 
231
- start = time.time()
232
 
233
  grad_norms_list.append(grad_norms)
234
 
 
235
  start = time.time()
236
- print(9,time.time()-start)
237
  grad_matrix = torch.cat(grad_norms_list, dim=1)
238
  grad_matrix = grad_matrix[:seq_len, :seq_len]
239
  attn_matrix = attn_matrix[:seq_len, :seq_len]
 
 
240
  start = time.time()
241
 
242
  attn_matrix = torch.round(attn_matrix.float() * 100) / 100
@@ -244,7 +245,7 @@ class RoBERTaVisualizer(TransformerVisualizer):
244
 
245
  grad_matrix = torch.round(grad_matrix.float() * 100) / 100
246
  grad_matrix = grad_matrix.to(torch.float16)
247
- start = time.time()
248
 
249
 
250
 
 
215
  attn_matrix = attn_matrix.to(torch.float16)
216
 
217
  attn_layer = attentions[target_layer].squeeze(0).mean(dim=0) # [seq, seq]
218
+ print(9,time.time()-start)
219
+ start = time.time()
220
+ #print('Computing grad norms')
221
  grad_norms_list = []
222
  for k in range(seq_len):
 
223
  scalar = attn_layer[:, k].sum()
224
  grad = torch.autograd.grad(scalar, inputs_embeds, retain_graph=True)[0].squeeze(0)
225
  grad_norms = grad.norm(dim=1)
 
228
  grad_norms = torch.round(grad_norms.unsqueeze(1).float() * 100) / 100
229
  grad_norms = grad_norms.to(torch.float16)
230
 
 
231
 
232
  grad_norms_list.append(grad_norms)
233
 
234
+ print(10,time.time()-start)
235
  start = time.time()
 
236
  grad_matrix = torch.cat(grad_norms_list, dim=1)
237
  grad_matrix = grad_matrix[:seq_len, :seq_len]
238
  attn_matrix = attn_matrix[:seq_len, :seq_len]
239
+
240
+ print(11,time.time()-start)
241
  start = time.time()
242
 
243
  attn_matrix = torch.round(attn_matrix.float() * 100) / 100
 
245
 
246
  grad_matrix = torch.round(grad_matrix.float() * 100) / 100
247
  grad_matrix = grad_matrix.to(torch.float16)
248
+ print(12,time.time()-start)
249
 
250
 
251