|
import torch |
|
import torch.nn as nn |
|
import os |
|
import json |
|
from transformers import LongformerModel, AutoModel, LongformerTokenizerFast, AutoTokenizer, PreTrainedModel |
|
|
|
class HarmFormer(PreTrainedModel): |
|
def __init__(self, config): |
|
super(HarmFormer, self).__init__(config) |
|
self.num_classes = config.num_classes |
|
self.num_risk_levels = config.num_risk_levels |
|
|
|
|
|
self.base_model = AutoModel.from_config(config) |
|
|
|
|
|
hidden_size = self.base_model.config.hidden_size |
|
|
|
self.classifiers = nn.ModuleList([ |
|
nn.Sequential( |
|
nn.Linear(hidden_size, 128), |
|
nn.ReLU(), |
|
nn.Linear(128, self.num_risk_levels) |
|
) |
|
for _ in range(self.num_classes) |
|
]) |
|
|
|
def forward(self, input_ids=None, attention_mask=None, **kwargs): |
|
outputs = self.base_model(input_ids=input_ids, attention_mask=attention_mask) |
|
pooled_output = outputs[1] |
|
|
|
|
|
logits = [] |
|
for classifier in self.classifiers: |
|
logits.append(classifier(pooled_output)) |
|
|
|
return logits |
|
|
|
@classmethod |
|
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): |
|
|
|
config_path = os.path.join(pretrained_model_name_or_path, "config.json") |
|
if os.path.exists(config_path): |
|
with open(config_path, 'r') as f: |
|
model_config = json.load(f) |
|
else: |
|
|
|
from huggingface_hub import hf_hub_download |
|
config_path = hf_hub_download(repo_id=pretrained_model_name_or_path, filename="config.json") |
|
with open(config_path, 'r') as f: |
|
model_config = json.load(f) |
|
|
|
|
|
from transformers import AutoConfig |
|
base_model_name = model_config.get("model_name", "allenai/longformer-base-4096") |
|
base_config = AutoConfig.from_pretrained(base_model_name) |
|
|
|
|
|
base_config.num_classes = model_config.get("num_classes", 5) |
|
base_config.num_risk_levels = model_config.get("num_risk_levels", 3) |
|
base_config.architecture = model_config.get("architecture", "SingleFC") |
|
|
|
|
|
model = cls(base_config) |
|
|
|
|
|
checkpoint_path = os.path.join(pretrained_model_name_or_path, "pytorch_model.bin") |
|
if os.path.exists(checkpoint_path): |
|
state_dict = torch.load(checkpoint_path, map_location="cpu") |
|
else: |
|
|
|
checkpoint_path = hf_hub_download(repo_id=pretrained_model_name_or_path, filename="pytorch_model.bin") |
|
state_dict = torch.load(checkpoint_path, map_location="cpu") |
|
|
|
model.load_state_dict(state_dict) |
|
model.eval() |
|
|
|
return model |
|
|
|
def predict_batch(model, tokenizer, texts, batch_size=32): |
|
device = next(model.parameters()).device |
|
predictions = [] |
|
|
|
|
|
for i in range(0, len(texts), batch_size): |
|
batch_texts = texts[i:i+batch_size] |
|
inputs = tokenizer( |
|
batch_texts, |
|
add_special_tokens=True, |
|
max_length=1024, |
|
truncation=True, |
|
padding='max_length', |
|
return_attention_mask=True, |
|
return_tensors='pt', |
|
).to(device) |
|
|
|
with torch.no_grad(): |
|
outputs = model(**inputs) |
|
logits = torch.stack(outputs, dim=0).permute(1, 0, 2) |
|
probs = torch.softmax(logits, dim=-1) |
|
batch_preds = [[[round(prob, 3) for prob in class_probs] for class_probs in sample] for sample in probs.cpu().tolist()] |
|
predictions.extend(batch_preds) |
|
|
|
return predictions |
|
|