HarmFormer / modeling.py
themendu's picture
Upload HarmFormer
cac9916 verified
import torch
import torch.nn as nn
import os
import json
from transformers import LongformerModel, AutoModel, LongformerTokenizerFast, AutoTokenizer, PreTrainedModel
class HarmFormer(PreTrainedModel):
def __init__(self, config):
super(HarmFormer, self).__init__(config)
self.num_classes = config.num_classes
self.num_risk_levels = config.num_risk_levels
# Base model
self.base_model = AutoModel.from_config(config)
# Classification heads
hidden_size = self.base_model.config.hidden_size
self.classifiers = nn.ModuleList([
nn.Sequential(
nn.Linear(hidden_size, 128),
nn.ReLU(),
nn.Linear(128, self.num_risk_levels)
)
for _ in range(self.num_classes)
])
def forward(self, input_ids=None, attention_mask=None, **kwargs):
outputs = self.base_model(input_ids=input_ids, attention_mask=attention_mask)
pooled_output = outputs[1] # Pooled [CLS] token output
# Apply classifiers for each task
logits = []
for classifier in self.classifiers:
logits.append(classifier(pooled_output))
return logits
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
# Load config
config_path = os.path.join(pretrained_model_name_or_path, "config.json")
if os.path.exists(config_path):
with open(config_path, 'r') as f:
model_config = json.load(f)
else:
# Try to load from HF Hub
from huggingface_hub import hf_hub_download
config_path = hf_hub_download(repo_id=pretrained_model_name_or_path, filename="config.json")
with open(config_path, 'r') as f:
model_config = json.load(f)
# Create base model config
from transformers import AutoConfig
base_model_name = model_config.get("model_name", "allenai/longformer-base-4096")
base_config = AutoConfig.from_pretrained(base_model_name)
# Add our custom attributes
base_config.num_classes = model_config.get("num_classes", 5)
base_config.num_risk_levels = model_config.get("num_risk_levels", 3)
base_config.architecture = model_config.get("architecture", "SingleFC")
# Create model
model = cls(base_config)
# Load weights
checkpoint_path = os.path.join(pretrained_model_name_or_path, "pytorch_model.bin")
if os.path.exists(checkpoint_path):
state_dict = torch.load(checkpoint_path, map_location="cpu")
else:
# Try to load from HF Hub
checkpoint_path = hf_hub_download(repo_id=pretrained_model_name_or_path, filename="pytorch_model.bin")
state_dict = torch.load(checkpoint_path, map_location="cpu")
model.load_state_dict(state_dict)
model.eval()
return model
def predict_batch(model, tokenizer, texts, batch_size=32):
device = next(model.parameters()).device
predictions = []
# Process in batches to avoid OOM
for i in range(0, len(texts), batch_size):
batch_texts = texts[i:i+batch_size]
inputs = tokenizer(
batch_texts,
add_special_tokens=True,
max_length=1024,
truncation=True,
padding='max_length',
return_attention_mask=True,
return_tensors='pt',
).to(device)
with torch.no_grad():
outputs = model(**inputs)
logits = torch.stack(outputs, dim=0).permute(1, 0, 2) # (batch_size, num_classes, num_risk_levels)
probs = torch.softmax(logits, dim=-1)
batch_preds = [[[round(prob, 3) for prob in class_probs] for class_probs in sample] for sample in probs.cpu().tolist()]
predictions.extend(batch_preds)
return predictions