""" simple_demo.py """ import os import gradio as gr from safetensors.torch import load_file from huggingface_hub import hf_hub_download import gspread from google.oauth2 import service_account import json from datetime import datetime import uuid # Local imports from lionguard2 import LionGuard2, CATEGORIES from utils import get_embeddings # Google Sheets configuration GOOGLE_SHEET_URL = os.environ.get("GOOGLE_SHEET_URL") GOOGLE_CREDENTIALS = os.environ.get("GCP_SERVICE_ACCOUNT") RESULTS_SHEET_NAME = "results" VOTES_SHEET_NAME = "votes" # Helper to save results data def save_results_data(row): try: # Create credentials object credentials = service_account.Credentials.from_service_account_info( json.loads(GOOGLE_CREDENTIALS), scopes=[ "https://www.googleapis.com/auth/spreadsheets", "https://www.googleapis.com/auth/drive", ], ) # Create authorized client gc = gspread.authorize(credentials) sheet = gc.open_by_url(GOOGLE_SHEET_URL) ws = sheet.worksheet(RESULTS_SHEET_NAME) ws.append_row(list(row.values())) print(f"Saved results data for text_id: {row['text_id']}") except Exception as e: print(f"Error saving results data: {e}") # Helper to save vote data def save_vote_data(text_id, agree): try: # Create credentials object credentials = service_account.Credentials.from_service_account_info( json.loads(GOOGLE_CREDENTIALS), scopes=[ "https://www.googleapis.com/auth/spreadsheets", "https://www.googleapis.com/auth/drive", ], ) # Create authorized client gc = gspread.authorize(credentials) sheet = gc.open_by_url(GOOGLE_SHEET_URL) ws = sheet.worksheet(VOTES_SHEET_NAME) vote_row = { "datetime": datetime.now().isoformat(), "text_id": text_id, "agree": agree } ws.append_row(list(vote_row.values())) print(f"Saved vote data for text_id: {text_id}, agree: {agree}") except Exception as e: print(f"Error saving vote data: {e}") def download_model(repo_id, filename="LionGuard2.safetensors", token=None): """ Download the LionGuard2 model from a Hugging Face private repository. Args: repo_id: The Hugging Face repository ID (e.g., "username/repo-name") filename: The filename to download (default: "LionGuard2.safetensors") token: Hugging Face access token for private repositories Returns: Path to the downloaded file """ if token is None: token = os.environ.get("HF_API_KEY") # Download the model file model_path = hf_hub_download( repo_id=repo_id, filename=filename, token=token, cache_dir="./cache" ) return model_path def load_model(repo_id=None, use_local=True): """ Load the LionGuard2 model from either local file or Hugging Face repository. Args: repo_id: The Hugging Face repository ID (optional) use_local: Whether to use local file first (default: True) """ model = LionGuard2() model.eval() model_path = "LionGuard2.safetensors" # Try to download from HF repo if specified and local file doesn't exist or use_local is False if repo_id and (not use_local or not os.path.exists(model_path)): try: print(f"Downloading LionGuard2.safetensors from {repo_id}...") model_path = download_model(repo_id) print(f"Model downloaded to: {model_path}") except Exception as e: print(f"Failed to download from HF repo: {e}") if not os.path.exists("LionGuard2.safetensors"): raise Exception("No local model file found and failed to download from HF repo") print("Falling back to local file...") state_dict = load_file(model_path) model.load_state_dict(state_dict) return model # Load model once at startup HF_REPO_ID = "gabrielchua/refactored-guacamole" # Update this with the actual repo ID model = load_model(repo_id=HF_REPO_ID) def format_score_with_style(score_str, compact=False): """ Format score with color and emoji based on value. Args: score_str: Score as string (could be "-" for missing values) compact: Whether to use compact styling Returns: HTML formatted string with styling """ if score_str == "-": return '-' try: score = float(score_str) percentage = int(score * 100) padding = "6px 12px" if compact else "8px 16px" if score < 0.4: # Safe - refined green return f'👌 {percentage}%' elif 0.4 <= score < 0.7: # Warning - refined amber return f'âš ī¸ {percentage}%' else: # High risk - refined red return f'🚨 {percentage}%' except: return score_str def format_binary_score(score): """Format the binary score with appropriate styling for dark mode.""" percentage = int(score * 100) if score < 0.4: return f'
✅ Pass ({percentage}/100)
' elif 0.4 <= score < 0.7: return f'
âš ī¸ Warning ({percentage}/100)
' else: return f'
🚨 Fail ({percentage}/100)
' def analyze_text(text): """ Analyze text for content moderation violations. Args: text: Input text to analyze Returns: binary_score: Overall safety score with styling category_table: HTML table with category-specific scores and styling text_id: Unique identifier for this analysis voting_section: HTML for voting buttons """ if not text.strip(): empty_html = '
Enter text to analyze
' return '
Enter text to analyze
', empty_html, "", "" try: # Generate unique text ID text_id = str(uuid.uuid4()) # Get embeddings for the text embeddings = get_embeddings([text]) # Run inference results = model.predict(embeddings) # Extract binary score (overall safety) binary_score = results.get('binary', [0.0])[0] # Extract specific scores for Google Sheets hateful_scores = CATEGORIES['hateful'] hateful_l1_score = results.get(hateful_scores[0], [0.0])[0] if len(hateful_scores) > 0 else 0.0 hateful_l2_score = results.get(hateful_scores[1], [0.0])[0] if len(hateful_scores) > 1 else 0.0 insults_scores = CATEGORIES['insults'] insults_score = results.get(insults_scores[0], [0.0])[0] if len(insults_scores) > 0 else 0.0 sexual_scores = CATEGORIES['sexual'] sexual_l1_score = results.get(sexual_scores[0], [0.0])[0] if len(sexual_scores) > 0 else 0.0 sexual_l2_score = results.get(sexual_scores[1], [0.0])[0] if len(sexual_scores) > 1 else 0.0 physical_violence_scores = CATEGORIES['physical_violence'] physical_violence_score = results.get(physical_violence_scores[0], [0.0])[0] if len(physical_violence_scores) > 0 else 0.0 self_harm_scores = CATEGORIES['self_harm'] self_harm_l1_score = results.get(self_harm_scores[0], [0.0])[0] if len(self_harm_scores) > 0 else 0.0 self_harm_l2_score = results.get(self_harm_scores[1], [0.0])[0] if len(self_harm_scores) > 1 else 0.0 aom_scores = CATEGORIES['all_other_misconduct'] aom_l1_score = results.get(aom_scores[0], [0.0])[0] if len(aom_scores) > 0 else 0.0 aom_l2_score = results.get(aom_scores[1], [0.0])[0] if len(aom_scores) > 1 else 0.0 # Save results to Google Sheets if GOOGLE_SHEET_URL and GOOGLE_CREDENTIALS: results_row = { "datetime": datetime.now().isoformat(), "text_id": text_id, "text": text, "binary_score": binary_score, "hateful_l1_score": hateful_l1_score, "hateful_l2_score": hateful_l2_score, "insults_score": insults_score, "sexual_l1_score": sexual_l1_score, "sexual_l2_score": sexual_l2_score, "physical_violence_score": physical_violence_score, "self_harm_l1_score": self_harm_l1_score, "self_harm_l2_score": self_harm_l2_score, "aom_l1_score": aom_l1_score, "aom_l2_score": aom_l2_score } save_results_data(results_row) # Prepare category data with max scores and dropdowns categories_html = [] # Define the main categories (excluding binary) main_categories = ['hateful', 'insults', 'sexual', 'physical_violence', 'self_harm', 'all_other_misconduct'] for category in main_categories: subcategories = CATEGORIES[category] category_name = category.replace('_', ' ').title() # Add emoji to category name based on type category_emojis = { 'Hateful': 'đŸ¤Ŧ', 'Insults': 'đŸ’ĸ', 'Sexual': '🔞', 'Physical Violence': 'âš”ī¸', 'Self Harm': 'â˜šī¸', 'All Other Misconduct': 'đŸ™…â€â™€ī¸' } category_display = f"{category_emojis.get(category_name, '📝')} {category_name}" # Get scores for all levels level_scores = [] for i, subcategory_key in enumerate(subcategories): score = results.get(subcategory_key, [0.0])[0] level_scores.append((f"Level {i+1}", score)) # Find max score max_score = max([score for _, score in level_scores]) if level_scores else 0.0 # Create the row HTML - just show max score categories_html.append(f''' {category_display} {format_score_with_style(f"{max_score:.4f}")} ''') # Create refined HTML table for dark mode html_table = f'''

