yifan0sun commited on
Commit
7c7c06a
·
verified ·
1 Parent(s): 2015ce0

Update ROBERTAmodel.py

Browse files
Files changed (1) hide show
  1. ROBERTAmodel.py +11 -6
ROBERTAmodel.py CHANGED
@@ -193,17 +193,20 @@ class RoBERTaVisualizer(TransformerVisualizer):
193
  attention_mask=inputs["attention_mask"],
194
  output_attentions=True
195
  )
 
196
  attentions_condensed = [a.mean(dim=0).mean(dim=0).sum(dim=0) for a in attentions]
197
- print([a.shape for a in attentions_condensed])
198
  attentions_condensed= torch.vstack(attentions_condensed)
199
- print(attentions_condensed.shape)
200
  return attentions_condensed
201
 
202
  start = time.time()
203
- jac = torch.autograd.functional.jacobian(scalar_outputs, inputs_embeds).norm(dim=-1).squeeze(dim=2)
204
  print(jac.shape)
205
- grad_matrices_all = [jac[i] for i in range(jac.size(0))]
206
-
 
 
 
 
207
  print(31,time.time()-start)
208
  attn_matrices_all = []
209
  for target_layer in range(len(attentions)):
@@ -212,7 +215,9 @@ class RoBERTaVisualizer(TransformerVisualizer):
212
  attn_matrix = mean_attns[target_layer]
213
  seq_len = attn_matrix.shape[0]
214
  attn_matrix = attn_matrix[:seq_len, :seq_len]
 
215
  attn_matrices_all.append(attn_matrix.tolist())
 
216
  print(3,time.time()-start)
217
 
218
 
@@ -259,6 +264,6 @@ class RoBERTaVisualizer(TransformerVisualizer):
259
  grad_norms_list.append(grad_norms.unsqueeze(1))
260
  print(2,time.time()-start)
261
  """
262
-
263
  return grad_matrices_all, attn_matrices_all
264
 
 
193
  attention_mask=inputs["attention_mask"],
194
  output_attentions=True
195
  )
196
+ attentions = outputs.attentions
197
  attentions_condensed = [a.mean(dim=0).mean(dim=0).sum(dim=0) for a in attentions]
 
198
  attentions_condensed= torch.vstack(attentions_condensed)
 
199
  return attentions_condensed
200
 
201
  start = time.time()
202
+ jac = torch.autograd.functional.jacobian(scalar_outputs, inputs_embeds)
203
  print(jac.shape)
204
+ jac = jac.norm(dim=-1).squeeze(dim=2)
205
+ print(jac.shape)
206
+ seq_len = jac.shape[0]
207
+ print(seq_len)
208
+ grad_matrices_all = [jac[ii,:,:].tolist() for ii in range(seq_len)]
209
+
210
  print(31,time.time()-start)
211
  attn_matrices_all = []
212
  for target_layer in range(len(attentions)):
 
215
  attn_matrix = mean_attns[target_layer]
216
  seq_len = attn_matrix.shape[0]
217
  attn_matrix = attn_matrix[:seq_len, :seq_len]
218
+ print(4,attn_matrix.shape)
219
  attn_matrices_all.append(attn_matrix.tolist())
220
+
221
  print(3,time.time()-start)
222
 
223
 
 
264
  grad_norms_list.append(grad_norms.unsqueeze(1))
265
  print(2,time.time()-start)
266
  """
267
+ #print(grad_matrices_all)
268
  return grad_matrices_all, attn_matrices_all
269