SamanthaStorm commited on
Commit
cc98c96
·
verified ·
1 Parent(s): 4472a1d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -16
app.py CHANGED
@@ -2,17 +2,16 @@ import gradio as gr
2
  import torch
3
  import numpy as np
4
  from transformers import AutoModelForSequenceClassification, AutoTokenizer
 
5
 
6
- # Load both sentiment and abuse models from the same Hugging Face repo
7
- repo_id = "SamanthaStorm/Tether"
 
8
 
9
- # Load fine-tuned sentiment model
10
- sentiment_model = AutoModelForSequenceClassification.from_pretrained(repo_id, subfolder="sentiment")
11
- sentiment_tokenizer = AutoTokenizer.from_pretrained(repo_id, subfolder="sentiment")
12
-
13
- # Load abuse detection model
14
- abuse_model = AutoModelForSequenceClassification.from_pretrained(repo_id, subfolder="abuse")
15
- abuse_tokenizer = AutoTokenizer.from_pretrained(repo_id, subfolder="abuse")
16
 
17
  LABELS = [
18
  "gaslighting", "mockery", "dismissiveness", "control", "guilt_tripping", "apology_baiting", "blame_shifting", "projection",
@@ -82,9 +81,9 @@ def analyze_messages(input_text, risk_flags):
82
 
83
  adjusted_thresholds = {k: v * 0.8 for k, v in THRESHOLDS.items()} if sentiment_label == "NEGATIVE" else THRESHOLDS.copy()
84
 
85
- inputs = abuse_tokenizer(input_text, return_tensors="pt", truncation=True, padding=True)
86
  with torch.no_grad():
87
- outputs = abuse_model(**inputs)
88
  scores = torch.sigmoid(outputs.logits.squeeze(0)).numpy()
89
 
90
  pattern_count = sum(score > adjusted_thresholds[label] for label, score in zip(PATTERN_LABELS, scores[:15]))
@@ -104,14 +103,14 @@ def analyze_messages(input_text, risk_flags):
104
  abuse_level = calculate_abuse_level(scores, THRESHOLDS)
105
  abuse_description = interpret_abuse_level(abuse_level)
106
 
107
- resources = (
108
- "Immediate assistance recommended. Please seek professional help or contact emergency services."
109
- if danger_flag_count >= 2 else
110
- "For more information on abuse patterns, consider reaching out to support groups or professional counselors."
111
- )
112
 
113
  scored_patterns = [(label, score) for label, score in zip(PATTERN_LABELS, scores[:15])]
114
  top_patterns = sorted(scored_patterns, key=lambda x: x[1], reverse=True)[:2]
 
115
  top_pattern_explanations = "\n".join([
116
  f"\u2022 {label.replace('_', ' ').title()}: {EXPLANATIONS.get(label, 'No explanation available.')}"
117
  for label, _ in top_patterns
 
2
  import torch
3
  import numpy as np
4
  from transformers import AutoModelForSequenceClassification, AutoTokenizer
5
+ from transformers import RobertaForSequenceClassification, RobertaTokenizer
6
 
7
+ # Load fine-tuned sentiment model from Hugging Face
8
+ sentiment_model = AutoModelForSequenceClassification.from_pretrained("SamanthaStorm/tether-sentiment")
9
+ sentiment_tokenizer = AutoTokenizer.from_pretrained("SamanthaStorm/tether-sentiment")
10
 
11
+ # Load abuse pattern model
12
+ model_name = "SamanthaStorm/Tether"
13
+ model = RobertaForSequenceClassification.from_pretrained(model_name, trust_remote_code=True)
14
+ tokenizer = RobertaTokenizer.from_pretrained(model_name, trust_remote_code=True)
 
 
 
15
 
16
  LABELS = [
17
  "gaslighting", "mockery", "dismissiveness", "control", "guilt_tripping", "apology_baiting", "blame_shifting", "projection",
 
81
 
82
  adjusted_thresholds = {k: v * 0.8 for k, v in THRESHOLDS.items()} if sentiment_label == "NEGATIVE" else THRESHOLDS.copy()
83
 
84
+ inputs = tokenizer(input_text, return_tensors="pt", truncation=True, padding=True)
85
  with torch.no_grad():
86
+ outputs = model(**inputs)
87
  scores = torch.sigmoid(outputs.logits.squeeze(0)).numpy()
88
 
89
  pattern_count = sum(score > adjusted_thresholds[label] for label, score in zip(PATTERN_LABELS, scores[:15]))
 
103
  abuse_level = calculate_abuse_level(scores, THRESHOLDS)
104
  abuse_description = interpret_abuse_level(abuse_level)
105
 
106
+ if danger_flag_count >= 2:
107
+ resources = "Immediate assistance recommended. Please seek professional help or contact emergency services."
108
+ else:
109
+ resources = "For more information on abuse patterns, consider reaching out to support groups or professional counselors."
 
110
 
111
  scored_patterns = [(label, score) for label, score in zip(PATTERN_LABELS, scores[:15])]
112
  top_patterns = sorted(scored_patterns, key=lambda x: x[1], reverse=True)[:2]
113
+
114
  top_pattern_explanations = "\n".join([
115
  f"\u2022 {label.replace('_', ' ').title()}: {EXPLANATIONS.get(label, 'No explanation available.')}"
116
  for label, _ in top_patterns