SamanthaStorm commited on
Commit
f4d2a9a
·
verified ·
1 Parent(s): 834f0ff

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -6
app.py CHANGED
@@ -2,7 +2,7 @@ import gradio as gr
2
  import re
3
  from transformers import pipeline as hf_pipeline
4
 
5
- # Load models
6
  sst_classifier = hf_pipeline(
7
  "text-classification",
8
  model="SamanthaStorm/tether-sst",
@@ -10,6 +10,7 @@ sst_classifier = hf_pipeline(
10
  truncation=True
11
  )
12
 
 
13
  emotion_pipeline = hf_pipeline(
14
  "text-classification",
15
  model="j-hartmann/emotion-english-distilroberta-base",
@@ -17,14 +18,30 @@ emotion_pipeline = hf_pipeline(
17
  truncation=True
18
  )
19
 
20
- # Functions
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  def get_emotion_profile(text):
22
  emotions = emotion_pipeline(text)
23
  if isinstance(emotions, list) and isinstance(emotions[0], list):
24
  emotions = emotions[0]
25
  return {e['label'].lower(): round(e['score'], 3) for e in emotions}
26
 
27
- def get_emotional_tone_tag(emotions, sentiment, patterns, abuse_score):
 
28
  sadness = emotions.get("sadness", 0)
29
  joy = emotions.get("joy", 0)
30
  neutral = emotions.get("neutral", 0)
@@ -77,17 +94,22 @@ def get_emotional_tone_tag(emotions, sentiment, patterns, abuse_score):
77
 
78
  return None
79
 
80
- # Main interface function
81
  def analyze_message(text):
 
82
  sst_result = sst_classifier(text)[0]
83
  sentiment_label = "supportive" if sst_result["label"] == "LABEL_0" else "undermining"
84
  sentiment_score = round(sst_result["score"] * 100, 2)
85
 
 
86
  emotions = get_emotion_profile(text)
87
  emotion_summary = "\n".join([f"{k.title()}: {v:.2f}" for k, v in emotions.items()])
88
 
 
 
89
 
90
- tone_tag = get_emotional_tone_tag(emotions, sentiment_label, patterns, abuse_score)
 
91
  tone_output = tone_tag if tone_tag else "None detected"
92
 
93
  return (
@@ -96,7 +118,7 @@ def analyze_message(text):
96
  f"🔍 Tone Tag: {tone_output}"
97
  )
98
 
99
- # Gradio app
100
  iface = gr.Interface(
101
  fn=analyze_message,
102
  inputs=gr.Textbox(lines=4, placeholder="Paste a message here..."),
 
2
  import re
3
  from transformers import pipeline as hf_pipeline
4
 
5
+ # Load SST sentiment classifier
6
  sst_classifier = hf_pipeline(
7
  "text-classification",
8
  model="SamanthaStorm/tether-sst",
 
10
  truncation=True
11
  )
12
 
13
+ # Load emotion classifier
14
  emotion_pipeline = hf_pipeline(
15
  "text-classification",
16
  model="j-hartmann/emotion-english-distilroberta-base",
 
18
  truncation=True
19
  )
20
 
21
+ # Load Tether's abuse pattern classifier
22
+ abuse_classifier = hf_pipeline(
23
+ "text-classification",
24
+ model="SamanthaStorm/tether-multilabel-v2",
25
+ top_k=None,
26
+ truncation=True
27
+ )
28
+
29
+ # Extract abuse pattern labels
30
+ def get_abuse_patterns(text, threshold=0.5):
31
+ results = abuse_classifier(text)
32
+ if isinstance(results, list) and isinstance(results[0], list):
33
+ results = results[0]
34
+ return [r['label'] for r in results if r['score'] >= threshold]
35
+
36
+ # Emotion extraction
37
  def get_emotion_profile(text):
38
  emotions = emotion_pipeline(text)
39
  if isinstance(emotions, list) and isinstance(emotions[0], list):
40
  emotions = emotions[0]
41
  return {e['label'].lower(): round(e['score'], 3) for e in emotions}
42
 
43
+ # Emotional tone tagging based on behavior + emotion + sentiment
44
+ def get_emotional_tone_tag(emotions, sentiment, patterns, abuse_score=0):
45
  sadness = emotions.get("sadness", 0)
46
  joy = emotions.get("joy", 0)
47
  neutral = emotions.get("neutral", 0)
 
94
 
95
  return None
96
 
97
+ # Analysis logic
98
  def analyze_message(text):
99
+ # Sentiment
100
  sst_result = sst_classifier(text)[0]
101
  sentiment_label = "supportive" if sst_result["label"] == "LABEL_0" else "undermining"
102
  sentiment_score = round(sst_result["score"] * 100, 2)
103
 
104
+ # Emotions
105
  emotions = get_emotion_profile(text)
106
  emotion_summary = "\n".join([f"{k.title()}: {v:.2f}" for k, v in emotions.items()])
107
 
108
+ # Abuse patterns (used internally)
109
+ patterns = get_abuse_patterns(text)
110
 
111
+ # Tone tag
112
+ tone_tag = get_emotional_tone_tag(emotions, sentiment_label, patterns)
113
  tone_output = tone_tag if tone_tag else "None detected"
114
 
115
  return (
 
118
  f"🔍 Tone Tag: {tone_output}"
119
  )
120
 
121
+ # Gradio UI
122
  iface = gr.Interface(
123
  fn=analyze_message,
124
  inputs=gr.Textbox(lines=4, placeholder="Paste a message here..."),