debojit01 commited on
Commit
6e8cb0a
·
verified ·
1 Parent(s): 40e4906

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +74 -0
app.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from transformers import DistilBertTokenizerFast, DistilBertForSequenceClassification
4
+ from datetime import datetime
5
+ import csv
6
+ import os
7
+
8
+ # Load model and tokenizer
9
+ model = DistilBertForSequenceClassification.from_pretrained("debojit01/course-review-sentiment")
10
+ tokenizer = DistilBertTokenizerFast.from_pretrained("distilbert-base-uncased")
11
+
12
+ labels = ['negative', 'neutral', 'positive']
13
+
14
+ # Setup log files
15
+ log_path = "logs.csv"
16
+ corrections_path = "corrections.csv"
17
+
18
+ for path, headers in [(log_path, ["timestamp", "input_text", "predicted_label"]),
19
+ (corrections_path, ["timestamp", "input_text", "predicted_label", "user_correction"])]:
20
+ if not os.path.exists(path):
21
+ with open(path, mode='w', newline='') as file:
22
+ writer = csv.writer(file)
23
+ writer.writerow(headers)
24
+
25
+ def classify_review(text):
26
+ inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
27
+ with torch.no_grad():
28
+ outputs = model(**inputs)
29
+ probs = torch.nn.functional.softmax(outputs.logits, dim=-1)
30
+ label_idx = torch.argmax(probs).item()
31
+ predicted_label = labels[label_idx]
32
+
33
+ # Logging
34
+ with open(log_path, mode='a', newline='') as file:
35
+ writer = csv.writer(file)
36
+ writer.writerow([datetime.now().isoformat(), text, predicted_label])
37
+
38
+ return {label: float(prob) for label, prob in zip(labels, probs[0])}, text, predicted_label
39
+
40
+ def save_correction(text, predicted_label, user_correction):
41
+ if user_correction != predicted_label:
42
+ with open(corrections_path, mode='a', newline='') as file:
43
+ writer = csv.writer(file)
44
+ writer.writerow([datetime.now().isoformat(), text, predicted_label, user_correction])
45
+ return f"Thanks! Correction recorded: {user_correction}"
46
+
47
+ # Gradio UI
48
+ with gr.Blocks() as demo:
49
+ gr.Markdown("# 📘 Course Review Sentiment Classifier")
50
+ gr.Markdown("Enter a course review and get the sentiment prediction. You can correct the result if needed.")
51
+
52
+ input_text = gr.Textbox(lines=4, placeholder="Enter course review here...")
53
+ output_label = gr.Label(num_top_classes=3)
54
+ predict_btn = gr.Button("Classify")
55
+
56
+ with gr.Row(visible=False) as correction_row:
57
+ gr.Markdown("### ❓ Is the prediction wrong?")
58
+ correction_dropdown = gr.Dropdown(choices=labels, label="Correct Sentiment")
59
+ submit_btn = gr.Button("Submit Correction")
60
+ correction_status = gr.Textbox(interactive=False)
61
+
62
+ hidden_text = gr.Textbox(visible=False)
63
+ hidden_pred = gr.Textbox(visible=False)
64
+
65
+ def show_correction_ui(result, text, pred):
66
+ return result, gr.update(visible=True), text, pred
67
+
68
+ predict_btn.click(classify_review, inputs=input_text, outputs=[output_label, hidden_text, hidden_pred])\
69
+ .then(show_correction_ui, outputs=[output_label, correction_row, hidden_text, hidden_pred])
70
+
71
+ submit_btn.click(save_correction, inputs=[hidden_text, hidden_pred, correction_dropdown], outputs=correction_status)
72
+
73
+ if __name__ == "__main__":
74
+ demo.launch()