Update app.py
Browse files
app.py
CHANGED
@@ -1,76 +1,86 @@
|
|
1 |
import gradio as gr
|
2 |
-
import torch
|
3 |
-
import numpy as np
|
4 |
-
from transformers import BertForSequenceClassification, BertTokenizer
|
5 |
-
import requests
|
6 |
-
import json
|
7 |
-
import plotly.express as px
|
8 |
import pandas as pd
|
|
|
|
|
|
|
|
|
|
|
9 |
|
10 |
-
# Load model and tokenizer from Hugging Face Hub
|
11 |
repo_id = "logasanjeev/goemotions-bert"
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
thresholds_data = json.loads(response.text)
|
24 |
-
emotion_labels = thresholds_data["emotion_labels"]
|
25 |
-
default_thresholds = thresholds_data["thresholds"]
|
26 |
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
31 |
padding='max_length',
|
32 |
truncation=True,
|
33 |
max_length=128,
|
34 |
return_tensors='pt'
|
35 |
)
|
36 |
-
input_ids = encodings['input_ids'].to(
|
37 |
-
attention_mask = encodings['attention_mask'].to(
|
38 |
|
39 |
with torch.no_grad():
|
40 |
-
outputs =
|
41 |
logits = torch.sigmoid(outputs.logits).cpu().numpy()[0]
|
42 |
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
if logit >= adjusted_thresh:
|
48 |
-
predictions.append((emotion_labels[i], logit))
|
49 |
|
50 |
-
|
51 |
-
|
52 |
-
|
|
|
|
|
|
|
53 |
|
54 |
-
|
55 |
-
|
|
|
|
|
56 |
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
|
|
|
|
70 |
|
71 |
-
return
|
72 |
|
73 |
-
# Custom CSS for modern UI
|
74 |
custom_css = """
|
75 |
body {
|
76 |
font-family: 'Segoe UI', Arial, sans-serif;
|
@@ -123,7 +133,6 @@ body {
|
|
123 |
}
|
124 |
"""
|
125 |
|
126 |
-
# JavaScript for theme toggle
|
127 |
theme_js = """
|
128 |
function toggleTheme() {
|
129 |
document.body.classList.toggle('dark-mode');
|
@@ -138,8 +147,8 @@ with gr.Blocks(css=custom_css) as demo:
|
|
138 |
"""
|
139 |
<div id='description'>
|
140 |
Predict emotions from text using a fine-tuned BERT-base model.
|
141 |
-
Explore 28 emotions with optimized thresholds (Micro F1: 0.
|
142 |
-
|
143 |
</div>
|
144 |
""",
|
145 |
elem_id="description"
|
@@ -173,8 +182,10 @@ with gr.Blocks(css=custom_css) as demo:
|
|
173 |
submit_btn = gr.Button("Predict Emotions", variant="primary")
|
174 |
|
175 |
with gr.Column(scale=1):
|
176 |
-
|
177 |
-
|
|
|
|
|
178 |
|
179 |
# Example carousel
|
180 |
examples = gr.Examples(
|
@@ -191,9 +202,9 @@ with gr.Blocks(css=custom_css) as demo:
|
|
191 |
|
192 |
# Bind prediction
|
193 |
submit_btn.click(
|
194 |
-
fn=
|
195 |
inputs=[text_input, confidence_slider],
|
196 |
-
outputs=[
|
197 |
)
|
198 |
|
199 |
# Launch
|
|
|
1 |
import gradio as gr
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
import pandas as pd
|
3 |
+
import plotly.express as px
|
4 |
+
import shutil
|
5 |
+
import os
|
6 |
+
from huggingface_hub import hf_hub_download
|
7 |
+
from importlib import import_module
|
8 |
|
|
|
9 |
repo_id = "logasanjeev/goemotions-bert"
|
10 |
+
local_file = hf_hub_download(repo_id=repo_id, filename="inference.py")
|
11 |
+
print("Downloaded inference.py successfully!")
|
12 |
+
|
13 |
+
current_dir = os.getcwd()
|
14 |
+
destination = os.path.join(current_dir, "inference.py")
|
15 |
+
shutil.copy(local_file, destination)
|
16 |
+
print("Copied inference.py to current directory!")
|
17 |
|
18 |
+
inference_module = import_module("inference")
|
19 |
+
predict_emotions = inference_module.predict_emotions
|
20 |
+
print("Imported predict_emotions successfully!")
|
|
|
|
|
|
|
21 |
|
22 |
+
_, _ = predict_emotions("dummy text")
|
23 |
+
emotion_labels = inference_module.EMOTION_LABELS
|
24 |
+
default_thresholds = inference_module.THRESHOLDS
|
25 |
+
|
26 |
+
def predict_emotions_with_details(text, confidence_threshold=0.0):
|
27 |
+
predictions_str, processed_text = predict_emotions(text)
|
28 |
+
|
29 |
+
predictions = []
|
30 |
+
if predictions_str != "No emotions predicted.":
|
31 |
+
for line in predictions_str.split("\n"):
|
32 |
+
emotion, confidence = line.split(": ")
|
33 |
+
predictions.append((emotion, float(confidence)))
|
34 |
+
|
35 |
+
encodings = inference_module.TOKENIZER(
|
36 |
+
processed_text,
|
37 |
padding='max_length',
|
38 |
truncation=True,
|
39 |
max_length=128,
|
40 |
return_tensors='pt'
|
41 |
)
|
42 |
+
input_ids = encodings['input_ids'].to(inference_module.DEVICE)
|
43 |
+
attention_mask = encodings['attention_mask'].to(inference_module.DEVICE)
|
44 |
|
45 |
with torch.no_grad():
|
46 |
+
outputs = inference_module.MODEL(input_ids, attention_mask=attention_mask)
|
47 |
logits = torch.sigmoid(outputs.logits).cpu().numpy()[0]
|
48 |
|
49 |
+
all_emotions = [(emotion_labels[i], round(logit, 4)) for i, logit in enumerate(logits)]
|
50 |
+
all_emotions.sort(key=lambda x: x[1], reverse=True)
|
51 |
+
top_5_emotions = all_emotions[:5]
|
52 |
+
top_5_output = "\n".join([f"{emotion}: {confidence:.4f}" for emotion, confidence in top_5_emotions])
|
|
|
|
|
53 |
|
54 |
+
filtered_predictions = []
|
55 |
+
for emotion, confidence in predictions:
|
56 |
+
thresh = default_thresholds[emotion_labels.index(emotion)]
|
57 |
+
adjusted_thresh = max(thresh, confidence_threshold)
|
58 |
+
if confidence >= adjusted_thresh:
|
59 |
+
filtered_predictions.append((emotion, confidence))
|
60 |
|
61 |
+
if not filtered_predictions:
|
62 |
+
thresholded_output = "No emotions predicted above thresholds."
|
63 |
+
else:
|
64 |
+
thresholded_output = "\n".join([f"{emotion}: {confidence:.4f}" for emotion, confidence in filtered_predictions])
|
65 |
|
66 |
+
if filtered_predictions:
|
67 |
+
df = pd.DataFrame(filtered_predictions, columns=["Emotion", "Confidence"])
|
68 |
+
fig = px.bar(
|
69 |
+
df,
|
70 |
+
x="Emotion",
|
71 |
+
y="Confidence",
|
72 |
+
color="Emotion",
|
73 |
+
text="Confidence",
|
74 |
+
title="Emotion Confidence Levels (Above Threshold)",
|
75 |
+
height=400
|
76 |
+
)
|
77 |
+
fig.update_traces(texttemplate='%{text:.2f}', textposition='auto')
|
78 |
+
fig.update_layout(showlegend=False, margin=dict(t=40, b=40))
|
79 |
+
else:
|
80 |
+
fig = None
|
81 |
|
82 |
+
return processed_text, thresholded_output, top_5_output, fig
|
83 |
|
|
|
84 |
custom_css = """
|
85 |
body {
|
86 |
font-family: 'Segoe UI', Arial, sans-serif;
|
|
|
133 |
}
|
134 |
"""
|
135 |
|
|
|
136 |
theme_js = """
|
137 |
function toggleTheme() {
|
138 |
document.body.classList.toggle('dark-mode');
|
|
|
147 |
"""
|
148 |
<div id='description'>
|
149 |
Predict emotions from text using a fine-tuned BERT-base model.
|
150 |
+
Explore 28 emotions with optimized thresholds (Micro F1: 0.6006).
|
151 |
+
View preprocessed text, top 5 emotions, and thresholded predictions!
|
152 |
</div>
|
153 |
""",
|
154 |
elem_id="description"
|
|
|
182 |
submit_btn = gr.Button("Predict Emotions", variant="primary")
|
183 |
|
184 |
with gr.Column(scale=1):
|
185 |
+
processed_text_output = gr.Textbox(label="Preprocessed Text", lines=2)
|
186 |
+
thresholded_output = gr.Textbox(label="Predicted Emotions (Above Threshold)", lines=5)
|
187 |
+
top_5_output = gr.Textbox(label="Top 5 Emotions (Regardless of Threshold)", lines=5)
|
188 |
+
output_plot = gr.Plot(label="Emotion Confidence Chart (Above Threshold)")
|
189 |
|
190 |
# Example carousel
|
191 |
examples = gr.Examples(
|
|
|
202 |
|
203 |
# Bind prediction
|
204 |
submit_btn.click(
|
205 |
+
fn=predict_emotions_with_details,
|
206 |
inputs=[text_input, confidence_slider],
|
207 |
+
outputs=[processed_text_output, thresholded_output, top_5_output, output_plot]
|
208 |
)
|
209 |
|
210 |
# Launch
|