Spaces:
Running
Running
import gradio as gr | |
from transformers import pipeline | |
import sys | |
# Toggle this to True if you want to see debug prints | |
DEBUG = False | |
# Load the toxicity classification pipeline | |
print("Loading toxicity classifier pipeline...") | |
toxicity_pipeline = pipeline( | |
"text-classification", | |
model="s-nlp/roberta_toxicity_classifier", | |
tokenizer="s-nlp/roberta_toxicity_classifier" | |
) | |
print("Pipeline loaded successfully!") | |
def toxicity_classification(text: str) -> dict: | |
""" | |
Classify the toxicity of the given text. | |
Args: | |
text (str): The text to analyze | |
Returns: | |
dict: A dictionary containing toxicity classification and confidence | |
""" | |
if not text.strip(): | |
return {"error": "Please enter some text to analyze"} | |
try: | |
# Get the top prediction using the pipeline | |
result = toxicity_pipeline(text)[0] | |
if DEBUG: | |
print(f"DEBUG - Pipeline result: {result}") | |
# The model returns labels like "neutral" or "toxic" | |
label = result.get("label", "neutral").lower() | |
score = result.get("score", 0.0) | |
# Map "neutral" (or any non-toxic) to non-toxic | |
if label == "toxic": | |
classification = "toxic" | |
confidence = score | |
else: | |
classification = "non-toxic" | |
confidence = score | |
return { | |
"classification": classification, | |
"confidence": round(confidence, 4) | |
} | |
except Exception as e: | |
return {"error": f"Error processing text: {str(e)}"} | |
# Create the Gradio interface | |
demo = gr.Interface( | |
fn=toxicity_classification, | |
inputs=gr.Textbox( | |
placeholder="Enter text to analyze for toxicity...", | |
lines=3, | |
label="Input Text" | |
), | |
outputs=gr.JSON(label="Toxicity Analysis Results"), | |
title="Text Toxicity Classification", | |
description="Analyze text toxicity using RoBERTa transformer model (s-nlp/roberta_toxicity_classifier)", | |
examples=[ | |
["You are amazing!"], | |
["This is a wonderful day."], | |
["I hate you so much!"], | |
["You're such an idiot!"], | |
] | |
) | |
if __name__ == "__main__": | |
# If "debug" was passed as a command-line argument, run local tests | |
if len(sys.argv) > 1 and sys.argv[1].lower() == "debug": | |
DEBUG = True | |
print("=" * 50) | |
print("DEBUG MODE - Testing toxicity classification locally") | |
print("=" * 50) | |
test_cases = [ | |
"You are amazing!", | |
"This is a wonderful day.", | |
"I hate you so much!", | |
"You're such an idiot!", | |
"I disagree with your opinion.", | |
"" # Empty string test | |
] | |
for i, test_text in enumerate(test_cases, 1): | |
print(f"\n--- Test Case {i} ---") | |
print(f"Input: '{test_text}'") | |
result = toxicity_classification(test_text) | |
print(f"Output: {result}") | |
print("-" * 30) | |
print("\nDebug testing completed!") | |
else: | |
# Normal Gradio mode: launch with MCP server enabled | |
demo.launch(mcp_server=True) | |