File size: 324 Bytes
04ccab0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
993b547
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
import torch




class TransformerVisualizer():
    def __init__(self):
        self.device = torch.device('cpu')
        
    def predict(self, task, text):
        return task, text,1
    

    def get_attention_gradient_matrix(self, task, text, target_layer):
        return task, text,target_layer,1