BERTGradGraph / BERTmodel.py
yifan0sun's picture
Update BERTmodel.py
fb57040 verified
import torch
import torch.nn as nn
from transformers import BertTokenizer
from models import TransformerVisualizer
from transformers import (
BertTokenizer,
BertForMaskedLM,
BertForSequenceClassification,
)
import torch.nn.functional as F
import os, time
CACHE_DIR = "/data/hf_cache"
class BERTVisualizer(TransformerVisualizer):
def __init__(self,task):
super().__init__()
self.task = task
print(task,'BERT VIS START')
TOKENIZER = 'bert-base-uncased'
LOCAL_PATH = os.path.join(CACHE_DIR, "tokenizers",TOKENIZER)
self.tokenizer = BertTokenizer.from_pretrained(LOCAL_PATH, local_files_only=True)
"""
try:
self.tokenizer = BertTokenizer.from_pretrained(LOCAL_PATH, local_files_only=True)
except Exception as e:
self.tokenizer = BertTokenizer.from_pretrained(TOKENIZER)
self.tokenizer.save_pretrained(LOCAL_PATH)
"""
print('finding model', self.task)
if self.task == 'mlm':
MODEL = 'bert-base-uncased'
LOCAL_PATH = os.path.join(CACHE_DIR, "models",MODEL)
self.model = BertForMaskedLM.from_pretrained( LOCAL_PATH, local_files_only=True, attn_implementation="eager" ).to(self.device)
"""
try:
self.model = BertForMaskedLM.from_pretrained( LOCAL_PATH, local_files_only=True, attn_implementation="eager" ).to(self.device)
except Exception as e:
self.model = BertForMaskedLM.from_pretrained( MODEL, attn_implementation="eager" ).to(self.device)
self.model.save_pretrained(LOCAL_PATH)
"""
elif self.task == 'sst':
MODEL = "textattack_bert-base-uncased-SST-2"
LOCAL_PATH = os.path.join(CACHE_DIR, "models",MODEL)
self.model = BertForSequenceClassification.from_pretrained( LOCAL_PATH, local_files_only=True, device_map=None ).to(self.device)
"""
try:
self.model = BertForSequenceClassification.from_pretrained( LOCAL_PATH, local_files_only=True, device_map=None )
except Exception as e:
self.model = BertForSequenceClassification.from_pretrained( MODEL, device_map=None )
self.model.save_pretrained(LOCAL_PATH)
"""
elif self.task == 'mnli':
MODEL = 'textattack_bert-base-uncased-MNLI'
LOCAL_PATH = os.path.join(CACHE_DIR, "models",MODEL)
self.model = BertForSequenceClassification.from_pretrained( LOCAL_PATH, local_files_only=True, device_map=None ).to(self.device)
"""
try:
self.model = BertForSequenceClassification.from_pretrained( LOCAL_PATH, local_files_only=True, device_map=None )
except Exception as e:
self.model = BertForSequenceClassification.from_pretrained( MODEL, device_map=None)
self.model.save_pretrained(LOCAL_PATH)
"""
else:
raise ValueError(f"Unsupported task: {self.task}")
print('model found')
self.model.to(self.device)
# Force materialization of all layers (avoids meta device errors)
with torch.no_grad():
dummy_ids = torch.tensor([[0, 1]], device=self.device)
dummy_mask = torch.tensor([[1, 1]], device=self.device)
_ = self.model(input_ids=dummy_ids, attention_mask=dummy_mask)
self.model.eval()
print('self model eval')
self.num_attention_layers = len(self.model.bert.encoder.layer)
print('init finished')
def tokenize(self, text, hypothesis = ''):
print('TTTokenize',text,'H:', hypothesis)
if len(hypothesis) == 0:
encoded = self.tokenizer(text, return_tensors='pt', return_attention_mask=True)
else:
encoded = self.tokenizer(text, hypothesis, return_tensors='pt', return_attention_mask=True)
input_ids = encoded['input_ids'].to(self.device)
attention_mask = encoded['attention_mask'].to(self.device)
tokens = self.tokenizer.convert_ids_to_tokens(input_ids[0])
return {
'input_ids': input_ids,
'attention_mask': attention_mask,
'tokens': tokens
}
def predict(self, task, text, hypothesis='', maskID = None):
print(task,text,hypothesis)
if task == 'mlm':
# Tokenize and find [MASK] position
print('Tokenize and find [MASK] position')
inputs = self.tokenizer(text, return_tensors='pt', padding=False, truncation=True)
if maskID is not None and 0 <= maskID < inputs['input_ids'].size(1):
inputs['input_ids'][0][maskID] = self.tokenizer.mask_token_id
mask_index = maskID
else:
raise ValueError(f"Invalid maskID {maskID} for input length {inputs['input_ids'].size(1)}")
# Move to device
inputs = {k: v.to(self.device) for k, v in inputs.items()}
# Get embeddings
embedding_layer = self.model.bert.embeddings.word_embeddings
inputs_embeds = embedding_layer(inputs['input_ids'])
# Forward through BERT encoder
hidden_states = self.model.bert(inputs_embeds=inputs_embeds,
attention_mask=inputs['attention_mask']).last_hidden_state
# Predict logits via MLM head
logits = self.model.cls(hidden_states)
mask_logits = logits[0, mask_index]
top_probs, top_indices = torch.topk(mask_logits, k=10, dim=-1)
top_probs = F.softmax(top_probs, dim=-1)
decoded = self.tokenizer.convert_ids_to_tokens(top_indices.tolist())
return decoded, top_probs
elif task == 'sst':
print('input')
inputs = self.tokenizer(text, return_tensors='pt', padding=False, truncation=True).to(self.device)
print('output')
with torch.no_grad():
outputs = self.model(**inputs)
logits = outputs.logits # shape: [1, 2]
probs = F.softmax(logits, dim=1).squeeze()
labels = ["negative", "positive"]
print('ready to return')
return labels, probs
elif task == 'mnli':
inputs = self.tokenizer(text, hypothesis, return_tensors='pt', padding=True, truncation=True).to(self.device)
with torch.no_grad():
outputs = self.model(**inputs)
logits = outputs.logits
probs = F.softmax(logits, dim=1).squeeze()
labels = ["entailment", "neutral", "contradiction"]
return labels, probs
def get_all_grad_attn_matrix(self, task, sentence, hypothesis='', maskID = 0):
print('GET GRAD:', task,'sentence',sentence, 'hypothesis', hypothesis)
print('Tokenize')
if task == 'mnli':
inputs = self.tokenizer(sentence, hypothesis, return_tensors='pt', padding=False, truncation=True)
elif task == 'mlm':
inputs = self.tokenizer(sentence, return_tensors='pt', padding=False, truncation=True)
if maskID is not None and 0 <= maskID < inputs['input_ids'].size(1):
inputs['input_ids'][0][maskID] = self.tokenizer.mask_token_id
else:
raise ValueError(f"Invalid maskID {maskID} for input length {inputs['input_ids'].size(1)}")
else:
inputs = self.tokenizer(sentence, return_tensors='pt', padding=False, truncation=True)
tokens = self.tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
inputs = {k: v.to(self.device) for k, v in inputs.items()}
print(inputs['input_ids'].shape)
print(tokens,len(tokens))
print('Input embeddings with grad')
embedding_layer = self.model.bert.embeddings.word_embeddings
inputs_embeds = embedding_layer(inputs["input_ids"])
inputs_embeds.requires_grad_()
print('Forward pass')
outputs = self.model.bert(
inputs_embeds=inputs_embeds,
attention_mask=inputs["attention_mask"],
output_attentions=True
)
attentions = outputs.attentions # list of [1, heads, seq, seq]
print('Average attentions per layer')
mean_attns = [a.squeeze(0).mean(dim=0).detach().cpu() for a in attentions]
def scalar_outputs(inputs_embeds):
outputs = self.model.bert(
inputs_embeds=inputs_embeds,
attention_mask=inputs["attention_mask"],
output_attentions=True
)
attentions = outputs.attentions
attentions_condensed = [a.mean(dim=0).mean(dim=0).sum(dim=0) for a in attentions]
attentions_condensed= torch.vstack(attentions_condensed)
return attentions_condensed
start = time.time()
jac = torch.autograd.functional.jacobian(scalar_outputs, inputs_embeds).to(torch.float16)
print('time to get jacobian: ', time.time()-start)
jac = jac.norm(dim=-1).squeeze(dim=2)
seq_len = jac.shape[0]
grad_matrices_all = [jac[ii,:,:].tolist() for ii in range(seq_len)]
attn_matrices_all = []
for target_layer in range(len(attentions)):
#grad_matrix, attn_matrix = self.get_grad_attn_matrix(inputs_embeds, attentions, mean_attns, target_layer)
attn_matrix = mean_attns[target_layer]
seq_len = attn_matrix.shape[0]
attn_matrix = attn_matrix[:seq_len, :seq_len]
attn_matrices_all.append(attn_matrix.tolist())
return grad_matrices_all, attn_matrices_all
if __name__ == "__main__":
import sys
MODEL_CLASSES = {
"bert": BERTVisualizer,
"roberta": RoBERTaVisualizer,
"distilbert": DistilBERTVisualizer,
"bart": BARTVisualizer,
}
# Parse command-line args or fallback to default
model_name = sys.argv[1] if len(sys.argv) > 1 else "bert"
text = " ".join(sys.argv[2:]) if len(sys.argv) > 2 else "The quick brown fox jumps over the lazy dog."
if model_name.lower() not in MODEL_CLASSES:
print(f"Supported models: {list(MODEL_CLASSES.keys())}")
sys.exit(1)
# Instantiate the visualizer
visualizer_class = MODEL_CLASSES[model_name.lower()]
visualizer = visualizer_class()
# Tokenize
token_info = visualizer.tokenize(text)
# Report
print(f"\nModel: {model_name}")
print(f"Num attention layers: {visualizer.num_attention_layers}")
print(f"Tokens: {token_info['tokens']}")
print(f"Input IDs: {token_info['input_ids'].tolist()}")
print(f"Attention mask: {token_info['attention_mask'].tolist()}")
"""
usage for debug:
python your_file.py bert "The rain in Spain falls mainly on the plain."
"""