📊 Category-Specific Scores

{"".join(categories_html)}
Category Score
''' # Create voting section voting_html = f'''

📊 How accurate are these results?

Your feedback helps improve the model

''' return format_binary_score(binary_score), html_table, text_id, voting_html except Exception as e: error_msg = f"Error analyzing text: {str(e)}" error_html = f'
❌ {error_msg}
' return f'
❌ {error_msg}
', error_html, "", "" # Voting functions def vote_thumbs_up(text_id): """Handle thumbs up vote""" if text_id and GOOGLE_SHEET_URL and GOOGLE_CREDENTIALS: save_vote_data(text_id, True) return '''
🎉

Awesome! Thank you!

Your positive feedback helps improve our AI safety models

''' return '
Voting not available
' def vote_thumbs_down(text_id): """Handle thumbs down vote""" if text_id and GOOGLE_SHEET_URL and GOOGLE_CREDENTIALS: save_vote_data(text_id, False) return '''
📝

Thanks for the feedback!

Critical feedback helps us identify and fix model weaknesses

''' return '
Voting not available
' # Create Gradio interface with dark theme with gr.Blocks(title="LionGuard2", theme=gr.themes.Base().set( body_background_fill="*neutral_950", background_fill_primary="*neutral_900", background_fill_secondary="*neutral_800", border_color_primary="*neutral_700", color_accent_soft="*blue_500" )) as demo: gr.HTML("""
""") with gr.Row(): with gr.Column(scale=1, min_width=400): text_input = gr.Textbox( label="Enter text to analyze:", placeholder="Type your text here...", lines=12, max_lines=20, container=True ) analyze_btn = gr.Button("🔍 Analyze Text", variant="primary") with gr.Column(scale=1, min_width=400): gr.HTML("""

Overall Safety Score

Higher percentages indicate higher likelihood of harmful content

""") binary_output = gr.HTML( value='
Enter text to analyze
' ) category_table = gr.HTML( value='
Category scores will appear here after analysis
' ) # Enhanced Voting section - always visible and prominent with gr.Row(): with gr.Column(): gr.HTML("""

đŸŽ¯ Help Us Improve! Rate the Analysis

Your feedback trains better AI safety models

""") voting_instruction = gr.HTML(value="""

âŦ†ī¸ Analyze some text above to unlock voting âŦ†ī¸

""") with gr.Row(visible=False) as voting_buttons_row: with gr.Column(scale=1): thumbs_up_btn = gr.Button( "👍 Results Look Accurate", variant="primary", size="lg", elem_classes=["voting-btn-positive"] ) with gr.Column(scale=1): thumbs_down_btn = gr.Button( "👎 Results Look Wrong", variant="secondary", size="lg", elem_classes=["voting-btn-negative"] ) vote_feedback = gr.HTML(value="") # Hidden text_id to track current analysis current_text_id = gr.Textbox(value="", visible=False) # Add information about the categories with gr.Row(): with gr.Accordion("â„šī¸ About the Scoring System", open=False): gr.HTML("""

How Scoring Works:

Content Categories (Singapore Context):

""") # Function to handle analysis and show voting buttons def analyze_and_show_voting(text): binary_score, category_table, text_id, voting_html = analyze_text(text) # Show voting buttons and hide instructions if we have results if text_id: voting_buttons_update = gr.update(visible=True) voting_instruction_update = gr.update(value="""

✅ Analysis Complete! Please rate the accuracy below 👇

""") else: voting_buttons_update = gr.update(visible=False) voting_instruction_update = gr.update(value="""

âŦ†ī¸ Analyze some text above to unlock voting âŦ†ī¸

""") return binary_score, category_table, text_id, voting_buttons_update, voting_instruction_update, "" # Connect the analyze button to the function analyze_btn.click( fn=analyze_and_show_voting, inputs=[text_input], outputs=[binary_output, category_table, current_text_id, voting_buttons_row, voting_instruction, vote_feedback] ) # Allow Enter key to trigger analysis text_input.submit( fn=analyze_and_show_voting, inputs=[text_input], outputs=[binary_output, category_table, current_text_id, voting_buttons_row, voting_instruction, vote_feedback] ) # Connect voting buttons thumbs_up_btn.click( fn=vote_thumbs_up, inputs=[current_text_id], outputs=[vote_feedback] ) thumbs_down_btn.click( fn=vote_thumbs_down, inputs=[current_text_id], outputs=[vote_feedback] ) if __name__ == "__main__": demo.launch(share=True, server_name="0.0.0.0", server_port=7860)