import gradio as gr import pandas as pd import plotly.express as px import torch from huggingface_hub import hf_hub_download from importlib import import_module import shutil import os # Load inference.py and model repo_id = "logasanjeev/emotion-analyzer-bert" local_file = hf_hub_download(repo_id=repo_id, filename="inference.py") print("Downloaded inference.py successfully!") current_dir = os.getcwd() destination = os.path.join(current_dir, "inference.py") shutil.copy(local_file, destination) print("Copied inference.py to current directory!") inference_module = import_module("inference") predict_emotions = inference_module.predict_emotions print("Imported predict_emotions successfully!") _, _ = predict_emotions("dummy text") emotion_labels = inference_module.EMOTION_LABELS default_thresholds = inference_module.THRESHOLDS # Prediction function with grouped bar chart def predict_emotions_with_details(text, confidence_threshold=0.0): if not text.strip(): return "Please enter some text.", "", "", None predictions_str, processed_text = predict_emotions(text) # Parse predictions predictions = [] if predictions_str != "No emotions predicted.": for line in predictions_str.split("\n"): emotion, confidence = line.split(": ") predictions.append((emotion, float(confidence))) # Get raw logits for all emotions (for Top 5) encodings = inference_module.TOKENIZER( processed_text, padding='max_length', truncation=True, max_length=128, return_tensors='pt' ) input_ids = encodings['input_ids'].to(inference_module.DEVICE) attention_mask = encodings['attention_mask'].to(inference_module.DEVICE) with torch.no_grad(): outputs = inference_module.MODEL(input_ids, attention_mask=attention_mask) logits = torch.sigmoid(outputs.logits).cpu().numpy()[0] # All emotions for Top 5 all_emotions = [(emotion_labels[i], round(logit, 4)) for i, logit in enumerate(logits)] all_emotions.sort(key=lambda x: x[1], reverse=True) top_5_emotions = all_emotions[:5] top_5_output = "\n".join([f"{emotion}: {confidence:.4f}" for emotion, confidence in top_5_emotions]) # Filter predictions based on threshold filtered_predictions = [] for emotion, confidence in predictions: thresh = default_thresholds[emotion_labels.index(emotion)] adjusted_thresh = max(thresh, confidence_threshold) if confidence >= adjusted_thresh: filtered_predictions.append((emotion, confidence)) if not filtered_predictions: thresholded_output = "No emotions predicted above thresholds." else: thresholded_output = "\n".join([f"{emotion}: {confidence:.4f}" for emotion, confidence in filtered_predictions]) # Create grouped bar chart fig = None if filtered_predictions or top_5_emotions: emotions = set([pred[0] for pred in filtered_predictions] + [emo[0] for emo in top_5_emotions]) thresholded_dict = {pred[0]: pred[1] for pred in filtered_predictions} top_5_dict = {emo[0]: emo[1] for emo in top_5_emotions} data = { "Emotion": [], "Confidence": [], "Category": [] } for emotion in emotions: if emotion in thresholded_dict: data["Emotion"].append(emotion) data["Confidence"].append(thresholded_dict[emotion]) data["Category"].append("Above Threshold") if emotion in top_5_dict: data["Emotion"].append(emotion) data["Confidence"].append(top_5_dict[emotion]) data["Category"].append("Top 5") df = pd.DataFrame(data) fig = px.bar( df, x="Emotion", y="Confidence", color="Category", barmode="group", title="Emotion Confidence Comparison", height=400, color_discrete_map={"Above Threshold": "#6366f1", "Top 5": "#10b981"} ) fig.update_traces(texttemplate='%{y:.2f}', textposition='auto') fig.update_layout( margin=dict(t=50, b=50), xaxis_title="", yaxis_title="Confidence", legend_title="", legend=dict(orientation="h", yanchor="bottom", y=1.05, xanchor="center", x=0.5), plot_bgcolor="rgba(0,0,0,0)", paper_bgcolor="rgba(0,0,0,0)", font=dict(color="#e5e7eb") ) return processed_text, thresholded_output, top_5_output, fig # Enhanced CSS with modern design custom_css = """ body { font-family: 'Inter', -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif; background: linear-gradient(135deg, #111827 0%, #1f2937 100%); color: #e5e7eb; margin: 0; padding: 24px; min-height: 100vh; display: flex; flex-direction: column; align-items: center; } .gr-panel { border-radius: 16px; box-shadow: 0 10px 30px rgba(0,0,0,0.3); background: rgba(31, 41, 55, 0.9); backdrop-filter: blur(12px); padding: 32px; margin: 24px auto; max-width: 960px; width: 100%; border: 1px solid rgba(255, 255, 255, 0.1); transition: transform 0.3s ease, box-shadow 0.3s ease; } .gr-panel:hover { transform: translateY(-4px); box-shadow: 0 12px 40px rgba(0,0,0,0.35); } .gr-button { border-radius: 8px; padding: 12px 32px; font-weight: 600; font-size: 16px; background: linear-gradient(90deg, #6366f1 0%, #8b5cf6 100%); color: #ffffff; border: none; transition: all 0.3s ease; cursor: pointer; margin-top: 16px; } .gr-button:hover { background: linear-gradient(90deg, #8b5cf6 0%, #6366f1 100%); transform: translateY(-2px); box-shadow: 0 6px 20px rgba(99, 102, 241, 0.4); } .gr-button:focus { outline: none; box-shadow: 0 0 0 3px rgba(99, 102, 241, 0.3); } .gr-textbox, .gr-slider { margin-bottom: 24px; } .gr-textbox label, .gr-slider label { font-size: 16px; font-weight: 600; color: #e5e7eb; margin-bottom: 8px; display: block; } .gr-textbox textarea, .gr-textbox input { border: 1px solid rgba(255, 255, 255, 0.15); border-radius: 8px; padding: 12px; font-size: 16px; background: rgba(55, 65, 81, 0.5); color: #e5e7eb; transition: border-color 0.3s ease, box-shadow 0.3s ease; } .gr-textbox textarea:focus, .gr-textbox input:focus { border-color: #6366f1; box-shadow: 0 0 0 3px rgba(99, 102, 241, 0.2); outline: none; } #title { font-size: 2.5rem; font-weight: 800; color: #ffffff; text-align: center; margin: 32px 0 16px 0; background: linear-gradient(90deg, #6366f1, #8b5cf6); -webkit-background-clip: text; -webkit-text-fill-color: transparent; } #description { font-size: 1.125rem; color: #d1d5db; text-align: center; max-width: 720px; margin: 0 auto 48px auto; line-height: 1.75; } #examples-title { font-size: 1.25rem; font-weight: 600; color: #e5e7eb; margin: 32px 0 16px 0; text-align: center; } footer { text-align: center; margin: 48px 0 24px 0; padding: 16px; font-size: 14px; color: #d1d5db; } footer a { color: #6366f1; text-decoration: none; font-weight: 500; transition: color 0.3s ease; } footer a:hover { color: #8b5cf6; } .gr-plot { margin-top: 24px; background: rgba(31, 41, 55, 0.5); border-radius: 12px; padding: 16px; border: 1px solid rgba(255, 255, 255, 0.1); } .gr-examples .example { background: rgba(55, 65, 81, 0.7); border-radius: 10px; padding: 16px; margin: 12px 0; transition: all 0.3s ease; cursor: pointer; border: 1px solid rgba(255, 255, 255, 0.1); } .gr-examples .example:hover { background: rgba(99, 102, 241, 0.15); transform: translateY(-2px); border-color: #6366f1; } @keyframes fadeIn { from { opacity: 0; transform: translateY(16px); } to { opacity: 1; transform: translateY(0); } } .gr-panel, #title, #description, footer, .gr-examples .example { animation: fadeIn 0.6s ease-out; } /* Responsive design */ @media (max-width: 768px) { .gr-panel { padding: 24px; margin: 16px; } #title { font-size: 2rem; } #description { font-size: 1rem; } .gr-button { padding: 10px 24px; font-size: 14px; } } """ # Gradio Blocks UI (Modernized) with gr.Blocks(css=custom_css) as demo: # Header gr.Markdown( "
Emotion Analyzer BERT
", elem_id="title" ) gr.Markdown( """
Uncover the emotions in your text with our fine-tuned BERT model, trained on the GoEmotions dataset. Enter your text, fine-tune the confidence threshold, and visualize the results in a sleek, interactive chart.
""", elem_id="description" ) # Input Section with gr.Group(): with gr.Row(): with gr.Column(scale=3): text_input = gr.Textbox( label="Enter Your Text", placeholder="Try: 'I'm over the moon today!' or 'This is so frustrating... 😣'", lines=4, show_label=True, elem_classes=["input-textbox"] ) with gr.Column(scale=1): confidence_slider = gr.Slider( minimum=0.0, maximum=0.9, value=0.0, step=0.05, label="Confidence Threshold", info="Filter emotions below this confidence level", elem_classes=["input-slider"] ) submit_btn = gr.Button("Analyze Emotions", variant="primary") # Output Section with gr.Group(): with gr.Row(): with gr.Column(scale=1): processed_text_output = gr.Textbox( label="Processed Text", lines=2, interactive=False, elem_classes=["output-textbox"] ) thresholded_output = gr.Textbox( label="Detected Emotions (Above Threshold)", lines=5, interactive=False, elem_classes=["output-textbox"] ) top_5_output = gr.Textbox( label="Top 5 Emotions", lines=5, interactive=False, elem_classes=["output-textbox"] ) with gr.Column(scale=2): output_plot = gr.Plot( label="Emotion Confidence Visualization", elem_classes=["output-plot"] ) # Example carousel with gr.Group(): gr.Markdown( "
Try These Examples
", elem_id="examples-title" ) examples = gr.Examples( examples=[ ["I’m thrilled to win this award! šŸ˜„", "Joy Example"], ["This is so frustrating, nothing works. 😣", "Annoyance Example"], ["I feel so sorry for what happened. 😢", "Sadness Example"], ["What a beautiful day to be alive! šŸŒž", "Admiration Example"], ["Feeling nervous about the exam tomorrow šŸ˜“ u/student r/study", "Nervousness Example"] ], inputs=[text_input], label="" ) # Footer gr.HTML( """ """ ) # Bind predictions submit_btn.click( fn=predict_emotions_with_details, inputs=[text_input, confidence_slider], outputs=[processed_text_output, thresholded_output, top_5_output, output_plot] ) # Launch if __name__ == "__main__": demo.launch()