debojit01 commited on
Commit
b4d96e5
·
verified ·
1 Parent(s): 7573a0d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -12
app.py CHANGED
@@ -26,23 +26,23 @@ def classify_review(text):
26
  inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
27
  with torch.no_grad():
28
  outputs = model(**inputs)
29
- probs = torch.nn.functional.softmax(outputs.logits, dim=-1)
30
- label_idx = torch.argmax(probs).item()
31
- predicted_label = labels[label_idx]
32
 
33
  # Logging
34
  with open(log_path, mode='a', newline='') as file:
35
  writer = csv.writer(file)
36
  writer.writerow([datetime.now().isoformat(), text, predicted_label])
37
 
38
- return {label: float(prob) for label, prob in zip(labels, probs[0])}, text, predicted_label
39
 
40
  def save_correction(text, predicted_label, user_correction):
41
  if user_correction != predicted_label:
42
  with open(corrections_path, mode='a', newline='') as file:
43
  writer = csv.writer(file)
44
  writer.writerow([datetime.now().isoformat(), text, predicted_label, user_correction])
45
- return f"Thanks! Correction recorded: {user_correction}"
46
 
47
  # Gradio UI
48
  with gr.Blocks() as demo:
@@ -50,7 +50,7 @@ with gr.Blocks() as demo:
50
  gr.Markdown("Enter a course review and get the sentiment prediction. You can correct the result if needed.")
51
 
52
  input_text = gr.Textbox(lines=4, placeholder="Enter course review here...")
53
- output_label = gr.Label(num_top_classes=3)
54
  predict_btn = gr.Button("Classify")
55
 
56
  with gr.Row(visible=False) as correction_row:
@@ -62,13 +62,12 @@ with gr.Blocks() as demo:
62
  hidden_text = gr.Textbox(visible=False)
63
  hidden_pred = gr.Textbox(visible=False)
64
 
65
- def show_correction_ui(result, text, pred):
66
- return result, gr.update(visible=True), text, pred
67
-
68
- predict_btn.click(classify_review, inputs=input_text, outputs=[output_label, hidden_text, hidden_pred])\
69
- .then(show_correction_ui, inputs=[output_label, hidden_text, hidden_pred],
70
- outputs=[output_label, correction_row, hidden_text, hidden_pred])
71
 
 
 
 
72
 
73
  submit_btn.click(save_correction, inputs=[hidden_text, hidden_pred, correction_dropdown], outputs=correction_status)
74
 
 
26
  inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
27
  with torch.no_grad():
28
  outputs = model(**inputs)
29
+ probs = torch.nn.functional.softmax(outputs.logits, dim=-1)[0]
30
+ prob_dict = {label: float(prob) for label, prob in zip(labels, probs)}
31
+ predicted_label = labels[torch.argmax(probs).item()]
32
 
33
  # Logging
34
  with open(log_path, mode='a', newline='') as file:
35
  writer = csv.writer(file)
36
  writer.writerow([datetime.now().isoformat(), text, predicted_label])
37
 
38
+ return prob_dict, text, predicted_label
39
 
40
  def save_correction(text, predicted_label, user_correction):
41
  if user_correction != predicted_label:
42
  with open(corrections_path, mode='a', newline='') as file:
43
  writer = csv.writer(file)
44
  writer.writerow([datetime.now().isoformat(), text, predicted_label, user_correction])
45
+ return f"Thanks! Correction recorded: {user_correction}"
46
 
47
  # Gradio UI
48
  with gr.Blocks() as demo:
 
50
  gr.Markdown("Enter a course review and get the sentiment prediction. You can correct the result if needed.")
51
 
52
  input_text = gr.Textbox(lines=4, placeholder="Enter course review here...")
53
+ output_label = gr.Label(num_top_classes=3, label="Predicted Sentiment")
54
  predict_btn = gr.Button("Classify")
55
 
56
  with gr.Row(visible=False) as correction_row:
 
62
  hidden_text = gr.Textbox(visible=False)
63
  hidden_pred = gr.Textbox(visible=False)
64
 
65
+ def show_correction_ui(_, text, pred):
66
+ return gr.update(visible=True), text, pred
 
 
 
 
67
 
68
+ predict_btn.click(classify_review, inputs=input_text, outputs=[output_label, hidden_text, hidden_pred]) \
69
+ .then(show_correction_ui, inputs=[output_label, hidden_text, hidden_pred],
70
+ outputs=[correction_row, hidden_text, hidden_pred])
71
 
72
  submit_btn.click(save_correction, inputs=[hidden_text, hidden_pred, correction_dropdown], outputs=correction_status)
73