Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -26,23 +26,23 @@ def classify_review(text):
|
|
26 |
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
|
27 |
with torch.no_grad():
|
28 |
outputs = model(**inputs)
|
29 |
-
probs = torch.nn.functional.softmax(outputs.logits, dim=-1)
|
30 |
-
|
31 |
-
predicted_label = labels[
|
32 |
|
33 |
# Logging
|
34 |
with open(log_path, mode='a', newline='') as file:
|
35 |
writer = csv.writer(file)
|
36 |
writer.writerow([datetime.now().isoformat(), text, predicted_label])
|
37 |
|
38 |
-
return
|
39 |
|
40 |
def save_correction(text, predicted_label, user_correction):
|
41 |
if user_correction != predicted_label:
|
42 |
with open(corrections_path, mode='a', newline='') as file:
|
43 |
writer = csv.writer(file)
|
44 |
writer.writerow([datetime.now().isoformat(), text, predicted_label, user_correction])
|
45 |
-
return f"Thanks! Correction recorded: {user_correction}"
|
46 |
|
47 |
# Gradio UI
|
48 |
with gr.Blocks() as demo:
|
@@ -50,7 +50,7 @@ with gr.Blocks() as demo:
|
|
50 |
gr.Markdown("Enter a course review and get the sentiment prediction. You can correct the result if needed.")
|
51 |
|
52 |
input_text = gr.Textbox(lines=4, placeholder="Enter course review here...")
|
53 |
-
output_label = gr.Label(num_top_classes=3)
|
54 |
predict_btn = gr.Button("Classify")
|
55 |
|
56 |
with gr.Row(visible=False) as correction_row:
|
@@ -62,13 +62,12 @@ with gr.Blocks() as demo:
|
|
62 |
hidden_text = gr.Textbox(visible=False)
|
63 |
hidden_pred = gr.Textbox(visible=False)
|
64 |
|
65 |
-
def show_correction_ui(
|
66 |
-
return
|
67 |
-
|
68 |
-
predict_btn.click(classify_review, inputs=input_text, outputs=[output_label, hidden_text, hidden_pred])\
|
69 |
-
.then(show_correction_ui, inputs=[output_label, hidden_text, hidden_pred],
|
70 |
-
outputs=[output_label, correction_row, hidden_text, hidden_pred])
|
71 |
|
|
|
|
|
|
|
72 |
|
73 |
submit_btn.click(save_correction, inputs=[hidden_text, hidden_pred, correction_dropdown], outputs=correction_status)
|
74 |
|
|
|
26 |
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
|
27 |
with torch.no_grad():
|
28 |
outputs = model(**inputs)
|
29 |
+
probs = torch.nn.functional.softmax(outputs.logits, dim=-1)[0]
|
30 |
+
prob_dict = {label: float(prob) for label, prob in zip(labels, probs)}
|
31 |
+
predicted_label = labels[torch.argmax(probs).item()]
|
32 |
|
33 |
# Logging
|
34 |
with open(log_path, mode='a', newline='') as file:
|
35 |
writer = csv.writer(file)
|
36 |
writer.writerow([datetime.now().isoformat(), text, predicted_label])
|
37 |
|
38 |
+
return prob_dict, text, predicted_label
|
39 |
|
40 |
def save_correction(text, predicted_label, user_correction):
|
41 |
if user_correction != predicted_label:
|
42 |
with open(corrections_path, mode='a', newline='') as file:
|
43 |
writer = csv.writer(file)
|
44 |
writer.writerow([datetime.now().isoformat(), text, predicted_label, user_correction])
|
45 |
+
return f"✅ Thanks! Correction recorded: {user_correction}"
|
46 |
|
47 |
# Gradio UI
|
48 |
with gr.Blocks() as demo:
|
|
|
50 |
gr.Markdown("Enter a course review and get the sentiment prediction. You can correct the result if needed.")
|
51 |
|
52 |
input_text = gr.Textbox(lines=4, placeholder="Enter course review here...")
|
53 |
+
output_label = gr.Label(num_top_classes=3, label="Predicted Sentiment")
|
54 |
predict_btn = gr.Button("Classify")
|
55 |
|
56 |
with gr.Row(visible=False) as correction_row:
|
|
|
62 |
hidden_text = gr.Textbox(visible=False)
|
63 |
hidden_pred = gr.Textbox(visible=False)
|
64 |
|
65 |
+
def show_correction_ui(_, text, pred):
|
66 |
+
return gr.update(visible=True), text, pred
|
|
|
|
|
|
|
|
|
67 |
|
68 |
+
predict_btn.click(classify_review, inputs=input_text, outputs=[output_label, hidden_text, hidden_pred]) \
|
69 |
+
.then(show_correction_ui, inputs=[output_label, hidden_text, hidden_pred],
|
70 |
+
outputs=[correction_row, hidden_text, hidden_pred])
|
71 |
|
72 |
submit_btn.click(save_correction, inputs=[hidden_text, hidden_pred, correction_dropdown], outputs=correction_status)
|
73 |
|