Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -2,17 +2,16 @@ import gradio as gr
|
|
2 |
import torch
|
3 |
import numpy as np
|
4 |
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
|
|
5 |
|
6 |
-
# Load
|
7 |
-
|
|
|
8 |
|
9 |
-
# Load
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
# Load abuse detection model
|
14 |
-
abuse_model = AutoModelForSequenceClassification.from_pretrained(repo_id, subfolder="abuse")
|
15 |
-
abuse_tokenizer = AutoTokenizer.from_pretrained(repo_id, subfolder="abuse")
|
16 |
|
17 |
LABELS = [
|
18 |
"gaslighting", "mockery", "dismissiveness", "control", "guilt_tripping", "apology_baiting", "blame_shifting", "projection",
|
@@ -82,9 +81,9 @@ def analyze_messages(input_text, risk_flags):
|
|
82 |
|
83 |
adjusted_thresholds = {k: v * 0.8 for k, v in THRESHOLDS.items()} if sentiment_label == "NEGATIVE" else THRESHOLDS.copy()
|
84 |
|
85 |
-
inputs =
|
86 |
with torch.no_grad():
|
87 |
-
outputs =
|
88 |
scores = torch.sigmoid(outputs.logits.squeeze(0)).numpy()
|
89 |
|
90 |
pattern_count = sum(score > adjusted_thresholds[label] for label, score in zip(PATTERN_LABELS, scores[:15]))
|
@@ -104,14 +103,14 @@ def analyze_messages(input_text, risk_flags):
|
|
104 |
abuse_level = calculate_abuse_level(scores, THRESHOLDS)
|
105 |
abuse_description = interpret_abuse_level(abuse_level)
|
106 |
|
107 |
-
|
108 |
-
"Immediate assistance recommended. Please seek professional help or contact emergency services."
|
109 |
-
|
110 |
-
"For more information on abuse patterns, consider reaching out to support groups or professional counselors."
|
111 |
-
)
|
112 |
|
113 |
scored_patterns = [(label, score) for label, score in zip(PATTERN_LABELS, scores[:15])]
|
114 |
top_patterns = sorted(scored_patterns, key=lambda x: x[1], reverse=True)[:2]
|
|
|
115 |
top_pattern_explanations = "\n".join([
|
116 |
f"\u2022 {label.replace('_', ' ').title()}: {EXPLANATIONS.get(label, 'No explanation available.')}"
|
117 |
for label, _ in top_patterns
|
|
|
2 |
import torch
|
3 |
import numpy as np
|
4 |
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
5 |
+
from transformers import RobertaForSequenceClassification, RobertaTokenizer
|
6 |
|
7 |
+
# Load fine-tuned sentiment model from Hugging Face
|
8 |
+
sentiment_model = AutoModelForSequenceClassification.from_pretrained("SamanthaStorm/tether-sentiment")
|
9 |
+
sentiment_tokenizer = AutoTokenizer.from_pretrained("SamanthaStorm/tether-sentiment")
|
10 |
|
11 |
+
# Load abuse pattern model
|
12 |
+
model_name = "SamanthaStorm/Tether"
|
13 |
+
model = RobertaForSequenceClassification.from_pretrained(model_name, trust_remote_code=True)
|
14 |
+
tokenizer = RobertaTokenizer.from_pretrained(model_name, trust_remote_code=True)
|
|
|
|
|
|
|
15 |
|
16 |
LABELS = [
|
17 |
"gaslighting", "mockery", "dismissiveness", "control", "guilt_tripping", "apology_baiting", "blame_shifting", "projection",
|
|
|
81 |
|
82 |
adjusted_thresholds = {k: v * 0.8 for k, v in THRESHOLDS.items()} if sentiment_label == "NEGATIVE" else THRESHOLDS.copy()
|
83 |
|
84 |
+
inputs = tokenizer(input_text, return_tensors="pt", truncation=True, padding=True)
|
85 |
with torch.no_grad():
|
86 |
+
outputs = model(**inputs)
|
87 |
scores = torch.sigmoid(outputs.logits.squeeze(0)).numpy()
|
88 |
|
89 |
pattern_count = sum(score > adjusted_thresholds[label] for label, score in zip(PATTERN_LABELS, scores[:15]))
|
|
|
103 |
abuse_level = calculate_abuse_level(scores, THRESHOLDS)
|
104 |
abuse_description = interpret_abuse_level(abuse_level)
|
105 |
|
106 |
+
if danger_flag_count >= 2:
|
107 |
+
resources = "Immediate assistance recommended. Please seek professional help or contact emergency services."
|
108 |
+
else:
|
109 |
+
resources = "For more information on abuse patterns, consider reaching out to support groups or professional counselors."
|
|
|
110 |
|
111 |
scored_patterns = [(label, score) for label, score in zip(PATTERN_LABELS, scores[:15])]
|
112 |
top_patterns = sorted(scored_patterns, key=lambda x: x[1], reverse=True)[:2]
|
113 |
+
|
114 |
top_pattern_explanations = "\n".join([
|
115 |
f"\u2022 {label.replace('_', ' ').title()}: {EXPLANATIONS.get(label, 'No explanation available.')}"
|
116 |
for label, _ in top_patterns
|