Spaces:
Runtime error
Runtime error
import torch | |
from transformers import AutoTokenizer | |
from captum.attr import visualization | |
from roberta2 import RobertaForSequenceClassification | |
from util import visualize_text, PyTMinMaxScalerVectorized | |
classifications = ["NEGATIVE", "POSITIVE"] | |
class GradientRolloutExplainer: | |
def __init__(self): | |
self.device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") | |
self.model = RobertaForSequenceClassification.from_pretrained("textattack/roberta-base-SST-2").to(self.device) | |
self.model.eval() | |
self.tokenizer = AutoTokenizer.from_pretrained("textattack/roberta-base-SST-2") | |
def tokens_from_ids(self, ids): | |
return list(map(lambda s: s[1:] if s[0] == "Ġ" else s, self.tokenizer.convert_ids_to_tokens(ids))) | |
def run_attribution_model(self, input_ids, attention_mask, index=None, start_layer=0): | |
def avg_heads(cam, grad): | |
cam = (grad * cam).clamp(min=0).mean(dim=-3) | |
# set negative values to 0, then average | |
# cam = cam.clamp(min=0).mean(dim=0) | |
return cam | |
def apply_self_attention_rules(R_ss, cam_ss): | |
R_ss_addition = torch.matmul(cam_ss, R_ss) | |
return R_ss_addition | |
output = self.model(input_ids=input_ids, attention_mask=attention_mask)[0] | |
if index == None: | |
# index = np.expand_dims(np.arange(input_ids.shape[1]) | |
# by default explain the class with the highest score | |
index = output.argmax(axis=-1).detach().cpu().numpy() | |
# create a one-hot vector selecting class we want explanations for | |
one_hot = ( | |
torch.nn.functional.one_hot( | |
torch.tensor(index, dtype=torch.int64), num_classes=output.size(-1) | |
) | |
.to(torch.float) | |
.requires_grad_(True) | |
).to(self.device) | |
one_hot = torch.sum(one_hot * output) | |
self.model.zero_grad() | |
# create the gradients for the class we're interested in | |
one_hot.backward(retain_graph=True) | |
num_tokens = self.model.roberta.encoder.layer[0].attention.self.get_attn().shape[-1] | |
R = torch.eye(num_tokens).expand(output.size(0), -1, -1).clone().to(self.device) | |
for i, blk in enumerate(self.model.roberta.encoder.layer): | |
if i < start_layer: | |
continue | |
grad = blk.attention.self.get_attn_gradients() | |
cam = blk.attention.self.get_attn() | |
cam = avg_heads(cam, grad) | |
joint = apply_self_attention_rules(R, cam) | |
R += joint | |
return output, R[:, 0, 1:-1] | |
def build_visualization(self, input_ids, attention_mask, index=None, start_layer=8): | |
# generate an explanation for the input | |
vis_data_records = [] | |
for index in range(2): | |
output, expl = self.run_attribution_model( | |
input_ids, attention_mask, index=index, start_layer=start_layer | |
) | |
# normalize scores | |
scaler = PyTMinMaxScalerVectorized() | |
norm = scaler(expl) | |
# get the model classification | |
output = torch.nn.functional.softmax(output, dim=-1) | |
for record in range(input_ids.size(0)): | |
classification = output[record].argmax(dim=-1).item() | |
class_name = classifications[classification] | |
nrm = norm[record] | |
# if the classification is negative, higher explanation scores are more negative | |
# flip for visualization | |
#if class_name == "NEGATIVE": | |
if index == 0: | |
nrm *= -1 | |
tokens = self.tokens_from_ids(input_ids[record].flatten())[ | |
1 : 0 - ((attention_mask[record] == 0).sum().item() + 1) | |
] | |
vis_data_records.append( | |
visualization.VisualizationDataRecord( | |
nrm, | |
output[record][classification], | |
classification, | |
classification, | |
index, | |
1, | |
tokens, | |
1, | |
) | |
) | |
return visualize_text(vis_data_records) | |
def __call__(self, input_text, start_layer=8): | |
text_batch = [input_text] | |
encoding = self.tokenizer(text_batch, return_tensors="pt") | |
input_ids = encoding["input_ids"].to(self.device) | |
attention_mask = encoding["attention_mask"].to(self.device) | |
return self.build_visualization(input_ids, attention_mask, start_layer=int(start_layer)) | |