logasanjeev's picture
Update app.py
78e6e06 verified
import gradio as gr
import pandas as pd
import plotly.express as px
import torch
from huggingface_hub import hf_hub_download
from importlib import import_module
import shutil
import os
# Load inference.py and model
repo_id = "logasanjeev/emotion-analyzer-bert"
local_file = hf_hub_download(repo_id=repo_id, filename="inference.py")
print("Downloaded inference.py successfully!")
current_dir = os.getcwd()
destination = os.path.join(current_dir, "inference.py")
shutil.copy(local_file, destination)
print("Copied inference.py to current directory!")
inference_module = import_module("inference")
predict_emotions = inference_module.predict_emotions
print("Imported predict_emotions successfully!")
_, _ = predict_emotions("dummy text")
emotion_labels = inference_module.EMOTION_LABELS
default_thresholds = inference_module.THRESHOLDS
# Prediction function with grouped bar chart
def predict_emotions_with_details(text, confidence_threshold=0.0):
if not text.strip():
return "Please enter some text.", "", "", None
predictions_str, processed_text = predict_emotions(text)
# Parse predictions
predictions = []
if predictions_str != "No emotions predicted.":
for line in predictions_str.split("\n"):
emotion, confidence = line.split(": ")
predictions.append((emotion, float(confidence)))
# Get raw logits for all emotions (for Top 5)
encodings = inference_module.TOKENIZER(
processed_text,
padding='max_length',
truncation=True,
max_length=128,
return_tensors='pt'
)
input_ids = encodings['input_ids'].to(inference_module.DEVICE)
attention_mask = encodings['attention_mask'].to(inference_module.DEVICE)
with torch.no_grad():
outputs = inference_module.MODEL(input_ids, attention_mask=attention_mask)
logits = torch.sigmoid(outputs.logits).cpu().numpy()[0]
# All emotions for Top 5
all_emotions = [(emotion_labels[i], round(logit, 4)) for i, logit in enumerate(logits)]
all_emotions.sort(key=lambda x: x[1], reverse=True)
top_5_emotions = all_emotions[:5]
top_5_output = "\n".join([f"{emotion}: {confidence:.4f}" for emotion, confidence in top_5_emotions])
# Filter predictions based on threshold
filtered_predictions = []
for emotion, confidence in predictions:
thresh = default_thresholds[emotion_labels.index(emotion)]
adjusted_thresh = max(thresh, confidence_threshold)
if confidence >= adjusted_thresh:
filtered_predictions.append((emotion, confidence))
if not filtered_predictions:
thresholded_output = "No emotions predicted above thresholds."
else:
thresholded_output = "\n".join([f"{emotion}: {confidence:.4f}" for emotion, confidence in filtered_predictions])
# Create grouped bar chart
fig = None
if filtered_predictions or top_5_emotions:
emotions = set([pred[0] for pred in filtered_predictions] + [emo[0] for emo in top_5_emotions])
thresholded_dict = {pred[0]: pred[1] for pred in filtered_predictions}
top_5_dict = {emo[0]: emo[1] for emo in top_5_emotions}
data = {
"Emotion": [],
"Confidence": [],
"Category": []
}
for emotion in emotions:
if emotion in thresholded_dict:
data["Emotion"].append(emotion)
data["Confidence"].append(thresholded_dict[emotion])
data["Category"].append("Above Threshold")
if emotion in top_5_dict:
data["Emotion"].append(emotion)
data["Confidence"].append(top_5_dict[emotion])
data["Category"].append("Top 5")
df = pd.DataFrame(data)
fig = px.bar(
df,
x="Emotion",
y="Confidence",
color="Category",
barmode="group",
title="Emotion Confidence Comparison",
height=400,
color_discrete_map={"Above Threshold": "#6366f1", "Top 5": "#10b981"}
)
fig.update_traces(texttemplate='%{y:.2f}', textposition='auto')
fig.update_layout(
margin=dict(t=50, b=50),
xaxis_title="",
yaxis_title="Confidence",
legend_title="",
legend=dict(orientation="h", yanchor="bottom", y=1.05, xanchor="center", x=0.5),
plot_bgcolor="rgba(0,0,0,0)",
paper_bgcolor="rgba(0,0,0,0)",
font=dict(color="#e5e7eb")
)
return processed_text, thresholded_output, top_5_output, fig
# Enhanced CSS with modern design
custom_css = """
body {
font-family: 'Inter', -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif;
background: linear-gradient(135deg, #111827 0%, #1f2937 100%);
color: #e5e7eb;
margin: 0;
padding: 24px;
min-height: 100vh;
display: flex;
flex-direction: column;
align-items: center;
}
.gr-panel {
border-radius: 16px;
box-shadow: 0 10px 30px rgba(0,0,0,0.3);
background: rgba(31, 41, 55, 0.9);
backdrop-filter: blur(12px);
padding: 32px;
margin: 24px auto;
max-width: 960px;
width: 100%;
border: 1px solid rgba(255, 255, 255, 0.1);
transition: transform 0.3s ease, box-shadow 0.3s ease;
}
.gr-panel:hover {
transform: translateY(-4px);
box-shadow: 0 12px 40px rgba(0,0,0,0.35);
}
.gr-button {
border-radius: 8px;
padding: 12px 32px;
font-weight: 600;
font-size: 16px;
background: linear-gradient(90deg, #6366f1 0%, #8b5cf6 100%);
color: #ffffff;
border: none;
transition: all 0.3s ease;
cursor: pointer;
margin-top: 16px;
}
.gr-button:hover {
background: linear-gradient(90deg, #8b5cf6 0%, #6366f1 100%);
transform: translateY(-2px);
box-shadow: 0 6px 20px rgba(99, 102, 241, 0.4);
}
.gr-button:focus {
outline: none;
box-shadow: 0 0 0 3px rgba(99, 102, 241, 0.3);
}
.gr-textbox, .gr-slider {
margin-bottom: 24px;
}
.gr-textbox label, .gr-slider label {
font-size: 16px;
font-weight: 600;
color: #e5e7eb;
margin-bottom: 8px;
display: block;
}
.gr-textbox textarea, .gr-textbox input {
border: 1px solid rgba(255, 255, 255, 0.15);
border-radius: 8px;
padding: 12px;
font-size: 16px;
background: rgba(55, 65, 81, 0.5);
color: #e5e7eb;
transition: border-color 0.3s ease, box-shadow 0.3s ease;
}
.gr-textbox textarea:focus, .gr-textbox input:focus {
border-color: #6366f1;
box-shadow: 0 0 0 3px rgba(99, 102, 241, 0.2);
outline: none;
}
#title {
font-size: 2.5rem;
font-weight: 800;
color: #ffffff;
text-align: center;
margin: 32px 0 16px 0;
background: linear-gradient(90deg, #6366f1, #8b5cf6);
-webkit-background-clip: text;
-webkit-text-fill-color: transparent;
}
#description {
font-size: 1.125rem;
color: #d1d5db;
text-align: center;
max-width: 720px;
margin: 0 auto 48px auto;
line-height: 1.75;
}
#examples-title {
font-size: 1.25rem;
font-weight: 600;
color: #e5e7eb;
margin: 32px 0 16px 0;
text-align: center;
}
footer {
text-align: center;
margin: 48px 0 24px 0;
padding: 16px;
font-size: 14px;
color: #d1d5db;
}
footer a {
color: #6366f1;
text-decoration: none;
font-weight: 500;
transition: color 0.3s ease;
}
footer a:hover {
color: #8b5cf6;
}
.gr-plot {
margin-top: 24px;
background: rgba(31, 41, 55, 0.5);
border-radius: 12px;
padding: 16px;
border: 1px solid rgba(255, 255, 255, 0.1);
}
.gr-examples .example {
background: rgba(55, 65, 81, 0.7);
border-radius: 10px;
padding: 16px;
margin: 12px 0;
transition: all 0.3s ease;
cursor: pointer;
border: 1px solid rgba(255, 255, 255, 0.1);
}
.gr-examples .example:hover {
background: rgba(99, 102, 241, 0.15);
transform: translateY(-2px);
border-color: #6366f1;
}
@keyframes fadeIn {
from { opacity: 0; transform: translateY(16px); }
to { opacity: 1; transform: translateY(0); }
}
.gr-panel, #title, #description, footer, .gr-examples .example {
animation: fadeIn 0.6s ease-out;
}
/* Responsive design */
@media (max-width: 768px) {
.gr-panel {
padding: 24px;
margin: 16px;
}
#title {
font-size: 2rem;
}
#description {
font-size: 1rem;
}
.gr-button {
padding: 10px 24px;
font-size: 14px;
}
}
"""
# Gradio Blocks UI (Modernized)
with gr.Blocks(css=custom_css) as demo:
# Header
gr.Markdown(
"<div id='title'>Emotion Analyzer BERT</div>",
elem_id="title"
)
gr.Markdown(
"""
<div id='description'>
Uncover the emotions in your text with our fine-tuned BERT model, trained on the GoEmotions dataset.
Enter your text, fine-tune the confidence threshold, and visualize the results in a sleek, interactive chart.
</div>
""",
elem_id="description"
)
# Input Section
with gr.Group():
with gr.Row():
with gr.Column(scale=3):
text_input = gr.Textbox(
label="Enter Your Text",
placeholder="Try: 'I'm over the moon today!' or 'This is so frustrating... 😣'",
lines=4,
show_label=True,
elem_classes=["input-textbox"]
)
with gr.Column(scale=1):
confidence_slider = gr.Slider(
minimum=0.0,
maximum=0.9,
value=0.0,
step=0.05,
label="Confidence Threshold",
info="Filter emotions below this confidence level",
elem_classes=["input-slider"]
)
submit_btn = gr.Button("Analyze Emotions", variant="primary")
# Output Section
with gr.Group():
with gr.Row():
with gr.Column(scale=1):
processed_text_output = gr.Textbox(
label="Processed Text",
lines=2,
interactive=False,
elem_classes=["output-textbox"]
)
thresholded_output = gr.Textbox(
label="Detected Emotions (Above Threshold)",
lines=5,
interactive=False,
elem_classes=["output-textbox"]
)
top_5_output = gr.Textbox(
label="Top 5 Emotions",
lines=5,
interactive=False,
elem_classes=["output-textbox"]
)
with gr.Column(scale=2):
output_plot = gr.Plot(
label="Emotion Confidence Visualization",
elem_classes=["output-plot"]
)
# Example carousel
with gr.Group():
gr.Markdown(
"<div id='examples-title'>Try These Examples</div>",
elem_id="examples-title"
)
examples = gr.Examples(
examples=[
["I’m thrilled to win this award! πŸ˜„", "Joy Example"],
["This is so frustrating, nothing works. 😣", "Annoyance Example"],
["I feel so sorry for what happened. 😒", "Sadness Example"],
["What a beautiful day to be alive! 🌞", "Admiration Example"],
["Feeling nervous about the exam tomorrow πŸ˜“ u/student r/study", "Nervousness Example"]
],
inputs=[text_input],
label=""
)
# Footer
gr.HTML(
"""
<footer>
Created by logasanjeev |
<a href="https://huggingface.co/logasanjeev/emotion-analyzer-bert">Model Card</a> |
<a href="https://www.kaggle.com/code/ravindranlogasanjeev/evaluation-logasanjeev-emotion-analyzer-bert/notebook">Kaggle Notebook</a>
</footer>
"""
)
# Bind predictions
submit_btn.click(
fn=predict_emotions_with_details,
inputs=[text_input, confidence_slider],
outputs=[processed_text_output, thresholded_output, top_5_output, output_plot]
)
# Launch
if __name__ == "__main__":
demo.launch()