import gradio as gr import pandas as pd import plotly.express as px import torch from huggingface_hub import hf_hub_download from importlib import import_module import shutil import os # Load inference.py and model repo_id = "logasanjeev/emotion-analyzer-bert" local_file = hf_hub_download(repo_id=repo_id, filename="inference.py") print("Downloaded inference.py successfully!") current_dir = os.getcwd() destination = os.path.join(current_dir, "inference.py") shutil.copy(local_file, destination) print("Copied inference.py to current directory!") inference_module = import_module("inference") predict_emotions = inference_module.predict_emotions print("Imported predict_emotions successfully!") _, _ = predict_emotions("dummy text") emotion_labels = inference_module.EMOTION_LABELS default_thresholds = inference_module.THRESHOLDS # Prediction function with grouped bar chart def predict_emotions_with_details(text, confidence_threshold=0.0): if not text.strip(): return "Please enter some text.", "", "", None predictions_str, processed_text = predict_emotions(text) # Parse predictions predictions = [] if predictions_str != "No emotions predicted.": for line in predictions_str.split("\n"): emotion, confidence = line.split(": ") predictions.append((emotion, float(confidence))) # Get raw logits for all emotions (for Top 5) encodings = inference_module.TOKENIZER( processed_text, padding='max_length', truncation=True, max_length=128, return_tensors='pt' ) input_ids = encodings['input_ids'].to(inference_module.DEVICE) attention_mask = encodings['attention_mask'].to(inference_module.DEVICE) with torch.no_grad(): outputs = inference_module.MODEL(input_ids, attention_mask=attention_mask) logits = torch.sigmoid(outputs.logits).cpu().numpy()[0] # All emotions for Top 5 all_emotions = [(emotion_labels[i], round(logit, 4)) for i, logit in enumerate(logits)] all_emotions.sort(key=lambda x: x[1], reverse=True) top_5_emotions = all_emotions[:5] top_5_output = "\n".join([f"{emotion}: {confidence:.4f}" for emotion, confidence in top_5_emotions]) # Filter predictions based on threshold filtered_predictions = [] for emotion, confidence in predictions: thresh = default_thresholds[emotion_labels.index(emotion)] adjusted_thresh = max(thresh, confidence_threshold) if confidence >= adjusted_thresh: filtered_predictions.append((emotion, confidence)) if not filtered_predictions: thresholded_output = "No emotions predicted above thresholds." else: thresholded_output = "\n".join([f"{emotion}: {confidence:.4f}" for emotion, confidence in filtered_predictions]) # Create grouped bar chart fig = None if filtered_predictions or top_5_emotions: emotions = set([pred[0] for pred in filtered_predictions] + [emo[0] for emo in top_5_emotions]) thresholded_dict = {pred[0]: pred[1] for pred in filtered_predictions} top_5_dict = {emo[0]: emo[1] for emo in top_5_emotions} data = { "Emotion": [], "Confidence": [], "Category": [] } for emotion in emotions: if emotion in thresholded_dict: data["Emotion"].append(emotion) data["Confidence"].append(thresholded_dict[emotion]) data["Category"].append("Above Threshold") if emotion in top_5_dict: data["Emotion"].append(emotion) data["Confidence"].append(top_5_dict[emotion]) data["Category"].append("Top 5") df = pd.DataFrame(data) fig = px.bar( df, x="Emotion", y="Confidence", color="Category", barmode="group", title="Emotion Confidence Comparison", height=400, color_discrete_map={"Above Threshold": "#ff6b6b", "Top 5": "#4ecdc4"} ) fig.update_traces(texttemplate='%{y:.2f}', textposition='auto') fig.update_layout( margin=dict(t=50, b=50), xaxis_title="", yaxis_title="Confidence", legend_title="", legend=dict(orientation="h", yanchor="bottom", y=1.05, xanchor="center", x=0.5), plot_bgcolor="rgba(0,0,0,0)", paper_bgcolor="rgba(0,0,0,0)", font=dict(color="#e0e0e0") ) return processed_text, thresholded_output, top_5_output, fig # Enhanced CSS with vibrant colors, animations, and better UX custom_css = """ body { font-family: 'Inter', -apple-system, BlinkMacSystemFont, sans-serif; background: linear-gradient(145deg, #1a1a3d 0%, #2e2e5c 100%); color: #e6e6fa; margin: 0; padding: 20px; min-height: 100vh; } .gr-panel { border-radius: 12px; box-shadow: 0 6px 20px rgba(0,0,0,0.25); background: rgba(255, 255, 255, 0.08); backdrop-filter: blur(8px); padding: 25px; margin: 20px auto; max-width: 900px; border: 1px solid rgba(255, 255, 255, 0.15); transition: transform 0.3s ease; } .gr-panel:hover { transform: translateY(-5px); } .gr-button { border-radius: 8px; padding: 12px 30px; font-weight: 600; background: linear-gradient(90deg, #ff6b6b 0%, #ff8e53 100%); color: white; border: none; transition: all 0.3s ease; cursor: pointer; margin-top: 15px; } .gr-button:hover { background: linear-gradient(90deg, #ff8e53 0%, #ff6b6b 100%); transform: scale(1.05); box-shadow: 0 4px 15px rgba(255, 107, 107, 0.4); } .gr-textbox, .gr-slider { margin-bottom: 20px; } .gr-textbox label, .gr-slider label { font-size: 1.1em; font-weight: 600; color: #e6e6fa; margin-bottom: 8px; display: block; } .gr-textbox textarea, .gr-textbox input { border: 1px solid rgba(255, 255, 255, 0.2); border-radius: 6px; padding: 10px; font-size: 1em; background: rgba(255, 255, 255, 0.1); color: #e6e6fa; transition: border-color 0.3s ease; } .gr-textbox textarea:focus, .gr-textbox input:focus { border-color: #ff6b6b; outline: none; } #title { font-size: 2.2em; font-weight: 700; color: #ffffff; text-align: center; margin: 40px 0 15px 0; text-shadow: 0 2px 4px rgba(0,0,0,0.3); } #description { font-size: 1.1em; color: #d3d3fa; text-align: center; max-width: 700px; margin: 0 auto 40px auto; line-height: 1.5; } #examples-title { font-size: 1.3em; font-weight: 600; color: #e6e6fa; margin: 30px 0 15px 0; text-align: center; } footer { text-align: center; margin: 40px 0; padding: 20px; font-size: 1em; color: #d3d3fa; } footer a { color: #ff6b6b; text-decoration: none; font-weight: 500; transition: color 0.3s ease; } footer a:hover { color: #ff8e53; } .gr-plot { margin-top: 20px; background: rgba(255, 255, 255, 0.1); border-radius: 10px; padding: 15px; border: 1px solid rgba(255, 255, 255, 0.15); } .gr-examples .example { background: rgba(255, 255, 255, 0.12); border-radius: 8px; padding: 12px; margin: 8px 0; transition: all 0.3s ease; cursor: pointer; } .gr-examples .example:hover { background: rgba(255, 107, 107, 0.2); transform: translateY(-3px); } @keyframes fadeIn { from { opacity: 0; transform: translateY(20px); } to { opacity: 1; transform: translateY(0); } } .gr-panel, #title, #description, footer { animation: fadeIn 0.5s ease-out; } """ # Gradio Blocks UI (Enhanced for vibrancy and UX) with gr.Blocks(css=custom_css) as demo: # Header gr.Markdown( "
Emotion Analyzer BERT
", elem_id="title" ) gr.Markdown( """
Discover the emotions in your text with our fine-tuned BERT model! Type your thoughts below, adjust the confidence threshold, and explore the detected emotions with a vibrant visualization.
""", elem_id="description" ) # Input Section with gr.Group(): text_input = gr.Textbox( label="Share Your Thoughts", placeholder="Try something like 'I’m super excited today!' or 'This is so annoying...'", lines=3, show_label=True, elem_classes=["input-textbox"] ) confidence_slider = gr.Slider( minimum=0.0, maximum=0.9, value=0.0, step=0.05, label="Confidence Threshold", info="Filter emotions below this confidence level (default thresholds apply)", elem_classes=["input-slider"] ) submit_btn = gr.Button("Analyze Emotions", variant="primary") # Output Section with gr.Group(): with gr.Row(): with gr.Column(scale=1): processed_text_output = gr.Textbox( label="Processed Text", lines=1, interactive=False, elem_classes=["output-textbox"] ) thresholded_output = gr.Textbox( label="Detected Emotions (Above Threshold)", lines=4, interactive=False, elem_classes=["output-textbox"] ) top_5_output = gr.Textbox( label="Top 5 Emotions", lines=4, interactive=False, elem_classes=["output-textbox"] ) with gr.Column(scale=1): output_plot = gr.Plot( label="Emotion Confidence Visualization", elem_classes=["output-plot"] ) # Example carousel with gr.Group(): gr.Markdown( "
Explore Example Texts
", elem_id="examples-title" ) examples = gr.Examples( examples=[ ["I’m thrilled to win this award! šŸ˜„", "Joy Example"], ["This is so frustrating, nothing works. 😣", "Annoyance Example"], ["I feel so sorry for what happened. 😢", "Sadness Example"], ["What a beautiful day to be alive! šŸŒž", "Admiration Example"], ["Feeling nervous about the exam tomorrow šŸ˜“ u/student r/study", "Nervousness Example"] ], inputs=[text_input], label="" ) # Footer gr.HTML( """ """ ) # Bind predictions submit_btn.click( fn=predict_emotions_with_details, inputs=[text_input, confidence_slider], outputs=[processed_text_output, thresholded_output, top_5_output, output_plot] ) # Launch if __name__ == "__main__": demo.launch()