File size: 1,848 Bytes
fc66fa8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
import torch
from model_loader import ModelLoader

class TextFilterService:
    def __init__(self, model_loader: ModelLoader):
        self.model_loader = model_loader

    def process_text(self, final_text: str) -> dict:
        hf_tokenizer = self.model_loader.hf_tokenizer
        hf_model = self.model_loader.hf_model
        identity_model = self.model_loader.identity_model
        identity_tokenizer = self.model_loader.identity_tokenizer
        device = self.model_loader.device

        # Toxic-BERT inference
        hf_inputs = hf_tokenizer(final_text, return_tensors="pt", padding=True, truncation=True)
        hf_inputs = {k: v.to(device) for k, v in hf_inputs.items()}
        with torch.no_grad():
            hf_outputs = hf_model(**hf_inputs)
        hf_probs = torch.sigmoid(hf_outputs.logits)[0]
        hf_labels = hf_model.config.id2label
        results = {hf_labels.get(i, f"Label {i}"): float(prob) for i, prob in enumerate(hf_probs)}

        # Identity hate classifier
        identity_inputs = identity_tokenizer(final_text, return_tensors="pt", padding=True, truncation=True)
        identity_inputs.pop("token_type_ids", None)
        identity_inputs = {k: v.to(device) for k, v in identity_inputs.items()}
        with torch.no_grad():
            identity_outputs = identity_model(**identity_inputs)
        identity_probs = torch.sigmoid(identity_outputs.logits)
        identity_prob = identity_probs[0][1].item()
        not_identity_prob = identity_probs[0][0].item()

        results["identity_hate_custom"] = identity_prob
        results["not_identity_hate_custom"] = not_identity_prob

        results["safe"] = (
            all(results.get(label, 0) < 0.5 for label in ['toxic', 'severe_toxic', 'obscene', 'insult', 'identity_hate']) 
            and identity_prob < 0.5
        )

        return results