TetherSST / app.py
SamanthaStorm's picture
Update app.py
87dbc4d verified
raw
history blame
4.2 kB
import gradio as gr
import torch
from transformers import pipeline as hf_pipeline, AutoModelForSequenceClassification, AutoTokenizer
# ——— 1) Emotion Pipeline ————————————————————————————————————————————————
emotion_pipeline = hf_pipeline(
"text-classification",
model="j-hartmann/emotion-english-distilroberta-base",
top_k=None,
truncation=True
)
def get_emotion_profile(text):
"""
Returns a dict of { emotion_label: score } for the input text.
"""
results = emotion_pipeline(text)
# some pipelines return [[…]]
if isinstance(results, list) and isinstance(results[0], list):
results = results[0]
return {r["label"].lower(): round(r["score"], 3) for r in results}
# ——— 2) Abuse-patterns Model ——————————————————————————————————————————————
model_name = "SamanthaStorm/tether-multilabel-v3"
model = AutoModelForSequenceClassification.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
LABELS = [
"blame shifting", "contradictory statements", "control", "dismissiveness",
"gaslighting", "guilt tripping", "insults", "obscure language",
"projection", "recovery phase", "threat"
]
THRESHOLDS = {
"blame shifting": 0.28, "contradictory statements": 0.27, "control": 0.08, "dismissiveness": 0.32,
"gaslighting": 0.27, "guilt tripping": 0.31, "insults": 0.10, "obscure language": 0.55,
"projection": 0.09, "recovery phase": 0.33, "threat": 0.15
}
# ——— 3) Single-message analysis ——————————————————————————————————————————————
def analyze_message(text):
"""
Runs motif detection, emotion profiling, and the abuse-pattern classifier.
Returns a dict with:
- matched_phrases: list of raw regex hits
- emotion_profile: { emotion: score }
- active_patterns: [ labels above their threshold ]
"""
motif_hits, matched_phrases = detect_motifs(text)
emotion_profile = get_emotion_profile(text)
# get raw model scores
toks = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
with torch.no_grad():
logits = model(**toks).logits.squeeze(0)
scores = torch.sigmoid(logits).cpu().numpy()
# pick up all labels whose score >= threshold
active = [lab for lab, sc in zip(LABELS, scores) if sc >= THRESHOLDS[lab]]
return {
"matched_phrases": matched_phrases,
"emotion_profile": emotion_profile,
"active_patterns": active
}
# ——— 4) Composite wrapper (for multiple inputs) ———————————————————————————————————
def analyze_composite(*texts):
"""
Accept multiple text inputs, run analyze_message on each, and
return a single formatted string.
"""
outputs = []
for idx, txt in enumerate(texts, start=1):
if not txt:
continue
r = analyze_message(txt)
block = (
f"── Message {idx} ──\n"
f"Matched Phrases: {r['matched_phrases']}\n"
f"Emotion Profile: {r['emotion_profile']}\n"
f"Active Patterns: {r['active_patterns']}\n"
)
outputs.append(block)
return "\n".join(outputs) if outputs else "Please enter at least one message."
# ——— 5) Gradio interface ————————————————————————————————————————————————
# adjust how many message inputs you want here:
message_inputs = [gr.Textbox(label=f"Message {i+1}") for i in range(3)]
iface = gr.Interface(
fn=analyze_composite,
inputs=message_inputs,
outputs=gr.Textbox(label="Analysis"),
title="Tether Analyzer (no abuse-score / no DARVO)",
description="Detects motifs, emotions, and active abuse-patterns only."
)
if __name__ == "__main__":
iface.launch()