Spaces:
Running
on
T4
Running
on
T4
Update ROBERTAmodel.py
Browse files- ROBERTAmodel.py +7 -6
ROBERTAmodel.py
CHANGED
@@ -215,11 +215,11 @@ class RoBERTaVisualizer(TransformerVisualizer):
|
|
215 |
attn_matrix = attn_matrix.to(torch.float16)
|
216 |
|
217 |
attn_layer = attentions[target_layer].squeeze(0).mean(dim=0) # [seq, seq]
|
218 |
-
|
219 |
-
|
|
|
220 |
grad_norms_list = []
|
221 |
for k in range(seq_len):
|
222 |
-
start = time.time()
|
223 |
scalar = attn_layer[:, k].sum()
|
224 |
grad = torch.autograd.grad(scalar, inputs_embeds, retain_graph=True)[0].squeeze(0)
|
225 |
grad_norms = grad.norm(dim=1)
|
@@ -228,15 +228,16 @@ class RoBERTaVisualizer(TransformerVisualizer):
|
|
228 |
grad_norms = torch.round(grad_norms.unsqueeze(1).float() * 100) / 100
|
229 |
grad_norms = grad_norms.to(torch.float16)
|
230 |
|
231 |
-
start = time.time()
|
232 |
|
233 |
grad_norms_list.append(grad_norms)
|
234 |
|
|
|
235 |
start = time.time()
|
236 |
-
print(9,time.time()-start)
|
237 |
grad_matrix = torch.cat(grad_norms_list, dim=1)
|
238 |
grad_matrix = grad_matrix[:seq_len, :seq_len]
|
239 |
attn_matrix = attn_matrix[:seq_len, :seq_len]
|
|
|
|
|
240 |
start = time.time()
|
241 |
|
242 |
attn_matrix = torch.round(attn_matrix.float() * 100) / 100
|
@@ -244,7 +245,7 @@ class RoBERTaVisualizer(TransformerVisualizer):
|
|
244 |
|
245 |
grad_matrix = torch.round(grad_matrix.float() * 100) / 100
|
246 |
grad_matrix = grad_matrix.to(torch.float16)
|
247 |
-
|
248 |
|
249 |
|
250 |
|
|
|
215 |
attn_matrix = attn_matrix.to(torch.float16)
|
216 |
|
217 |
attn_layer = attentions[target_layer].squeeze(0).mean(dim=0) # [seq, seq]
|
218 |
+
print(9,time.time()-start)
|
219 |
+
start = time.time()
|
220 |
+
#print('Computing grad norms')
|
221 |
grad_norms_list = []
|
222 |
for k in range(seq_len):
|
|
|
223 |
scalar = attn_layer[:, k].sum()
|
224 |
grad = torch.autograd.grad(scalar, inputs_embeds, retain_graph=True)[0].squeeze(0)
|
225 |
grad_norms = grad.norm(dim=1)
|
|
|
228 |
grad_norms = torch.round(grad_norms.unsqueeze(1).float() * 100) / 100
|
229 |
grad_norms = grad_norms.to(torch.float16)
|
230 |
|
|
|
231 |
|
232 |
grad_norms_list.append(grad_norms)
|
233 |
|
234 |
+
print(10,time.time()-start)
|
235 |
start = time.time()
|
|
|
236 |
grad_matrix = torch.cat(grad_norms_list, dim=1)
|
237 |
grad_matrix = grad_matrix[:seq_len, :seq_len]
|
238 |
attn_matrix = attn_matrix[:seq_len, :seq_len]
|
239 |
+
|
240 |
+
print(11,time.time()-start)
|
241 |
start = time.time()
|
242 |
|
243 |
attn_matrix = torch.round(attn_matrix.float() * 100) / 100
|
|
|
245 |
|
246 |
grad_matrix = torch.round(grad_matrix.float() * 100) / 100
|
247 |
grad_matrix = grad_matrix.to(torch.float16)
|
248 |
+
print(12,time.time()-start)
|
249 |
|
250 |
|
251 |
|