"""
simple_demo.py
"""
import os
import gradio as gr
from safetensors.torch import load_file
from huggingface_hub import hf_hub_download
import gspread
from google.oauth2 import service_account
import json
from datetime import datetime
import uuid
# Local imports
from lionguard2 import LionGuard2, CATEGORIES
from utils import get_embeddings
# Google Sheets configuration
GOOGLE_SHEET_URL = os.environ.get("GOOGLE_SHEET_URL")
GOOGLE_CREDENTIALS = os.environ.get("GCP_SERVICE_ACCOUNT")
RESULTS_SHEET_NAME = "results"
VOTES_SHEET_NAME = "votes"
# Helper to save results data
def save_results_data(row):
try:
# Create credentials object
credentials = service_account.Credentials.from_service_account_info(
json.loads(GOOGLE_CREDENTIALS),
scopes=[
"https://www.googleapis.com/auth/spreadsheets",
"https://www.googleapis.com/auth/drive",
],
)
# Create authorized client
gc = gspread.authorize(credentials)
sheet = gc.open_by_url(GOOGLE_SHEET_URL)
ws = sheet.worksheet(RESULTS_SHEET_NAME)
ws.append_row(list(row.values()))
print(f"Saved results data for text_id: {row['text_id']}")
except Exception as e:
print(f"Error saving results data: {e}")
# Helper to save vote data
def save_vote_data(text_id, agree):
try:
# Create credentials object
credentials = service_account.Credentials.from_service_account_info(
json.loads(GOOGLE_CREDENTIALS),
scopes=[
"https://www.googleapis.com/auth/spreadsheets",
"https://www.googleapis.com/auth/drive",
],
)
# Create authorized client
gc = gspread.authorize(credentials)
sheet = gc.open_by_url(GOOGLE_SHEET_URL)
ws = sheet.worksheet(VOTES_SHEET_NAME)
vote_row = {
"datetime": datetime.now().isoformat(),
"text_id": text_id,
"agree": agree
}
ws.append_row(list(vote_row.values()))
print(f"Saved vote data for text_id: {text_id}, agree: {agree}")
except Exception as e:
print(f"Error saving vote data: {e}")
def download_model(repo_id, filename="LionGuard2.safetensors", token=None):
"""
Download the LionGuard2 model from a Hugging Face private repository.
Args:
repo_id: The Hugging Face repository ID (e.g., "username/repo-name")
filename: The filename to download (default: "LionGuard2.safetensors")
token: Hugging Face access token for private repositories
Returns:
Path to the downloaded file
"""
if token is None:
token = os.environ.get("HF_API_KEY")
# Download the model file
model_path = hf_hub_download(
repo_id=repo_id,
filename=filename,
token=token,
cache_dir="./cache"
)
return model_path
def load_model(repo_id=None, use_local=True):
"""
Load the LionGuard2 model from either local file or Hugging Face repository.
Args:
repo_id: The Hugging Face repository ID (optional)
use_local: Whether to use local file first (default: True)
"""
model = LionGuard2()
model.eval()
model_path = "LionGuard2.safetensors"
# Try to download from HF repo if specified and local file doesn't exist or use_local is False
if repo_id and (not use_local or not os.path.exists(model_path)):
try:
print(f"Downloading LionGuard2.safetensors from {repo_id}...")
model_path = download_model(repo_id)
print(f"Model downloaded to: {model_path}")
except Exception as e:
print(f"Failed to download from HF repo: {e}")
if not os.path.exists("LionGuard2.safetensors"):
raise Exception("No local model file found and failed to download from HF repo")
print("Falling back to local file...")
state_dict = load_file(model_path)
model.load_state_dict(state_dict)
return model
# Load model once at startup
HF_REPO_ID = "gabrielchua/refactored-guacamole" # Update this with the actual repo ID
model = load_model(repo_id=HF_REPO_ID)
def format_score_with_style(score_str, compact=False):
"""
Format score with color and emoji based on value.
Args:
score_str: Score as string (could be "-" for missing values)
compact: Whether to use compact styling
Returns:
HTML formatted string with styling
"""
if score_str == "-":
return '-'
try:
score = float(score_str)
percentage = int(score * 100)
padding = "6px 12px" if compact else "8px 16px"
if score < 0.4:
# Safe - refined green
return f'đ {percentage}%'
elif 0.4 <= score < 0.7:
# Warning - refined amber
return f'â ī¸ {percentage}%'
else:
# High risk - refined red
return f'đ¨ {percentage}%'
except:
return score_str
def format_binary_score(score):
"""Format the binary score with appropriate styling for dark mode."""
percentage = int(score * 100)
if score < 0.4:
return f'
â
Pass ({percentage}/100)
'
elif 0.4 <= score < 0.7:
return f'â ī¸ Warning ({percentage}/100)
'
else:
return f'đ¨ Fail ({percentage}/100)
'
def analyze_text(text):
"""
Analyze text for content moderation violations.
Args:
text: Input text to analyze
Returns:
binary_score: Overall safety score with styling
category_table: HTML table with category-specific scores and styling
text_id: Unique identifier for this analysis
voting_section: HTML for voting buttons
"""
if not text.strip():
empty_html = 'Enter text to analyze
'
return 'Enter text to analyze
', empty_html, "", ""
try:
# Generate unique text ID
text_id = str(uuid.uuid4())
# Get embeddings for the text
embeddings = get_embeddings([text])
# Run inference
results = model.predict(embeddings)
# Extract binary score (overall safety)
binary_score = results.get('binary', [0.0])[0]
# Extract specific scores for Google Sheets
hateful_scores = CATEGORIES['hateful']
hateful_l1_score = results.get(hateful_scores[0], [0.0])[0] if len(hateful_scores) > 0 else 0.0
hateful_l2_score = results.get(hateful_scores[1], [0.0])[0] if len(hateful_scores) > 1 else 0.0
insults_scores = CATEGORIES['insults']
insults_score = results.get(insults_scores[0], [0.0])[0] if len(insults_scores) > 0 else 0.0
sexual_scores = CATEGORIES['sexual']
sexual_l1_score = results.get(sexual_scores[0], [0.0])[0] if len(sexual_scores) > 0 else 0.0
sexual_l2_score = results.get(sexual_scores[1], [0.0])[0] if len(sexual_scores) > 1 else 0.0
physical_violence_scores = CATEGORIES['physical_violence']
physical_violence_score = results.get(physical_violence_scores[0], [0.0])[0] if len(physical_violence_scores) > 0 else 0.0
self_harm_scores = CATEGORIES['self_harm']
self_harm_l1_score = results.get(self_harm_scores[0], [0.0])[0] if len(self_harm_scores) > 0 else 0.0
self_harm_l2_score = results.get(self_harm_scores[1], [0.0])[0] if len(self_harm_scores) > 1 else 0.0
aom_scores = CATEGORIES['all_other_misconduct']
aom_l1_score = results.get(aom_scores[0], [0.0])[0] if len(aom_scores) > 0 else 0.0
aom_l2_score = results.get(aom_scores[1], [0.0])[0] if len(aom_scores) > 1 else 0.0
# Save results to Google Sheets
if GOOGLE_SHEET_URL and GOOGLE_CREDENTIALS:
results_row = {
"datetime": datetime.now().isoformat(),
"text_id": text_id,
"text": text,
"binary_score": binary_score,
"hateful_l1_score": hateful_l1_score,
"hateful_l2_score": hateful_l2_score,
"insults_score": insults_score,
"sexual_l1_score": sexual_l1_score,
"sexual_l2_score": sexual_l2_score,
"physical_violence_score": physical_violence_score,
"self_harm_l1_score": self_harm_l1_score,
"self_harm_l2_score": self_harm_l2_score,
"aom_l1_score": aom_l1_score,
"aom_l2_score": aom_l2_score
}
save_results_data(results_row)
# Prepare category data with max scores and dropdowns
categories_html = []
# Define the main categories (excluding binary)
main_categories = ['hateful', 'insults', 'sexual', 'physical_violence', 'self_harm', 'all_other_misconduct']
for category in main_categories:
subcategories = CATEGORIES[category]
category_name = category.replace('_', ' ').title()
# Add emoji to category name based on type
category_emojis = {
'Hateful': 'đ¤Ŧ',
'Insults': 'đĸ',
'Sexual': 'đ',
'Physical Violence': 'âī¸',
'Self Harm': 'âšī¸',
'All Other Misconduct': 'đ
ââī¸'
}
category_display = f"{category_emojis.get(category_name, 'đ')} {category_name}"
# Get scores for all levels
level_scores = []
for i, subcategory_key in enumerate(subcategories):
score = results.get(subcategory_key, [0.0])[0]
level_scores.append((f"Level {i+1}", score))
# Find max score
max_score = max([score for _, score in level_scores]) if level_scores else 0.0
# Create the row HTML - just show max score
categories_html.append(f'''
{category_display} |
{format_score_with_style(f"{max_score:.4f}")} |
''')
# Create refined HTML table for dark mode
html_table = f'''
đ Category-Specific Scores
Category |
Score |
{"".join(categories_html)}
'''
# Create voting section
voting_html = f'''
đ How accurate are these results?
Your feedback helps improve the model
'''
return format_binary_score(binary_score), html_table, text_id, voting_html
except Exception as e:
error_msg = f"Error analyzing text: {str(e)}"
error_html = f'â {error_msg}
'
return f'â {error_msg}
', error_html, "", ""
# Voting functions
def vote_thumbs_up(text_id):
"""Handle thumbs up vote"""
if text_id and GOOGLE_SHEET_URL and GOOGLE_CREDENTIALS:
save_vote_data(text_id, True)
return '''
đ
Awesome! Thank you!
Your positive feedback helps improve our AI safety models
'''
return 'Voting not available
'
def vote_thumbs_down(text_id):
"""Handle thumbs down vote"""
if text_id and GOOGLE_SHEET_URL and GOOGLE_CREDENTIALS:
save_vote_data(text_id, False)
return '''
đ
Thanks for the feedback!
Critical feedback helps us identify and fix model weaknesses
'''
return 'Voting not available
'
# Create Gradio interface with dark theme
with gr.Blocks(title="LionGuard2", theme=gr.themes.Base().set(
body_background_fill="*neutral_950",
background_fill_primary="*neutral_900",
background_fill_secondary="*neutral_800",
border_color_primary="*neutral_700",
color_accent_soft="*blue_500"
)) as demo:
gr.HTML("""
""")
with gr.Row():
with gr.Column(scale=1, min_width=400):
text_input = gr.Textbox(
label="Enter text to analyze:",
placeholder="Type your text here...",
lines=12,
max_lines=20,
container=True
)
analyze_btn = gr.Button("đ Analyze Text", variant="primary")
with gr.Column(scale=1, min_width=400):
gr.HTML("""
Overall Safety Score
Higher percentages indicate higher likelihood of harmful content
""")
binary_output = gr.HTML(
value='Enter text to analyze
'
)
category_table = gr.HTML(
value='Category scores will appear here after analysis
'
)
# Enhanced Voting section - always visible and prominent
with gr.Row():
with gr.Column():
gr.HTML("""
đ¯ Help Us Improve! Rate the Analysis
Your feedback trains better AI safety models
""")
voting_instruction = gr.HTML(value="""
âŦī¸ Analyze some text above to unlock voting âŦī¸
""")
with gr.Row(visible=False) as voting_buttons_row:
with gr.Column(scale=1):
thumbs_up_btn = gr.Button(
"đ Results Look Accurate",
variant="primary",
size="lg",
elem_classes=["voting-btn-positive"]
)
with gr.Column(scale=1):
thumbs_down_btn = gr.Button(
"đ Results Look Wrong",
variant="secondary",
size="lg",
elem_classes=["voting-btn-negative"]
)
vote_feedback = gr.HTML(value="")
# Hidden text_id to track current analysis
current_text_id = gr.Textbox(value="", visible=False)
# Add information about the categories
with gr.Row():
with gr.Accordion("âšī¸ About the Scoring System", open=False):
gr.HTML("""
How Scoring Works:
- Percentages represent likelihood of harmful content - Higher % = More likely to be harmful
- 0-40%: Content appears safe
- 40-70%: Potentially concerning content that warrants review
- 70-100%: High likelihood of policy violation
Content Categories (Singapore Context):
- đ¤Ŧ Hateful: Content targeting Singapore's protected traits (e.g., race, religion), including discriminatory remarks and explicit calls for harm/violence.
- đĸ Insults: Personal attacks on non-protected attributes (e.g., appearance). Note: Sexuality attacks are classified as insults, not hateful, in Singapore.
- đ Sexual: Sexual content or adult themes, ranging from mild content inappropriate for minors to explicit content inappropriate for general audiences.
- âī¸ Physical Violence: Threats, descriptions, or glorification of physical harm against individuals or groups (not property damage).
- âšī¸ Self Harm: Content about self-harm or suicide, including ideation, encouragement, or descriptions of ongoing actions.
- đ
ââī¸ All Other Misconduct: Unethical/criminal conduct not covered above, from socially condemned behavior to clearly illegal activities under Singapore law.
""")
# Function to handle analysis and show voting buttons
def analyze_and_show_voting(text):
binary_score, category_table, text_id, voting_html = analyze_text(text)
# Show voting buttons and hide instructions if we have results
if text_id:
voting_buttons_update = gr.update(visible=True)
voting_instruction_update = gr.update(value="""
â
Analysis Complete! Please rate the accuracy below đ
""")
else:
voting_buttons_update = gr.update(visible=False)
voting_instruction_update = gr.update(value="""
âŦī¸ Analyze some text above to unlock voting âŦī¸
""")
return binary_score, category_table, text_id, voting_buttons_update, voting_instruction_update, ""
# Connect the analyze button to the function
analyze_btn.click(
fn=analyze_and_show_voting,
inputs=[text_input],
outputs=[binary_output, category_table, current_text_id, voting_buttons_row, voting_instruction, vote_feedback]
)
# Allow Enter key to trigger analysis
text_input.submit(
fn=analyze_and_show_voting,
inputs=[text_input],
outputs=[binary_output, category_table, current_text_id, voting_buttons_row, voting_instruction, vote_feedback]
)
# Connect voting buttons
thumbs_up_btn.click(
fn=vote_thumbs_up,
inputs=[current_text_id],
outputs=[vote_feedback]
)
thumbs_down_btn.click(
fn=vote_thumbs_down,
inputs=[current_text_id],
outputs=[vote_feedback]
)
if __name__ == "__main__":
demo.launch(share=True, server_name="0.0.0.0", server_port=7860)