File size: 3,162 Bytes
6e8cb0a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b4d96e5
 
 
6e8cb0a
 
 
 
 
 
b4d96e5
6e8cb0a
 
 
 
 
 
b4d96e5
6e8cb0a
 
 
 
 
 
 
b4d96e5
6e8cb0a
 
 
 
 
 
 
 
 
 
 
b4d96e5
 
7573a0d
b4d96e5
 
 
6e8cb0a
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import gradio as gr
import torch
from transformers import DistilBertTokenizerFast, DistilBertForSequenceClassification
from datetime import datetime
import csv
import os

# Load model and tokenizer
model = DistilBertForSequenceClassification.from_pretrained("debojit01/course-review-sentiment")
tokenizer = DistilBertTokenizerFast.from_pretrained("distilbert-base-uncased")

labels = ['negative', 'neutral', 'positive']

# Setup log files
log_path = "logs.csv"
corrections_path = "corrections.csv"

for path, headers in [(log_path, ["timestamp", "input_text", "predicted_label"]),
                      (corrections_path, ["timestamp", "input_text", "predicted_label", "user_correction"])]:
    if not os.path.exists(path):
        with open(path, mode='w', newline='') as file:
            writer = csv.writer(file)
            writer.writerow(headers)

def classify_review(text):
    inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
    with torch.no_grad():
        outputs = model(**inputs)
    probs = torch.nn.functional.softmax(outputs.logits, dim=-1)[0]
    prob_dict = {label: float(prob) for label, prob in zip(labels, probs)}
    predicted_label = labels[torch.argmax(probs).item()]

    # Logging
    with open(log_path, mode='a', newline='') as file:
        writer = csv.writer(file)
        writer.writerow([datetime.now().isoformat(), text, predicted_label])
    
    return prob_dict, text, predicted_label

def save_correction(text, predicted_label, user_correction):
    if user_correction != predicted_label:
        with open(corrections_path, mode='a', newline='') as file:
            writer = csv.writer(file)
            writer.writerow([datetime.now().isoformat(), text, predicted_label, user_correction])
    return f"βœ… Thanks! Correction recorded: {user_correction}"

# Gradio UI
with gr.Blocks() as demo:
    gr.Markdown("# πŸ“˜ Course Review Sentiment Classifier")
    gr.Markdown("Enter a course review and get the sentiment prediction. You can correct the result if needed.")

    input_text = gr.Textbox(lines=4, placeholder="Enter course review here...")
    output_label = gr.Label(num_top_classes=3, label="Predicted Sentiment")
    predict_btn = gr.Button("Classify")

    with gr.Row(visible=False) as correction_row:
        gr.Markdown("### ❓ Is the prediction wrong?")
        correction_dropdown = gr.Dropdown(choices=labels, label="Correct Sentiment")
        submit_btn = gr.Button("Submit Correction")
        correction_status = gr.Textbox(interactive=False)

    hidden_text = gr.Textbox(visible=False)
    hidden_pred = gr.Textbox(visible=False)

    def show_correction_ui(_, text, pred):
        return gr.update(visible=True), text, pred

    predict_btn.click(classify_review, inputs=input_text, outputs=[output_label, hidden_text, hidden_pred]) \
               .then(show_correction_ui, inputs=[output_label, hidden_text, hidden_pred],
                     outputs=[correction_row, hidden_text, hidden_pred])

    submit_btn.click(save_correction, inputs=[hidden_text, hidden_pred, correction_dropdown], outputs=correction_status)

if __name__ == "__main__":
    demo.launch()