|
import gradio as gr |
|
import torch |
|
from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline |
|
import numpy as np |
|
import json |
|
from datetime import datetime |
|
import logging |
|
import os |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
class FixedMultiAgentSystem: |
|
def __init__(self): |
|
self.detection_agent = None |
|
self.counter_speech_agent = None |
|
self.moderation_agent = None |
|
self.sentiment_agent = None |
|
|
|
|
|
self.counter_speech_prompts = self.load_prompts("counter_speech_prompts.json") |
|
self.moderation_prompts = self.load_prompts("moderation_prompts.json") |
|
|
|
self.initialize_agents() |
|
|
|
def load_prompts(self, filename): |
|
"""Load prompts from JSON file with robust fallback""" |
|
try: |
|
if os.path.exists(filename): |
|
with open(filename, 'r', encoding='utf-8') as f: |
|
return json.load(f) |
|
else: |
|
logger.warning(f"Prompt file {filename} not found, using built-in prompts") |
|
return self.get_default_prompts(filename) |
|
except Exception as e: |
|
logger.error(f"Error loading prompts from {filename}: {e}") |
|
return self.get_default_prompts(filename) |
|
|
|
def get_default_prompts(self, filename): |
|
"""Comprehensive default prompts as fallback""" |
|
if "counter_speech" in filename: |
|
return { |
|
"counter_speech_prompts": { |
|
"high_risk": { |
|
"system_prompt": "You are an expert educator specializing in counter-speech and conflict de-escalation.", |
|
"user_prompt_template": "Generate a respectful, educational counter-speech response to address harmful content while promoting understanding. Original text (Risk: {risk_level}, Confidence: {confidence}%, Sentiment: {sentiment}): \"{original_text}\"\n\nProvide a constructive response that educates without attacking:", |
|
"fallback_responses": [ |
|
"This type of language can cause real harm to individuals and communities. Consider expressing your concerns in a way that respects everyone's dignity and opens constructive dialogue.", |
|
"Instead of divisive language, try focusing on shared values and common ground. Everyone deserves respect regardless of their background.", |
|
"Strong communities are built on mutual respect and understanding. How can we work together rather than against each other?" |
|
] |
|
}, |
|
"medium_risk": { |
|
"fallback_responses": [ |
|
"This message might be interpreted as harmful by some. Consider rephrasing to express your thoughts more constructively.", |
|
"Try framing your message to invite discussion rather than potentially excluding others.", |
|
"How might you express this sentiment in a way that brings people together rather than apart?" |
|
] |
|
}, |
|
"low_risk": { |
|
"fallback_responses": [ |
|
"While this seems mostly positive, consider how your words might be received by everyone in the conversation.", |
|
"Every interaction is a chance to build understanding and connection.", |
|
"Consider how you can use your voice to create an even more welcoming environment." |
|
] |
|
}, |
|
"general_template": { |
|
"fallback_responses": [ |
|
"Thank you for sharing your thoughts. Building strong communities works best when we focus on shared values and constructive dialogue.", |
|
"I appreciate your perspective. Sometimes our strongest feelings can be expressed in ways that bring people together.", |
|
"Your engagement with this topic is clear. When we channel that energy into inclusive dialogue, we often find solutions that work for everyone." |
|
] |
|
} |
|
} |
|
} |
|
else: |
|
return { |
|
"moderation_prompts": { |
|
"comprehensive_analysis": { |
|
"system_prompt": "You are an expert content moderation specialist analyzing text for safety and compliance.", |
|
"user_prompt_template": "Analyze this text for potential violations: \"{text}\"\n\nProvide brief analysis: 1) Safety level 2) Main concerns 3) Recommended action\n\nAnalysis:", |
|
} |
|
} |
|
} |
|
|
|
def initialize_agents(self): |
|
"""Initialize all AI agents with proper error handling""" |
|
logger.info("🤖 Initializing Fixed Multi-Agent System...") |
|
|
|
self.setup_detection_agent() |
|
self.setup_lightweight_agents() |
|
|
|
logger.info("✅ All agents initialized successfully!") |
|
|
|
def setup_detection_agent(self): |
|
"""Initialize the hate speech detection agent with proper label handling""" |
|
try: |
|
logger.info("🔍 Loading Detection Agent (Fine-tuned DistilBERT)...") |
|
model_path = "./model" |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_path) |
|
model = AutoModelForSequenceClassification.from_pretrained( |
|
model_path, |
|
torch_dtype=torch.float32 |
|
) |
|
|
|
self.detection_agent = pipeline( |
|
"text-classification", |
|
model=model, |
|
tokenizer=tokenizer, |
|
return_all_scores=True, |
|
device=0 if torch.cuda.is_available() else -1 |
|
) |
|
|
|
|
|
self.test_model_labels() |
|
|
|
logger.info("✅ Detection Agent loaded successfully") |
|
|
|
except Exception as e: |
|
logger.error(f"❌ Detection Agent failed: {e}") |
|
logger.info("🔄 Using fallback detection model...") |
|
self.detection_agent = pipeline( |
|
"text-classification", |
|
model="unitary/toxic-bert", |
|
return_all_scores=True |
|
) |
|
self.model_label_mapping = {"TOXIC": "hate", "NORMAL": "normal"} |
|
|
|
def test_model_labels(self): |
|
"""Test model to understand its label mapping""" |
|
try: |
|
|
|
safe_text = "I love sunny days and happy people." |
|
results = self.detection_agent(safe_text) |
|
|
|
if isinstance(results, list) and len(results) > 0: |
|
if isinstance(results[0], list): |
|
results = results[0] |
|
|
|
|
|
max_result = max(results, key=lambda x: x['score']) |
|
safe_label = max_result['label'] |
|
|
|
|
|
if safe_label in ['LABEL_0', '0']: |
|
self.model_label_mapping = {"LABEL_0": "normal", "LABEL_1": "hate"} |
|
self.hate_label = "LABEL_1" |
|
self.normal_label = "LABEL_0" |
|
elif safe_label in ['LABEL_1', '1']: |
|
self.model_label_mapping = {"LABEL_0": "hate", "LABEL_1": "normal"} |
|
self.hate_label = "LABEL_0" |
|
self.normal_label = "LABEL_1" |
|
else: |
|
|
|
self.model_label_mapping = {safe_label: "normal"} |
|
self.normal_label = safe_label |
|
|
|
other_labels = [r['label'] for r in results if r['label'] != safe_label] |
|
if other_labels: |
|
self.hate_label = other_labels[0] |
|
self.model_label_mapping[self.hate_label] = "hate" |
|
|
|
logger.info(f"Model label mapping determined: {self.model_label_mapping}") |
|
logger.info(f"Normal label: {self.normal_label}, Hate label: {self.hate_label}") |
|
|
|
except Exception as e: |
|
logger.error(f"Error testing model labels: {e}") |
|
|
|
self.model_label_mapping = {"LABEL_0": "normal", "LABEL_1": "hate"} |
|
self.hate_label = "LABEL_1" |
|
self.normal_label = "LABEL_0" |
|
|
|
def setup_lightweight_agents(self): |
|
"""Setup only essential additional agents to reduce load time""" |
|
try: |
|
logger.info("📊 Loading Lightweight Sentiment Agent...") |
|
self.sentiment_agent = pipeline( |
|
"sentiment-analysis", |
|
model="cardiffnlp/twitter-roberta-base-sentiment-latest", |
|
return_all_scores=True, |
|
device=0 if torch.cuda.is_available() else -1 |
|
) |
|
logger.info("✅ Sentiment Agent loaded") |
|
|
|
|
|
logger.info("💬 Using template-based counter-speech (fast mode)") |
|
self.counter_speech_agent = None |
|
self.moderation_agent = None |
|
|
|
except Exception as e: |
|
logger.error(f"❌ Lightweight agents failed: {e}") |
|
self.sentiment_agent = None |
|
|
|
def detect_hate_speech(self, text): |
|
"""Fixed detection with proper label interpretation""" |
|
if not text or not text.strip(): |
|
return { |
|
"status": "❌ Please enter some text to analyze.", |
|
"prediction": "No input", |
|
"confidence": 0.0, |
|
"all_scores": {}, |
|
"risk_level": "Unknown", |
|
"is_hate_speech": False |
|
} |
|
|
|
try: |
|
results = self.detection_agent(text.strip()) |
|
|
|
if isinstance(results, list) and len(results) > 0: |
|
if isinstance(results[0], list): |
|
results = results[0] |
|
|
|
all_scores = {} |
|
hate_score = 0 |
|
normal_score = 0 |
|
|
|
|
|
for result in results: |
|
label = result["label"] |
|
score = result["score"] |
|
|
|
|
|
mapped_label = self.model_label_mapping.get(label, label) |
|
all_scores[f"{label} ({mapped_label})"] = { |
|
"score": score, |
|
"percentage": f"{score*100:.2f}%", |
|
"confidence": f"{score:.4f}" |
|
} |
|
|
|
|
|
if label == getattr(self, 'hate_label', 'LABEL_1'): |
|
hate_score = score |
|
elif label == getattr(self, 'normal_label', 'LABEL_0'): |
|
normal_score = score |
|
|
|
|
|
is_hate_speech = False |
|
risk_level = "Low" |
|
predicted_label = "Normal" |
|
confidence = normal_score |
|
|
|
if hate_score > normal_score: |
|
|
|
confidence = hate_score |
|
predicted_label = "Hate Speech" |
|
|
|
if hate_score > 0.8: |
|
is_hate_speech = True |
|
risk_level = "High" |
|
status = f"🚨 High confidence hate speech detected! (Hate: {hate_score:.2%})" |
|
elif hate_score > 0.6: |
|
is_hate_speech = True |
|
risk_level = "Medium" |
|
status = f"⚠️ Potential hate speech detected (Hate: {hate_score:.2%})" |
|
else: |
|
risk_level = "Low-Medium" |
|
status = f"⚡ Low confidence hate detection (Hate: {hate_score:.2%})" |
|
else: |
|
|
|
risk_level = "Low" |
|
status = f"✅ No hate speech detected (Normal: {normal_score:.2%})" |
|
|
|
return { |
|
"status": status, |
|
"prediction": predicted_label, |
|
"confidence": confidence, |
|
"all_scores": all_scores, |
|
"risk_level": risk_level, |
|
"is_hate_speech": is_hate_speech, |
|
"hate_score": hate_score, |
|
"normal_score": normal_score |
|
} |
|
|
|
except Exception as e: |
|
logger.error(f"Detection error: {e}") |
|
return { |
|
"status": f"❌ Detection error: {str(e)}", |
|
"prediction": "Error", |
|
"confidence": 0.0, |
|
"all_scores": {}, |
|
"risk_level": "Unknown", |
|
"is_hate_speech": False |
|
} |
|
|
|
def analyze_sentiment(self, text): |
|
"""Fast sentiment analysis""" |
|
if not self.sentiment_agent or not text.strip(): |
|
return {"sentiment": "neutral", "confidence": 0.0, "all_sentiments": {}} |
|
|
|
try: |
|
results = self.sentiment_agent(text.strip()) |
|
if isinstance(results, list) and len(results) > 0: |
|
if isinstance(results[0], list): |
|
results = results[0] |
|
|
|
best_sentiment = max(results, key=lambda x: x['score']) |
|
return { |
|
"sentiment": best_sentiment['label'].lower(), |
|
"confidence": best_sentiment['score'], |
|
"all_sentiments": {r['label']: r['score'] for r in results} |
|
} |
|
except Exception as e: |
|
logger.error(f"Sentiment analysis error: {e}") |
|
return {"sentiment": "neutral", "confidence": 0.0, "all_sentiments": {}} |
|
|
|
def generate_template_moderation(self, text, detection_result, sentiment_result): |
|
"""Fast template-based moderation analysis""" |
|
risk_level = detection_result.get("risk_level", "Low").lower() |
|
confidence = detection_result.get("confidence", 0.0) |
|
hate_score = detection_result.get("hate_score", 0.0) |
|
|
|
if hate_score > 0.8: |
|
analysis = f"🚨 HIGH RISK: Clear hate speech detected with {confidence:.1%} confidence. Immediate review recommended. Content may violate community standards and could cause harm." |
|
safety_level = "harmful" |
|
elif hate_score > 0.6: |
|
analysis = f"⚠️ MEDIUM RISK: Potentially problematic content detected with {confidence:.1%} confidence. Human review recommended to assess context and intent." |
|
safety_level = "concerning" |
|
elif hate_score > 0.3: |
|
analysis = f"⚡ LOW RISK: Minor concerns detected with {confidence:.1%} confidence. Content appears mostly acceptable but may benefit from user education." |
|
safety_level = "review_needed" |
|
else: |
|
analysis = f"✅ SAFE: No significant violations detected. Content appears to meet community standards with {confidence:.1%} confidence." |
|
safety_level = "safe" |
|
|
|
return { |
|
"analysis": analysis, |
|
"confidence": confidence, |
|
"safety_level": safety_level, |
|
"method": "template_based_fast" |
|
} |
|
|
|
def generate_template_counter_speech(self, text, detection_result, sentiment_result): |
|
"""Fast template-based counter-speech""" |
|
if not detection_result.get("is_hate_speech", False): |
|
return "✨ This text promotes positive communication. Great job maintaining respectful dialogue!" |
|
|
|
risk_level = detection_result.get("risk_level", "Low").lower() |
|
|
|
|
|
counter_config = self.counter_speech_prompts.get("counter_speech_prompts", {}) |
|
|
|
if risk_level == "high": |
|
responses = counter_config.get("high_risk", {}).get("fallback_responses", [ |
|
"This type of language can cause real harm. Consider expressing concerns in a way that respects everyone's dignity." |
|
]) |
|
elif risk_level == "medium": |
|
responses = counter_config.get("medium_risk", {}).get("fallback_responses", [ |
|
"This message might be harmful to some. Consider rephrasing to express thoughts more constructively." |
|
]) |
|
else: |
|
responses = counter_config.get("low_risk", {}).get("fallback_responses", [ |
|
"Consider how your words might be received by everyone in the conversation." |
|
]) |
|
|
|
import random |
|
return f"📝 **Educational Response** ({risk_level.title()} Risk): {random.choice(responses)}" |
|
|
|
def comprehensive_analysis(self, text): |
|
"""Fast comprehensive analysis with fixed logic""" |
|
start_time = datetime.now() |
|
|
|
|
|
detection_result = self.detect_hate_speech(text) |
|
sentiment_result = self.analyze_sentiment(text) |
|
|
|
|
|
moderation_result = self.generate_template_moderation(text, detection_result, sentiment_result) |
|
counter_speech = self.generate_template_counter_speech(text, detection_result, sentiment_result) |
|
|
|
processing_time = (datetime.now() - start_time).total_seconds() |
|
|
|
return { |
|
"detection": detection_result, |
|
"sentiment": sentiment_result, |
|
"moderation": moderation_result, |
|
"counter_speech": counter_speech, |
|
"processing_time": processing_time, |
|
"timestamp": datetime.now().isoformat() |
|
} |
|
|
|
|
|
logger.info("🚀 Starting Fixed Multi-Agent System...") |
|
agent_system = FixedMultiAgentSystem() |
|
|
|
def analyze_text_fixed(text): |
|
"""Fixed analysis function with proper logic""" |
|
if not text or not text.strip(): |
|
return ( |
|
"❌ Please enter some text to analyze.", |
|
{}, |
|
"No analysis performed.", |
|
"No input provided", |
|
{} |
|
) |
|
|
|
|
|
results = agent_system.comprehensive_analysis(text) |
|
|
|
|
|
detection_status = results["detection"]["status"] |
|
detection_scores = results["detection"]["all_scores"] |
|
counter_speech = results["counter_speech"] |
|
|
|
|
|
agent_summary = f""" |
|
🔍 **Detection Agent**: {results['detection']['risk_level']} risk ({results['detection']['confidence']:.2%} confidence) |
|
↳ Hate Score: {results['detection'].get('hate_score', 0):.2%} | Normal Score: {results['detection'].get('normal_score', 0):.2%} |
|
📊 **Sentiment Agent**: {results['sentiment']['sentiment'].title()} ({results['sentiment']['confidence']:.2%} confidence) |
|
🛡️ **Moderation Agent**: {results['moderation']['safety_level'].title()} ({results['moderation']['method']}) |
|
💬 **Counter-Speech Agent**: Template-based response system |
|
⚡ **Processing Time**: {results['processing_time']:.2f} seconds (Fixed & Optimized) |
|
|
|
📋 **Quick Analysis**: {results['moderation']['analysis'][:150]}... |
|
""" |
|
|
|
|
|
all_agent_data = { |
|
"Detection_Analysis": { |
|
"corrected_scores": detection_scores, |
|
"hate_score": results['detection'].get('hate_score', 0), |
|
"normal_score": results['detection'].get('normal_score', 0), |
|
"final_prediction": results['detection']['prediction'], |
|
"risk_level": results['detection']['risk_level'], |
|
"is_hate_speech": results['detection']['is_hate_speech'] |
|
}, |
|
"Sentiment_Analysis": { |
|
"primary_sentiment": results['sentiment']['sentiment'], |
|
"all_sentiments": results['sentiment'].get('all_sentiments', {}) |
|
}, |
|
"Moderation_Analysis": { |
|
"safety_level": results['moderation']['safety_level'], |
|
"analysis": results['moderation']['analysis'], |
|
"method": results['moderation']['method'] |
|
}, |
|
"System_Info": { |
|
"mode": "Fixed & Optimized", |
|
"processing_time_seconds": results['processing_time'], |
|
"timestamp": results['timestamp'], |
|
"model_labels": getattr(agent_system, 'model_label_mapping', {}) |
|
} |
|
} |
|
|
|
return detection_status, detection_scores, counter_speech, agent_summary, all_agent_data |
|
|
|
|
|
with gr.Blocks( |
|
title="Fixed Multi-Agent Hate Speech Detection", |
|
theme=gr.themes.Soft() |
|
) as demo: |
|
|
|
|
|
|
|
with gr.Row(): |
|
with gr.Column(scale=2): |
|
text_input = gr.Textbox( |
|
label="Enter text for fixed multi-agent analysis", |
|
placeholder="Test the fixed system with any text...", |
|
lines=4 |
|
) |
|
|
|
with gr.Row(): |
|
analyze_btn = gr.Button("🔧 Run Fixed Analysis", variant="primary", size="lg") |
|
clear_btn = gr.Button("🗑️ Clear", variant="secondary") |
|
|
|
gr.Examples( |
|
examples=[ |
|
["The diversity in our group makes our discussions much richer and more meaningful."], |
|
["I love collaborating with people from different backgrounds."], |
|
["This is a wonderful day to learn something new!"], |
|
["Thank you for sharing your perspective with us."], |
|
["Let's work together to build something amazing."] |
|
], |
|
inputs=text_input, |
|
label="📝 Test with these examples (should show as SAFE):" |
|
) |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
detection_output = gr.Textbox( |
|
label="🎯 Fixed Detection Result", |
|
interactive=False, |
|
lines=3 |
|
) |
|
|
|
agent_summary = gr.Textbox( |
|
label="🔧 Fixed Agent Summary", |
|
interactive=False, |
|
lines=8 |
|
) |
|
|
|
with gr.Column(): |
|
counter_speech_output = gr.Textbox( |
|
label="💬 Counter-Speech Response", |
|
interactive=False, |
|
lines=4 |
|
) |
|
|
|
with gr.Row(): |
|
all_agents_output = gr.JSON( |
|
label="📊 Complete Fixed Analysis Data", |
|
visible=True |
|
) |
|
|
|
|
|
|
|
analyze_btn.click( |
|
fn=analyze_text_fixed, |
|
inputs=text_input, |
|
outputs=[detection_output, all_agents_output, counter_speech_output, agent_summary, all_agents_output] |
|
) |
|
|
|
clear_btn.click( |
|
fn=lambda: ("", "", "", "", {}), |
|
outputs=[text_input, detection_output, counter_speech_output, agent_summary, all_agents_output] |
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.launch( |
|
server_name="0.0.0.0", |
|
server_port=7860, |
|
show_api=False, |
|
share=False |
|
) |