yifan0sun commited on
Commit
e60ad79
·
verified ·
1 Parent(s): a42cb18

Update ROBERTAmodel.py

Browse files
Files changed (1) hide show
  1. ROBERTAmodel.py +10 -2
ROBERTAmodel.py CHANGED
@@ -10,6 +10,8 @@ import torch.autograd.functional as Fgrad
10
 
11
  CACHE_DIR = "/data/hf_cache"
12
 
 
 
13
  class RoBERTaVisualizer(TransformerVisualizer):
14
  def __init__(self, task):
15
  super().__init__()
@@ -73,6 +75,12 @@ class RoBERTaVisualizer(TransformerVisualizer):
73
 
74
 
75
  self.model.to(self.device)
 
 
 
 
 
 
76
  self.model.eval()
77
  self.num_attention_layers = self.model.config.num_hidden_layers
78
 
@@ -199,14 +207,14 @@ class RoBERTaVisualizer(TransformerVisualizer):
199
  return attentions_condensed
200
 
201
  start = time.time()
202
- jac = torch.autograd.functional.jacobian(scalar_outputs, inputs_embeds)
203
  print(jac.shape)
204
  jac = jac.norm(dim=-1).squeeze(dim=2)
205
  print(jac.shape)
206
  seq_len = jac.shape[0]
207
  print(seq_len)
208
  grad_matrices_all = [jac[ii,:,:].tolist() for ii in range(seq_len)]
209
-
210
  print(31,time.time()-start)
211
  attn_matrices_all = []
212
  for target_layer in range(len(attentions)):
 
10
 
11
  CACHE_DIR = "/data/hf_cache"
12
 
13
+
14
+
15
  class RoBERTaVisualizer(TransformerVisualizer):
16
  def __init__(self, task):
17
  super().__init__()
 
75
 
76
 
77
  self.model.to(self.device)
78
+ # Force materialization of all layers (avoids meta device errors)
79
+ with torch.no_grad():
80
+ dummy_ids = torch.tensor([[0, 1]], device=self.device)
81
+ dummy_mask = torch.tensor([[1, 1]], device=self.device)
82
+ _ = self.model(input_ids=dummy_ids, attention_mask=dummy_mask)
83
+
84
  self.model.eval()
85
  self.num_attention_layers = self.model.config.num_hidden_layers
86
 
 
207
  return attentions_condensed
208
 
209
  start = time.time()
210
+ jac = torch.autograd.functional.jacobian(scalar_outputs, inputs_embeds).to(torch.float16)
211
  print(jac.shape)
212
  jac = jac.norm(dim=-1).squeeze(dim=2)
213
  print(jac.shape)
214
  seq_len = jac.shape[0]
215
  print(seq_len)
216
  grad_matrices_all = [jac[ii,:,:].tolist() for ii in range(seq_len)]
217
+
218
  print(31,time.time()-start)
219
  attn_matrices_all = []
220
  for target_layer in range(len(attentions)):