logasanjeev commited on
Commit
d581a83
·
verified ·
1 Parent(s): aaf3cb3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +71 -60
app.py CHANGED
@@ -1,76 +1,86 @@
1
  import gradio as gr
2
- import torch
3
- import numpy as np
4
- from transformers import BertForSequenceClassification, BertTokenizer
5
- import requests
6
- import json
7
- import plotly.express as px
8
  import pandas as pd
 
 
 
 
 
9
 
10
- # Load model and tokenizer from Hugging Face Hub
11
  repo_id = "logasanjeev/goemotions-bert"
12
- model = BertForSequenceClassification.from_pretrained(repo_id)
13
- tokenizer = BertTokenizer.from_pretrained(repo_id)
14
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
15
- model.to(device)
16
- if torch.cuda.device_count() > 1:
17
- model = nn.DataParallel(model)
18
- model.eval()
19
 
20
- # Load optimized thresholds from Hugging Face Hub
21
- thresholds_url = f"https://huggingface.co/{repo_id}/raw/main/thresholds.json"
22
- response = requests.get(thresholds_url)
23
- thresholds_data = json.loads(response.text)
24
- emotion_labels = thresholds_data["emotion_labels"]
25
- default_thresholds = thresholds_data["thresholds"]
26
 
27
- # Prediction function
28
- def predict_emotions(text, confidence_threshold=0.0):
29
- encodings = tokenizer(
30
- text,
 
 
 
 
 
 
 
 
 
 
 
31
  padding='max_length',
32
  truncation=True,
33
  max_length=128,
34
  return_tensors='pt'
35
  )
36
- input_ids = encodings['input_ids'].to(device)
37
- attention_mask = encodings['attention_mask'].to(device)
38
 
39
  with torch.no_grad():
40
- outputs = model(input_ids, attention_mask=attention_mask)
41
  logits = torch.sigmoid(outputs.logits).cpu().numpy()[0]
42
 
43
- # Apply thresholds with user-defined confidence boost
44
- predictions = []
45
- for i, (logit, thresh) in enumerate(zip(logits, default_thresholds)):
46
- adjusted_thresh = max(thresh, confidence_threshold)
47
- if logit >= adjusted_thresh:
48
- predictions.append((emotion_labels[i], logit))
49
 
50
- predictions.sort(key=lambda x: x[1], reverse=True)
51
- if not predictions:
52
- return "No emotions predicted above thresholds.", None
 
 
 
53
 
54
- # Format output
55
- text_output = "\n".join([f"{emotion}: {confidence:.4f}" for emotion, confidence in predictions])
 
 
56
 
57
- # Create bar chart
58
- df = pd.DataFrame(predictions, columns=["Emotion", "Confidence"])
59
- fig = px.bar(
60
- df,
61
- x="Emotion",
62
- y="Confidence",
63
- color="Emotion",
64
- text="Confidence",
65
- title="Emotion Confidence Levels",
66
- height=400
67
- )
68
- fig.update_traces(texttemplate='%{text:.2f}', textposition='auto')
69
- fig.update_layout(showlegend=False, margin=dict(t=40, b=40))
 
 
70
 
71
- return text_output, fig
72
 
73
- # Custom CSS for modern UI
74
  custom_css = """
75
  body {
76
  font-family: 'Segoe UI', Arial, sans-serif;
@@ -123,7 +133,6 @@ body {
123
  }
124
  """
125
 
126
- # JavaScript for theme toggle
127
  theme_js = """
128
  function toggleTheme() {
129
  document.body.classList.toggle('dark-mode');
@@ -138,8 +147,8 @@ with gr.Blocks(css=custom_css) as demo:
138
  """
139
  <div id='description'>
140
  Predict emotions from text using a fine-tuned BERT-base model.
141
- Explore 28 emotions with optimized thresholds (Micro F1: 0.6025).
142
- Try examples or enter your own text!
143
  </div>
