Spaces:
Running
on
T4
Running
on
T4
Update ROBERTAmodel.py
Browse files- ROBERTAmodel.py +10 -2
ROBERTAmodel.py
CHANGED
@@ -10,6 +10,8 @@ import torch.autograd.functional as Fgrad
|
|
10 |
|
11 |
CACHE_DIR = "/data/hf_cache"
|
12 |
|
|
|
|
|
13 |
class RoBERTaVisualizer(TransformerVisualizer):
|
14 |
def __init__(self, task):
|
15 |
super().__init__()
|
@@ -73,6 +75,12 @@ class RoBERTaVisualizer(TransformerVisualizer):
|
|
73 |
|
74 |
|
75 |
self.model.to(self.device)
|
|
|
|
|
|
|
|
|
|
|
|
|
76 |
self.model.eval()
|
77 |
self.num_attention_layers = self.model.config.num_hidden_layers
|
78 |
|
@@ -199,14 +207,14 @@ class RoBERTaVisualizer(TransformerVisualizer):
|
|
199 |
return attentions_condensed
|
200 |
|
201 |
start = time.time()
|
202 |
-
jac = torch.autograd.functional.jacobian(scalar_outputs, inputs_embeds)
|
203 |
print(jac.shape)
|
204 |
jac = jac.norm(dim=-1).squeeze(dim=2)
|
205 |
print(jac.shape)
|
206 |
seq_len = jac.shape[0]
|
207 |
print(seq_len)
|
208 |
grad_matrices_all = [jac[ii,:,:].tolist() for ii in range(seq_len)]
|
209 |
-
|
210 |
print(31,time.time()-start)
|
211 |
attn_matrices_all = []
|
212 |
for target_layer in range(len(attentions)):
|
|
|
10 |
|
11 |
CACHE_DIR = "/data/hf_cache"
|
12 |
|
13 |
+
|
14 |
+
|
15 |
class RoBERTaVisualizer(TransformerVisualizer):
|
16 |
def __init__(self, task):
|
17 |
super().__init__()
|
|
|
75 |
|
76 |
|
77 |
self.model.to(self.device)
|
78 |
+
# Force materialization of all layers (avoids meta device errors)
|
79 |
+
with torch.no_grad():
|
80 |
+
dummy_ids = torch.tensor([[0, 1]], device=self.device)
|
81 |
+
dummy_mask = torch.tensor([[1, 1]], device=self.device)
|
82 |
+
_ = self.model(input_ids=dummy_ids, attention_mask=dummy_mask)
|
83 |
+
|
84 |
self.model.eval()
|
85 |
self.num_attention_layers = self.model.config.num_hidden_layers
|
86 |
|
|
|
207 |
return attentions_condensed
|
208 |
|
209 |
start = time.time()
|
210 |
+
jac = torch.autograd.functional.jacobian(scalar_outputs, inputs_embeds).to(torch.float16)
|
211 |
print(jac.shape)
|
212 |
jac = jac.norm(dim=-1).squeeze(dim=2)
|
213 |
print(jac.shape)
|
214 |
seq_len = jac.shape[0]
|
215 |
print(seq_len)
|
216 |
grad_matrices_all = [jac[ii,:,:].tolist() for ii in range(seq_len)]
|
217 |
+
|
218 |
print(31,time.time()-start)
|
219 |
attn_matrices_all = []
|
220 |
for target_layer in range(len(attentions)):
|