logasanjeev commited on
Commit
0ce1d0f
·
verified ·
1 Parent(s): e4fe643

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +223 -77
app.py CHANGED
@@ -1,12 +1,14 @@
1
  import gradio as gr
2
  import pandas as pd
3
  import plotly.express as px
 
4
  import shutil
5
  import os
6
  import torch
7
  from huggingface_hub import hf_hub_download
8
  from importlib import import_module
9
 
 
10
  repo_id = "logasanjeev/goemotions-bert"
11
  local_file = hf_hub_download(repo_id=repo_id, filename="inference.py")
12
  print("Downloaded inference.py successfully!")
@@ -24,15 +26,21 @@ _, _ = predict_emotions("dummy text")
24
  emotion_labels = inference_module.EMOTION_LABELS
25
  default_thresholds = inference_module.THRESHOLDS
26
 
27
- def predict_emotions_with_details(text, confidence_threshold=0.0):
 
 
 
 
28
  predictions_str, processed_text = predict_emotions(text)
29
 
 
30
  predictions = []
31
  if predictions_str != "No emotions predicted.":
32
  for line in predictions_str.split("\n"):
33
  emotion, confidence = line.split(": ")
34
  predictions.append((emotion, float(confidence)))
35
 
 
36
  encodings = inference_module.TOKENIZER(
37
  processed_text,
38
  padding='max_length',
@@ -47,11 +55,13 @@ def predict_emotions_with_details(text, confidence_threshold=0.0):
47
  outputs = inference_module.MODEL(input_ids, attention_mask=attention_mask)
48
  logits = torch.sigmoid(outputs.logits).cpu().numpy()[0]
49
 
 
50
  all_emotions = [(emotion_labels[i], round(logit, 4)) for i, logit in enumerate(logits)]
51
  all_emotions.sort(key=lambda x: x[1], reverse=True)
52
  top_5_emotions = all_emotions[:5]
53
  top_5_output = "\n".join([f"{emotion}: {confidence:.4f}" for emotion, confidence in top_5_emotions])
54
 
 
55
  filtered_predictions = []
56
  for emotion, confidence in predictions:
57
  thresh = default_thresholds[emotion_labels.index(emotion)]
@@ -64,67 +74,107 @@ def predict_emotions_with_details(text, confidence_threshold=0.0):
64
  else:
65
  thresholded_output = "\n".join([f"{emotion}: {confidence:.4f}" for emotion, confidence in filtered_predictions])
66
 
 
 
 
67
  if filtered_predictions:
68
  df = pd.DataFrame(filtered_predictions, columns=["Emotion", "Confidence"])
69
- fig = px.bar(
70
- df,
71
- x="Emotion",
72
- y="Confidence",
73
- color="Emotion",
74
- text="Confidence",
75
- title="Emotion Confidence Levels (Above Threshold)",
76
- height=400
77
- )
78
- fig.update_traces(texttemplate='%{text:.2f}', textposition='auto')
79
- fig.update_layout(showlegend=False, margin=dict(t=40, b=40))
80
- else:
81
- fig = None
 
 
 
 
 
 
 
 
 
 
 
 
 
82
 
83
- return processed_text, thresholded_output, top_5_output, fig
84
 
 
85
  custom_css = """
86
  body {
87
- font-family: 'Segoe UI', Arial, sans-serif;
 
88
  }
89
  .gr-panel {
90
- border-radius: 12px;
91
- box-shadow: 0 4px 12px rgba(0,0,0,0.1);
92
- background: linear-gradient(145deg, #ffffff, #f0f4f8);
 
 
93
  }
94
  .gr-button {
95
  border-radius: 8px;
96
- background: #007bff;
 
 
 
 
 
97
  color: white;
98
- padding: 10px 20px;
99
- transition: background 0.3s;
100
  }
101
- .gr-button:hover {
102
- background: #0056b3;
 
 
 
 
 
 
 
103
  }
104
  #title {
105
- font-size: 2.5em;
 
106
  color: #1a3c6e;
107
  text-align: center;
108
- margin-bottom: 20px;
109
  }
110
  #description {
111
- font-size: 1.1em;
112
- color: #333;
113
  text-align: center;
114
- max-width: 700px;
115
- margin: 0 auto;
116
  }
117
  #theme-toggle {
118
- position: absolute;
119
  top: 20px;
120
  right: 20px;
 
 
 
 
 
 
 
 
121
  }
122
  .dark-mode {
123
- background: #1a1a1a;
124
  color: #e0e0e0;
125
  }
126
  .dark-mode .gr-panel {
127
- background: linear-gradient(145deg, #2a2a2a, #3a3a3a);
 
128
  }
129
  .dark-mode #title {
130
  color: #66b3ff;
@@ -132,80 +182,176 @@ body {
132
  .dark-mode #description {
133
  color: #b0b0b0;
134
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
135
  """
136
 
 
137
  theme_js = """
138
  function toggleTheme() {
139
  document.body.classList.toggle('dark-mode');
 
 
 
 
 
 
 
 
140
  }
141
  """
142
 
143
  # Gradio Blocks UI
144
  with gr.Blocks(css=custom_css) as demo:
 
 
 
 
 
 
 
 
145
  # Header
146
  gr.Markdown("<div id='title'>GoEmotions BERT Classifier</div>", elem_id="title")
147
  gr.Markdown(
148
  """
149
  <div id='description'>
150
- Predict emotions from text using a fine-tuned BERT-base model.
151
- Explore 28 emotions with optimized thresholds (Micro F1: 0.6006).
152
- View preprocessed text, top 5 emotions, and thresholded predictions!
153
  </div>
154
  """,
155
  elem_id="description"
156
  )
157
 
158
- # Theme toggle button
159
  with gr.Row():
160
- gr.HTML(
161
- """
162
- <button id='theme-toggle' onclick='toggleTheme()'>Toggle Dark Mode</button>
163
- <script>{}</script>
164
- """.format(theme_js)
165
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
166
 
167
- # Main input and output
 
 
 
168
  with gr.Row():
169
  with gr.Column(scale=1):
170
- text_input = gr.Textbox(
171
- label="Enter Your Text",
172
- placeholder="Type something like 'I’m just chilling today'...",
173
- lines=3
174
- )
175
- confidence_slider = gr.Slider(
176
- minimum=0.0,
177
- maximum=0.9,
178
- value=0.0,
179
- step=0.05,
180
- label="Minimum Confidence Threshold",
181
- info="Adjust to filter low-confidence predictions"
182
- )
183
- submit_btn = gr.Button("Predict Emotions", variant="primary")
184
-
185
- with gr.Column(scale=1):
186
- processed_text_output = gr.Textbox(label="Preprocessed Text", lines=2)
187
- thresholded_output = gr.Textbox(label="Predicted Emotions (Above Threshold)", lines=5)
188
- top_5_output = gr.Textbox(label="Top 5 Emotions (Regardless of Threshold)", lines=5)
189
- output_plot = gr.Plot(label="Emotion Confidence Chart (Above Threshold)")
190
 
191
  # Example carousel
192
- examples = gr.Examples(
193
- examples=[
194
- "I’m just chilling today.",
195
- "Thank you for saving my life!",
196
- "I’m nervous about my exam tomorrow.",
197
- "I love my new puppy so much!",
198
- "I’m so relieved the storm passed."
199
- ],
200
- inputs=text_input,
201
- label="Try These Examples"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
202
  )
203
 
204
- # Bind prediction
205
  submit_btn.click(
206
  fn=predict_emotions_with_details,
207
- inputs=[text_input, confidence_slider],
208
- outputs=[processed_text_output, thresholded_output, top_5_output, output_plot]
 
 
 
 
 
 
 
 
 
 
 
 
 
209
  )
210
 
211
  # Launch
 
1
  import gradio as gr
2
  import pandas as pd
3
  import plotly.express as px
4
+ import plotly.graph_objects as go
5
  import shutil
6
  import os
7
  import torch
8
  from huggingface_hub import hf_hub_download
9
  from importlib import import_module
10
 
11
+ # Load inference.py and model
12
  repo_id = "logasanjeev/goemotions-bert"
13
  local_file = hf_hub_download(repo_id=repo_id, filename="inference.py")
14
  print("Downloaded inference.py successfully!")
 
26
  emotion_labels = inference_module.EMOTION_LABELS
27
  default_thresholds = inference_module.THRESHOLDS
28
 
29
+ # Prediction function with export capability
30
+ def predict_emotions_with_details(text, confidence_threshold=0.0, chart_type="bar"):
31
+ if not text.strip():
32
+ return "Please enter some text.", "", "", None, None
33
+
34
  predictions_str, processed_text = predict_emotions(text)
35
 
36
+ # Parse predictions
37
  predictions = []
38
  if predictions_str != "No emotions predicted.":
39
  for line in predictions_str.split("\n"):
40
  emotion, confidence = line.split(": ")
41
  predictions.append((emotion, float(confidence)))
42
 
43
+ # Get raw logits for all emotions
44
  encodings = inference_module.TOKENIZER(
45
  processed_text,
46
  padding='max_length',
 
55
  outputs = inference_module.MODEL(input_ids, attention_mask=attention_mask)
56
  logits = torch.sigmoid(outputs.logits).cpu().numpy()[0]
57
 
58
+ # All emotions for top 5
59
  all_emotions = [(emotion_labels[i], round(logit, 4)) for i, logit in enumerate(logits)]
60
  all_emotions.sort(key=lambda x: x[1], reverse=True)
61
  top_5_emotions = all_emotions[:5]
62
  top_5_output = "\n".join([f"{emotion}: {confidence:.4f}" for emotion, confidence in top_5_emotions])
63
 
64
+ # Filter predictions based on threshold
65
  filtered_predictions = []
66
  for emotion, confidence in predictions:
67
  thresh = default_thresholds[emotion_labels.index(emotion)]
 
74
  else:
75
  thresholded_output = "\n".join([f"{emotion}: {confidence:.4f}" for emotion, confidence in filtered_predictions])
76
 
77
+ # Create visualization
78
+ fig = None
79
+ df_export = None
80
  if filtered_predictions:
81
  df = pd.DataFrame(filtered_predictions, columns=["Emotion", "Confidence"])
82
+ df_export = df.copy()
83
+
84
+ if chart_type == "bar":
85
+ fig = px.bar(
86
+ df,
87
+ x="Emotion",
88
+ y="Confidence",
89
+ color="Emotion",
90
+ text="Confidence",
91
+ title="Emotion Confidence Levels (Above Threshold)",
92
+ height=400,
93
+ color_discrete_sequence=px.colors.qualitative.Plotly
94
+ )
95
+ fig.update_traces(texttemplate='%{text:.2f}', textposition='auto')
96
+ fig.update_layout(showlegend=False, margin=dict(t=40, b=40), xaxis_title="", yaxis_title="Confidence")
97
+ else: # pie chart
98
+ fig = px.pie(
99
+ df,
100
+ names="Emotion",
101
+ values="Confidence",
102
+ title="Emotion Confidence Distribution (Above Threshold)",
103
+ height=400,
104
+ color_discrete_sequence=px.colors.qualitative.Plotly
105
+ )
106
+ fig.update_traces(textinfo='percent+label', pull=[0.1] + [0] * (len(df) - 1))
107
+ fig.update_layout(margin=dict(t=40, b=40))
108
 
109
+ return processed_text, thresholded_output, top_5_output, fig, df_export
110
 
111
+ # Custom CSS for enhanced styling
112
  custom_css = """
113
  body {
114
+ font-family: 'Inter', -apple-system, BlinkMacSystemFont, sans-serif;
115
+ background-color: #f5f7fa;
116
  }
117
  .gr-panel {
118
+ border-radius: 16px;
119
+ box-shadow: 0 6px 20px rgba(0,0,0,0.08);
120
+ background: white;
121
+ padding: 20px;
122
+ margin-bottom: 20px;
123
  }
124
  .gr-button {
125
  border-radius: 8px;
126
+ padding: 12px 24px;
127
+ font-weight: 600;
128
+ transition: all 0.3s ease;
129
+ }
130
+ .gr-button-primary {
131
+ background: #4a90e2;
132
  color: white;
 
 
133
  }
134
+ .gr-button-primary:hover {
135
+ background: #357abd;
136
+ }
137
+ .gr-button-secondary {
138
+ background: #e4e7eb;
139
+ color: #333;
140
+ }
141
+ .gr-button-secondary:hover {
142
+ background: #d1d5db;
143
  }
144
  #title {
145
+ font-size: 2.8em;
146
+ font-weight: 700;
147
  color: #1a3c6e;
148
  text-align: center;
149
+ margin-bottom: 10px;
150
  }
151
  #description {
152
+ font-size: 1.2em;
153
+ color: #555;
154
  text-align: center;
155
+ max-width: 800px;
156
+ margin: 0 auto 30px auto;
157
  }
158
  #theme-toggle {
159
+ position: fixed;
160
  top: 20px;
161
  right: 20px;
162
+ background: none;
163
+ border: none;
164
+ font-size: 1.5em;
165
+ cursor: pointer;
166
+ transition: transform 0.3s;
167
+ }
168
+ #theme-toggle:hover {
169
+ transform: scale(1.2);
170
  }
171
  .dark-mode {
172
+ background: #1e2a44;
173
  color: #e0e0e0;
174
  }
175
  .dark-mode .gr-panel {
176
+ background: #2a3a5a;
177
+ box-shadow: 0 6px 20px rgba(0,0,0,0.2);
178
  }
179
  .dark-mode #title {
180
  color: #66b3ff;
 
182
  .dark-mode #description {
183
  color: #b0b0b0;
184
  }
185
+ .dark-mode .gr-button-secondary {
186
+ background: #3a4a6a;
187
+ color: #e0e0e0;
188
+ }
189
+ .dark-mode .gr-button-secondary:hover {
190
+ background: #4a5a7a;
191
+ }
192
+ #loading {
193
+ font-style: italic;
194
+ color: #888;
195
+ text-align: center;
196
+ }
197
+ #examples-title {
198
+ font-size: 1.5em;
199
+ font-weight: 600;
200
+ color: #1a3c6e;
201
+ margin-bottom: 10px;
202
+ }
203
+ .dark-mode #examples-title {
204
+ color: #66b3ff;
205
+ }
206
+ footer {
207
+ text-align: center;
208
+ margin-top: 40px;
209
+ padding: 20px;
210
+ font-size: 0.9em;
211
+ color: #666;
212
+ }
213
+ footer a {
214
+ color: #4a90e2;
215
+ text-decoration: none;
216
+ }
217
+ footer a:hover {
218
+ text-decoration: underline;
219
+ }
220
+ .dark-mode footer {
221
+ color: #b0b0b0;
222
+ }
223
  """
224
 
225
+ # JavaScript for theme toggle and loading spinner
226
  theme_js = """
227
  function toggleTheme() {
228
  document.body.classList.toggle('dark-mode');
229
+ const toggleBtn = document.getElementById('theme-toggle');
230
+ toggleBtn.innerHTML = document.body.classList.contains('dark-mode') ? '☀️' : '🌙';
231
+ }
232
+ function showLoading() {
233
+ document.getElementById('loading').style.display = 'block';
234
+ }
235
+ function hideLoading() {
236
+ document.getElementById('loading').style.display = 'none';
237
  }
238
  """
239
 
240
  # Gradio Blocks UI
241
  with gr.Blocks(css=custom_css) as demo:
242
+ # Theme toggle button
243
+ gr.HTML(
244
+ """
245
+ <button id='theme-toggle' onclick='toggleTheme()'>🌙</button>
246
+ <script>{}</script>
247
+ """.format(theme_js)
248
+ )
249
+
250
  # Header
251
  gr.Markdown("<div id='title'>GoEmotions BERT Classifier</div>", elem_id="title")
252
  gr.Markdown(
253
  """
254
  <div id='description'>
255
+ Predict emotions from text using a fine-tuned BERT-base model on the GoEmotions dataset.
256
+ Detect 28 emotions with optimized thresholds (Micro F1: 0.6006).
257
+ View preprocessed text, top 5 emotions, and thresholded predictions with interactive visualizations!
258
  </div>
259
  """,
260
  elem_id="description"
261
  )
262
 
263
+ # Main content
264
  with gr.Row():
265
+ with gr.Column(scale=1):
266
+ # Input Section
267
+ with gr.Group():
268
+ gr.Markdown("### Input Text")
269
+ text_input = gr.Textbox(
270
+ label="Enter Your Text",
271
+ placeholder="Type something like 'I’m just chilling today'...",
272
+ lines=3,
273
+ show_label=False
274
+ )
275
+ confidence_slider = gr.Slider(
276
+ minimum=0.0,
277
+ maximum=0.9,
278
+ value=0.0,
279
+ step=0.05,
280
+ label="Minimum Confidence Threshold",
281
+ info="Filter predictions below this confidence level (default thresholds still apply)"
282
+ )
283
+ chart_type = gr.Radio(
284
+ choices=["bar", "pie"],
285
+ value="bar",
286
+ label="Chart Type",
287
+ info="Choose how to visualize the emotion confidences"
288
+ )
289
+ with gr.Row():
290
+ submit_btn = gr.Button("Predict Emotions", variant="primary")
291
+ reset_btn = gr.Button("Reset", variant="secondary")
292
 
293
+ # Loading indicator
294
+ gr.HTML("<div id='loading' style='display:none;'>Predicting emotions, please wait...</div>")
295
+
296
+ # Output Section
297
  with gr.Row():
298
  with gr.Column(scale=1):
299
+ with gr.Group():
300
+ gr.Markdown("### Results")
301
+ processed_text_output = gr.Textbox(label="Preprocessed Text", lines=2, interactive=False)
302
+ thresholded_output = gr.Textbox(label="Predicted Emotions (Above Threshold)", lines=5, interactive=False)
303
+ top_5_output = gr.Textbox(label="Top 5 Emotions (Regardless of Threshold)", lines=5, interactive=False)
304
+ output_plot = gr.Plot(label="Emotion Confidence Visualization (Above Threshold)")
305
+
306
+ # Export predictions
307
+ export_btn = gr.File(label="Download Predictions as CSV", visible=False)
 
 
 
 
 
 
 
 
 
 
 
308
 
309
  # Example carousel
310
+ with gr.Group():
311
+ gr.Markdown("<div id='examples-title'>Example Texts</div>", elem_id="examples-title")
312
+ examples = gr.Examples(
313
+ examples=[
314
+ ["I’m just chilling today.", "Neutral Example"],
315
+ ["Thank you for saving my life!", "Gratitude Example"],
316
+ ["I’m nervous about my exam tomorrow.", "Nervousness Example"],
317
+ ["I love my new puppy so much!", "Love Example"],
318
+ ["I’m so relieved the storm passed.", "Relief Example"]
319
+ ],
320
+ inputs=[text_input],
321
+ label="",
322
+ examples_per_page=3
323
+ )
324
+
325
+ # Footer
326
+ gr.HTML(
327
+ """
328
+ <footer>
329
+ Built with ❤️ by logasanjeev |
330
+ <a href="https://huggingface.co/logasanjeev/goemotions-bert">Model Card</a> |
331
+ <a href="https://www.kaggle.com/code/ravindranlogasanjeev/evaluation-logasanjeev-goemotions-bert/notebook">Kaggle Notebook</a> |
332
+ <a href="https://github.com/logasanjeev">GitHub</a>
333
+ </footer>
334
+ """
335
  )
336
 
337
+ # Bind predictions with loading spinner
338
  submit_btn.click(
339
  fn=predict_emotions_with_details,
340
+ inputs=[text_input, confidence_slider, chart_type],
341
+ outputs=[processed_text_output, thresholded_output, top_5_output, output_plot, export_btn],
342
+ _js="showLoading(); return [arguments[0], arguments[1], arguments[2]]"
343
+ ).then(
344
+ fn=None,
345
+ inputs=None,
346
+ outputs=None,
347
+ _js="hideLoading"
348
+ )
349
+
350
+ # Reset functionality
351
+ reset_btn.click(
352
+ fn=lambda: ("", "", "", None, None),
353
+ inputs=[],
354
+ outputs=[text_input, processed_text_output, thresholded_output, top_5_output, output_plot, export_btn]
355
  )
356
 
357
  # Launch