144
  """,
145
  elem_id="description"
@@ -173,8 +182,10 @@ with gr.Blocks(css=custom_css) as demo:
173
  submit_btn = gr.Button("Predict Emotions", variant="primary")
174
 
175
  with gr.Column(scale=1):
176
- output_text = gr.Textbox(label="Predicted Emotions", lines=5)
177
- output_plot = gr.Plot(label="Emotion Confidence Chart")
 
 
178
 
179
  # Example carousel
180
  examples = gr.Examples(
@@ -191,9 +202,9 @@ with gr.Blocks(css=custom_css) as demo:
191
 
192
  # Bind prediction
193
  submit_btn.click(
194
- fn=predict_emotions,
195
  inputs=[text_input, confidence_slider],
196
- outputs=[output_text, output_plot]
197
  )
198
 
199
  # Launch
 
1
  import gradio as gr
 
 
 
 
 
 
2
  import pandas as pd
3
+ import plotly.express as px
4
+ import shutil
5
+ import os
6
+ from huggingface_hub import hf_hub_download
7
+ from importlib import import_module
8
 
 
9
  repo_id = "logasanjeev/goemotions-bert"
10
+ local_file = hf_hub_download(repo_id=repo_id, filename="inference.py")
11
+ print("Downloaded inference.py successfully!")
12
+
13
+ current_dir = os.getcwd()
14
+ destination = os.path.join(current_dir, "inference.py")
15
+ shutil.copy(local_file, destination)
16
+ print("Copied inference.py to current directory!")
17
 
18
+ inference_module = import_module("inference")
19
+ predict_emotions = inference_module.predict_emotions
20
+ print("Imported predict_emotions successfully!")
 
 
 
21
 
22
+ _, _ = predict_emotions("dummy text")
23
+ emotion_labels = inference_module.EMOTION_LABELS
24
+ default_thresholds = inference_module.THRESHOLDS
25
+
26
+ def predict_emotions_with_details(text, confidence_threshold=0.0):
27
+ predictions_str, processed_text = predict_emotions(text)
28
+
29
+ predictions = []
30
+ if predictions_str != "No emotions predicted.":
31
+ for line in predictions_str.split("\n"):
32
+ emotion, confidence = line.split(": ")
33
+ predictions.append((emotion, float(confidence)))
34
+
35
+ encodings = inference_module.TOKENIZER(
36
+ processed_text,
37
  padding='max_length',
38
  truncation=True,
39
  max_length=128,
40
  return_tensors='pt'
41
  )
42
+ input_ids = encodings['input_ids'].to(inference_module.DEVICE)
43
+ attention_mask = encodings['attention_mask'].to(inference_module.DEVICE)
44
 
45
  with torch.no_grad():
46
+ outputs = inference_module.MODEL(input_ids, attention_mask=attention_mask)
47
  logits = torch.sigmoid(outputs.logits).cpu().numpy()[0]
48
 
49
+ all_emotions = [(emotion_labels[i], round(logit, 4)) for i, logit in enumerate(logits)]
50
+ all_emotions.sort(key=lambda x: x[1], reverse=True)
51
+ top_5_emotions = all_emotions[:5]
52
+ top_5_output = "\n".join([f"{emotion}: {confidence:.4f}" for emotion, confidence in top_5_emotions])
 
 
53
 
54
+ filtered_predictions = []
55
+ for emotion, confidence in predictions:
56
+ thresh = default_thresholds[emotion_labels.index(emotion)]
57
+ adjusted_thresh = max(thresh, confidence_threshold)
58
+ if confidence >= adjusted_thresh:
59
+ filtered_predictions.append((emotion, confidence))
60
 
61
+ if not filtered_predictions:
62
+ thresholded_output = "No emotions predicted above thresholds."
63
+ else:
64
+ thresholded_output = "\n".join([f"{emotion}: {confidence:.4f}" for emotion, confidence in filtered_predictions])
65
 
66
+ if filtered_predictions:
67
+ df = pd.DataFrame(filtered_predictions, columns=["Emotion", "Confidence"])
68
+ fig = px.bar(
69
+ df,
70
+ x="Emotion",
71
+ y="Confidence",
72
+ color="Emotion",
73
+ text="Confidence",
74
+ title="Emotion Confidence Levels (Above Threshold)",
75
+ height=400
76
+ )
77
+ fig.update_traces(texttemplate='%{text:.2f}', textposition='auto')
78
+ fig.update_layout(showlegend=False, margin=dict(t=40, b=40))
79
+ else:
80
+ fig = None
81
 
82
+ return processed_text, thresholded_output, top_5_output, fig
83
 
 
84
  custom_css = """
85
  body {
86
  font-family: 'Segoe UI', Arial, sans-serif;
 
133
  }
134
  """
135
 
 
136
  theme_js = """
137
  function toggleTheme() {
138
  document.body.classList.toggle('dark-mode');
 
147
  """
148
  <div id='description'>
149
  Predict emotions from text using a fine-tuned BERT-base model.
150
+ Explore 28 emotions with optimized thresholds (Micro F1: 0.6006).
151
+ View preprocessed text, top 5 emotions, and thresholded predictions!
152
  </div>
153
  """,
154
  elem_id="description"
 
182
  submit_btn = gr.Button("Predict Emotions", variant="primary")
183
 
184
  with gr.Column(scale=1):
185
+ processed_text_output = gr.Textbox(label="Preprocessed Text", lines=2)
186
+ thresholded_output = gr.Textbox(label="Predicted Emotions (Above Threshold)", lines=5)
187
+ top_5_output = gr.Textbox(label="Top 5 Emotions (Regardless of Threshold)", lines=5)
188
+ output_plot = gr.Plot(label="Emotion Confidence Chart (Above Threshold)")
189
 
190
  # Example carousel
191
  examples = gr.Examples(
 
202
 
203
  # Bind prediction
204
  submit_btn.click(
205
+ fn=predict_emotions_with_details,
206
  inputs=[text_input, confidence_slider],
207
+ outputs=[processed_text_output, thresholded_output, top_5_output, output_plot]
208
  )
209
 
210
  # Launch