|
import gradio as gr |
|
import pandas as pd |
|
import plotly.express as px |
|
import torch |
|
from huggingface_hub import hf_hub_download |
|
from importlib import import_module |
|
import shutil |
|
import os |
|
|
|
|
|
repo_id = "logasanjeev/emotion-analyzer-bert" |
|
local_file = hf_hub_download(repo_id=repo_id, filename="inference.py") |
|
print("Downloaded inference.py successfully!") |
|
|
|
current_dir = os.getcwd() |
|
destination = os.path.join(current_dir, "inference.py") |
|
shutil.copy(local_file, destination) |
|
print("Copied inference.py to current directory!") |
|
|
|
inference_module = import_module("inference") |
|
predict_emotions = inference_module.predict_emotions |
|
print("Imported predict_emotions successfully!") |
|
|
|
_, _ = predict_emotions("dummy text") |
|
emotion_labels = inference_module.EMOTION_LABELS |
|
default_thresholds = inference_module.THRESHOLDS |
|
|
|
|
|
def predict_emotions_with_details(text, confidence_threshold=0.0): |
|
if not text.strip(): |
|
return "Please enter some text.", "", "", None |
|
|
|
predictions_str, processed_text = predict_emotions(text) |
|
|
|
|
|
predictions = [] |
|
if predictions_str != "No emotions predicted.": |
|
for line in predictions_str.split("\n"): |
|
emotion, confidence = line.split(": ") |
|
predictions.append((emotion, float(confidence))) |
|
|
|
|
|
encodings = inference_module.TOKENIZER( |
|
processed_text, |
|
padding='max_length', |
|
truncation=True, |
|
max_length=128, |
|
return_tensors='pt' |
|
) |
|
input_ids = encodings['input_ids'].to(inference_module.DEVICE) |
|
attention_mask = encodings['attention_mask'].to(inference_module.DEVICE) |
|
|
|
with torch.no_grad(): |
|
outputs = inference_module.MODEL(input_ids, attention_mask=attention_mask) |
|
logits = torch.sigmoid(outputs.logits).cpu().numpy()[0] |
|
|
|
|
|
all_emotions = [(emotion_labels[i], round(logit, 4)) for i, logit in enumerate(logits)] |
|
all_emotions.sort(key=lambda x: x[1], reverse=True) |
|
top_5_emotions = all_emotions[:5] |
|
top_5_output = "\n".join([f"{emotion}: {confidence:.4f}" for emotion, confidence in top_5_emotions]) |
|
|
|
|
|
filtered_predictions = [] |
|
for emotion, confidence in predictions: |
|
thresh = default_thresholds[emotion_labels.index(emotion)] |
|
adjusted_thresh = max(thresh, confidence_threshold) |
|
if confidence >= adjusted_thresh: |
|
filtered_predictions.append((emotion, confidence)) |
|
|
|
if not filtered_predictions: |
|
thresholded_output = "No emotions predicted above thresholds." |
|
else: |
|
thresholded_output = "\n".join([f"{emotion}: {confidence:.4f}" for emotion, confidence in filtered_predictions]) |
|
|
|
|
|
fig = None |
|
if filtered_predictions or top_5_emotions: |
|
emotions = set([pred[0] for pred in filtered_predictions] + [emo[0] for emo in top_5_emotions]) |
|
thresholded_dict = {pred[0]: pred[1] for pred in filtered_predictions} |
|
top_5_dict = {emo[0]: emo[1] for emo in top_5_emotions} |
|
|
|
data = { |
|
"Emotion": [], |
|
"Confidence": [], |
|
"Category": [] |
|
} |
|
|
|
for emotion in emotions: |
|
if emotion in thresholded_dict: |
|
data["Emotion"].append(emotion) |
|
data["Confidence"].append(thresholded_dict[emotion]) |
|
data["Category"].append("Above Threshold") |
|
if emotion in top_5_dict: |
|
data["Emotion"].append(emotion) |
|
data["Confidence"].append(top_5_dict[emotion]) |
|
data["Category"].append("Top 5") |
|
|
|
df = pd.DataFrame(data) |
|
|
|
fig = px.bar( |
|
df, |
|
x="Emotion", |
|
y="Confidence", |
|
color="Category", |
|
barmode="group", |
|
title="Emotion Confidence Comparison", |
|
height=400, |
|
color_discrete_map={"Above Threshold": "#6366f1", "Top 5": "#10b981"} |
|
) |
|
fig.update_traces(texttemplate='%{y:.2f}', textposition='auto') |
|
fig.update_layout( |
|
margin=dict(t=50, b=50), |
|
xaxis_title="", |
|
yaxis_title="Confidence", |
|
legend_title="", |
|
legend=dict(orientation="h", yanchor="bottom", y=1.05, xanchor="center", x=0.5), |
|
plot_bgcolor="rgba(0,0,0,0)", |
|
paper_bgcolor="rgba(0,0,0,0)", |
|
font=dict(color="#e5e7eb") |
|
) |
|
|
|
return processed_text, thresholded_output, top_5_output, fig |
|
|
|
|
|
custom_css = """ |
|
body { |
|
font-family: 'Inter', -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif; |
|
background: linear-gradient(135deg, #111827 0%, #1f2937 100%); |
|
color: #e5e7eb; |
|
margin: 0; |
|
padding: 24px; |
|
min-height: 100vh; |
|
display: flex; |
|
flex-direction: column; |
|
align-items: center; |
|
} |
|
|
|
.gr-panel { |
|
border-radius: 16px; |
|
box-shadow: 0 10px 30px rgba(0,0,0,0.3); |
|
background: rgba(31, 41, 55, 0.9); |
|
backdrop-filter: blur(12px); |
|
padding: 32px; |
|
margin: 24px auto; |
|
max-width: 960px; |
|
width: 100%; |
|
border: 1px solid rgba(255, 255, 255, 0.1); |
|
transition: transform 0.3s ease, box-shadow 0.3s ease; |
|
} |
|
|
|
.gr-panel:hover { |
|
transform: translateY(-4px); |
|
box-shadow: 0 12px 40px rgba(0,0,0,0.35); |
|
} |
|
|
|
.gr-button { |
|
border-radius: 8px; |
|
padding: 12px 32px; |
|
font-weight: 600; |
|
font-size: 16px; |
|
background: linear-gradient(90deg, #6366f1 0%, #8b5cf6 100%); |
|
color: #ffffff; |
|
border: none; |
|
transition: all 0.3s ease; |
|
cursor: pointer; |
|
margin-top: 16px; |
|
} |
|
|
|
.gr-button:hover { |
|
background: linear-gradient(90deg, #8b5cf6 0%, #6366f1 100%); |
|
transform: translateY(-2px); |
|
box-shadow: 0 6px 20px rgba(99, 102, 241, 0.4); |
|
} |
|
|
|
.gr-button:focus { |
|
outline: none; |
|
box-shadow: 0 0 0 3px rgba(99, 102, 241, 0.3); |
|
} |
|
|
|
.gr-textbox, .gr-slider { |
|
margin-bottom: 24px; |
|
} |
|
|
|
.gr-textbox label, .gr-slider label { |
|
font-size: 16px; |
|
font-weight: 600; |
|
color: #e5e7eb; |
|
margin-bottom: 8px; |
|
display: block; |
|
} |
|
|
|
.gr-textbox textarea, .gr-textbox input { |
|
border: 1px solid rgba(255, 255, 255, 0.15); |
|
border-radius: 8px; |
|
padding: 12px; |
|
font-size: 16px; |
|
background: rgba(55, 65, 81, 0.5); |
|
color: #e5e7eb; |
|
transition: border-color 0.3s ease, box-shadow 0.3s ease; |
|
} |
|
|
|
.gr-textbox textarea:focus, .gr-textbox input:focus { |
|
border-color: #6366f1; |
|
box-shadow: 0 0 0 3px rgba(99, 102, 241, 0.2); |
|
outline: none; |
|
} |
|
|
|
#title { |
|
font-size: 2.5rem; |
|
font-weight: 800; |
|
color: #ffffff; |
|
text-align: center; |
|
margin: 32px 0 16px 0; |
|
background: linear-gradient(90deg, #6366f1, #8b5cf6); |
|
-webkit-background-clip: text; |
|
-webkit-text-fill-color: transparent; |
|
} |
|
|
|
#description { |
|
font-size: 1.125rem; |
|
color: #d1d5db; |
|
text-align: center; |
|
max-width: 720px; |
|
margin: 0 auto 48px auto; |
|
line-height: 1.75; |
|
} |
|
|
|
#examples-title { |
|
font-size: 1.25rem; |
|
font-weight: 600; |
|
color: #e5e7eb; |
|
margin: 32px 0 16px 0; |
|
text-align: center; |
|
} |
|
|
|
footer { |
|
text-align: center; |
|
margin: 48px 0 24px 0; |
|
padding: 16px; |
|
font-size: 14px; |
|
color: #d1d5db; |
|
} |
|
|
|
footer a { |
|
color: #6366f1; |
|
text-decoration: none; |
|
font-weight: 500; |
|
transition: color 0.3s ease; |
|
} |
|
|
|
footer a:hover { |
|
color: #8b5cf6; |
|
} |
|
|
|
.gr-plot { |
|
margin-top: 24px; |
|
background: rgba(31, 41, 55, 0.5); |
|
border-radius: 12px; |
|
padding: 16px; |
|
border: 1px solid rgba(255, 255, 255, 0.1); |
|
} |
|
|
|
.gr-examples .example { |
|
background: rgba(55, 65, 81, 0.7); |
|
border-radius: 10px; |
|
padding: 16px; |
|
margin: 12px 0; |
|
transition: all 0.3s ease; |
|
cursor: pointer; |
|
border: 1px solid rgba(255, 255, 255, 0.1); |
|
} |
|
|
|
.gr-examples .example:hover { |
|
background: rgba(99, 102, 241, 0.15); |
|
transform: translateY(-2px); |
|
border-color: #6366f1; |
|
} |
|
|
|
@keyframes fadeIn { |
|
from { opacity: 0; transform: translateY(16px); } |
|
to { opacity: 1; transform: translateY(0); } |
|
} |
|
|
|
.gr-panel, #title, #description, footer, .gr-examples .example { |
|
animation: fadeIn 0.6s ease-out; |
|
} |
|
|
|
/* Responsive design */ |
|
@media (max-width: 768px) { |
|
.gr-panel { |
|
padding: 24px; |
|
margin: 16px; |
|
} |
|
#title { |
|
font-size: 2rem; |
|
} |
|
#description { |
|
font-size: 1rem; |
|
} |
|
.gr-button { |
|
padding: 10px 24px; |
|
font-size: 14px; |
|
} |
|
} |
|
""" |
|
|
|
|
|
with gr.Blocks(css=custom_css) as demo: |
|
|
|
gr.Markdown( |
|
"<div id='title'>Emotion Analyzer BERT</div>", |
|
elem_id="title" |
|
) |
|
gr.Markdown( |
|
""" |
|
<div id='description'> |
|
Uncover the emotions in your text with our fine-tuned BERT model, trained on the GoEmotions dataset. |
|
Enter your text, fine-tune the confidence threshold, and visualize the results in a sleek, interactive chart. |
|
</div> |
|
""", |
|
elem_id="description" |
|
) |
|
|
|
|
|
with gr.Group(): |
|
with gr.Row(): |
|
with gr.Column(scale=3): |
|
text_input = gr.Textbox( |
|
label="Enter Your Text", |
|
placeholder="Try: 'I'm over the moon today!' or 'This is so frustrating... π£'", |
|
lines=4, |
|
show_label=True, |
|
elem_classes=["input-textbox"] |
|
) |
|
with gr.Column(scale=1): |
|
confidence_slider = gr.Slider( |
|
minimum=0.0, |
|
maximum=0.9, |
|
value=0.0, |
|
step=0.05, |
|
label="Confidence Threshold", |
|
info="Filter emotions below this confidence level", |
|
elem_classes=["input-slider"] |
|
) |
|
submit_btn = gr.Button("Analyze Emotions", variant="primary") |
|
|
|
|
|
with gr.Group(): |
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
processed_text_output = gr.Textbox( |
|
label="Processed Text", |
|
lines=2, |
|
interactive=False, |
|
elem_classes=["output-textbox"] |
|
) |
|
thresholded_output = gr.Textbox( |
|
label="Detected Emotions (Above Threshold)", |
|
lines=5, |
|
interactive=False, |
|
elem_classes=["output-textbox"] |
|
) |
|
top_5_output = gr.Textbox( |
|
label="Top 5 Emotions", |
|
lines=5, |
|
interactive=False, |
|
elem_classes=["output-textbox"] |
|
) |
|
with gr.Column(scale=2): |
|
output_plot = gr.Plot( |
|
label="Emotion Confidence Visualization", |
|
elem_classes=["output-plot"] |
|
) |
|
|
|
|
|
with gr.Group(): |
|
gr.Markdown( |
|
"<div id='examples-title'>Try These Examples</div>", |
|
elem_id="examples-title" |
|
) |
|
examples = gr.Examples( |
|
examples=[ |
|
["Iβm thrilled to win this award! π", "Joy Example"], |
|
["This is so frustrating, nothing works. π£", "Annoyance Example"], |
|
["I feel so sorry for what happened. π’", "Sadness Example"], |
|
["What a beautiful day to be alive! π", "Admiration Example"], |
|
["Feeling nervous about the exam tomorrow π u/student r/study", "Nervousness Example"] |
|
], |
|
inputs=[text_input], |
|
label="" |
|
) |
|
|
|
|
|
gr.HTML( |
|
""" |
|
<footer> |
|
Created by logasanjeev | |
|
<a href="https://huggingface.co/logasanjeev/emotion-analyzer-bert">Model Card</a> | |
|
<a href="https://www.kaggle.com/code/ravindranlogasanjeev/evaluation-logasanjeev-emotion-analyzer-bert/notebook">Kaggle Notebook</a> |
|
</footer> |
|
""" |
|
) |
|
|
|
|
|
submit_btn.click( |
|
fn=predict_emotions_with_details, |
|
inputs=[text_input, confidence_slider], |
|
outputs=[processed_text_output, thresholded_output, top_5_output, output_plot] |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
demo.launch() |