import torch import torch.nn as nn import os import json from transformers import LongformerModel, AutoModel, LongformerTokenizerFast, AutoTokenizer, PreTrainedModel class HarmFormer(PreTrainedModel): def __init__(self, config): super(HarmFormer, self).__init__(config) self.num_classes = config.num_classes self.num_risk_levels = config.num_risk_levels # Base model self.base_model = AutoModel.from_config(config) # Classification heads hidden_size = self.base_model.config.hidden_size self.classifiers = nn.ModuleList([ nn.Sequential( nn.Linear(hidden_size, 128), nn.ReLU(), nn.Linear(128, self.num_risk_levels) ) for _ in range(self.num_classes) ]) def forward(self, input_ids=None, attention_mask=None, **kwargs): outputs = self.base_model(input_ids=input_ids, attention_mask=attention_mask) pooled_output = outputs[1] # Pooled [CLS] token output # Apply classifiers for each task logits = [] for classifier in self.classifiers: logits.append(classifier(pooled_output)) return logits @classmethod def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): # Load config config_path = os.path.join(pretrained_model_name_or_path, "config.json") if os.path.exists(config_path): with open(config_path, 'r') as f: model_config = json.load(f) else: # Try to load from HF Hub from huggingface_hub import hf_hub_download config_path = hf_hub_download(repo_id=pretrained_model_name_or_path, filename="config.json") with open(config_path, 'r') as f: model_config = json.load(f) # Create base model config from transformers import AutoConfig base_model_name = model_config.get("model_name", "allenai/longformer-base-4096") base_config = AutoConfig.from_pretrained(base_model_name) # Add our custom attributes base_config.num_classes = model_config.get("num_classes", 5) base_config.num_risk_levels = model_config.get("num_risk_levels", 3) base_config.architecture = model_config.get("architecture", "SingleFC") # Create model model = cls(base_config) # Load weights checkpoint_path = os.path.join(pretrained_model_name_or_path, "pytorch_model.bin") if os.path.exists(checkpoint_path): state_dict = torch.load(checkpoint_path, map_location="cpu") else: # Try to load from HF Hub checkpoint_path = hf_hub_download(repo_id=pretrained_model_name_or_path, filename="pytorch_model.bin") state_dict = torch.load(checkpoint_path, map_location="cpu") model.load_state_dict(state_dict) model.eval() return model def predict_batch(model, tokenizer, texts, batch_size=32): device = next(model.parameters()).device predictions = [] # Process in batches to avoid OOM for i in range(0, len(texts), batch_size): batch_texts = texts[i:i+batch_size] inputs = tokenizer( batch_texts, add_special_tokens=True, max_length=1024, truncation=True, padding='max_length', return_attention_mask=True, return_tensors='pt', ).to(device) with torch.no_grad(): outputs = model(**inputs) logits = torch.stack(outputs, dim=0).permute(1, 0, 2) # (batch_size, num_classes, num_risk_levels) probs = torch.softmax(logits, dim=-1) batch_preds = [[[round(prob, 3) for prob in class_probs] for class_probs in sample] for sample in probs.cpu().tolist()] predictions.extend(batch_preds) return predictions