Upload app.py
Browse files
app.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1 |
-
# app.py
|
2 |
import unsloth
|
3 |
from unsloth import FastModel
|
4 |
|
@@ -55,7 +55,7 @@ class GPTSequenceClassifier(nn.Module):
|
|
55 |
|
56 |
|
57 |
# ===================================================================
|
58 |
-
#
|
59 |
# ===================================================================
|
60 |
|
61 |
# --- Helper Functions ---
|
@@ -109,7 +109,7 @@ def evaluate_equations(eq_dict: dict, sol_dict: dict):
|
|
109 |
correct_rhs_val = round(lhs_val, 4)
|
110 |
correct_rhs_str = f"{correct_rhs_val:.4f}".rstrip('0').rstrip('.')
|
111 |
|
112 |
-
|
113 |
return {
|
114 |
"error": True,
|
115 |
"line_key": key,
|
@@ -235,7 +235,7 @@ logger.info("load_model(): %s", msg)
|
|
235 |
|
236 |
|
237 |
# ===================================================================
|
238 |
-
#
|
239 |
# ===================================================================
|
240 |
|
241 |
def run_conceptual_check(question: str, solution: str, model, tokenizer) -> dict:
|
@@ -255,7 +255,7 @@ def run_conceptual_check(question: str, solution: str, model, tokenizer) -> dict
|
|
255 |
with torch.inference_mode():
|
256 |
outputs = model(**inputs, use_cache=False)
|
257 |
|
258 |
-
|
259 |
logits = outputs["logits"].to(torch.float32)
|
260 |
probs = torch.softmax(logits, dim=-1).squeeze().tolist()
|
261 |
|
@@ -327,11 +327,11 @@ def analyze_solution(question: str, solution: str):
|
|
327 |
"""
|
328 |
Main orchestrator that runs the full pipeline and generates the final explanation.
|
329 |
"""
|
330 |
-
# STAGE 1: Conceptual Check
|
331 |
conceptual_result = run_conceptual_check(question, solution, classifier_model, classifier_tokenizer)
|
332 |
confidence = conceptual_result['probabilities'][conceptual_result['prediction']]
|
333 |
|
334 |
-
# STAGE 2: Computational Check
|
335 |
computational_result = run_computational_check(solution, gemma_model, gemma_tokenizer)
|
336 |
|
337 |
# FINAL VERDICT LOGIC
|
@@ -372,13 +372,13 @@ def classify_solution_stream(question: str, solution: str):
|
|
372 |
|
373 |
log = []
|
374 |
|
375 |
-
|
376 |
if not question.strip() or not solution.strip():
|
377 |
log.append("⚠️ Provide a question and a solution.")
|
378 |
yield "Please fill in both fields", "", render(log)
|
379 |
return
|
380 |
|
381 |
-
|
382 |
if not models_ready():
|
383 |
log.append("⏳ Loading models…")
|
384 |
yield "⏳ Working…", "", render(log)
|
@@ -444,7 +444,7 @@ def classify_solution_stream(question: str, solution: str):
|
|
444 |
yield "Runtime error", f"{type(e).__name__}: {e}", render(log)
|
445 |
|
446 |
|
447 |
-
# ---------------- UI: streaming
|
448 |
with gr.Blocks(title="Math Solution Classifier", theme=gr.themes.Soft()) as app:
|
449 |
gr.Markdown("# 🧮 Math Solution Classifier")
|
450 |
gr.Markdown(
|
@@ -665,7 +665,7 @@ class ExampleSelector:
|
|
665 |
else:
|
666 |
self.balance["wrong"] += 1
|
667 |
|
668 |
-
# ===== CSV hookup
|
669 |
from pathlib import Path
|
670 |
import time
|
671 |
|
@@ -673,10 +673,10 @@ CSV_PATH = Path(__file__).resolve().parent / "final-test-with-wrong-answers.csv"
|
|
673 |
POOL = load_examples_csv(str(CSV_PATH))
|
674 |
|
675 |
def new_selector(seed: int | None = None):
|
676 |
-
|
677 |
return ExampleSelector(POOL, seed=seed or int(time.time()) & 0xFFFF)
|
678 |
|
679 |
-
|
680 |
def _truncate(s: str, n: int = 100) -> str:
|
681 |
s = s or ""
|
682 |
return s if len(s) <= n else s[: n - 1] + "…"
|
@@ -694,7 +694,6 @@ def _rows_to_table(rows: list[dict]) -> list[list[str]]:
|
|
694 |
return table
|
695 |
|
696 |
|
697 |
-
# ===== Gradio callbacks for examples =====
|
698 |
def ui_surprise(selector, filter_label="any"):
|
699 |
"""Pick one example and push it straight to inputs; persist selector state."""
|
700 |
if selector is None or not POOL:
|
@@ -704,9 +703,13 @@ def ui_surprise(selector, filter_label="any"):
|
|
704 |
return selector, gr.update(), gr.update()
|
705 |
return selector, r["question"], r["solution"]
|
706 |
|
|
|
|
|
|
|
|
|
|
|
707 |
|
708 |
|
709 |
-
# ---------------- UI: add CSV-driven examples ----------------
|
710 |
with gr.Blocks(title="Math Solution Classifier", theme=gr.themes.Soft()) as app:
|
711 |
gr.Markdown("# 🧮 Math Solution Classifier")
|
712 |
gr.Markdown(
|
@@ -715,7 +718,7 @@ with gr.Blocks(title="Math Solution Classifier", theme=gr.themes.Soft()) as app:
|
|
715 |
" \n Press 'Surprise me' to randomly select a sample question/answer pair from our dataset."
|
716 |
)
|
717 |
|
718 |
-
|
719 |
selector_state = gr.State(new_selector())
|
720 |
|
721 |
with gr.Row():
|
@@ -723,12 +726,12 @@ with gr.Blocks(title="Math Solution Classifier", theme=gr.themes.Soft()) as app:
|
|
723 |
with gr.Column(scale=1):
|
724 |
question_input = gr.Textbox(
|
725 |
label="Math Question",
|
726 |
-
placeholder="e.g.,
|
727 |
lines=3,
|
728 |
)
|
729 |
solution_input = gr.Textbox(
|
730 |
label="Proposed Solution",
|
731 |
-
placeholder="e.g.,
|
732 |
lines=8,
|
733 |
)
|
734 |
expected_label_example = gr.Textbox(
|
@@ -738,7 +741,7 @@ with gr.Blocks(title="Math Solution Classifier", theme=gr.themes.Soft()) as app:
|
|
738 |
with gr.Row():
|
739 |
classify_btn = gr.Button("Classify Solution", variant="primary")
|
740 |
surprise_btn = gr.Button("Surprise me") # <- new
|
741 |
-
clear_btn = gr.
|
742 |
|
743 |
# -------- Right: outputs --------
|
744 |
with gr.Column(scale=1):
|
@@ -746,7 +749,7 @@ with gr.Blocks(title="Math Solution Classifier", theme=gr.themes.Soft()) as app:
|
|
746 |
explanation_output = gr.Textbox(label="Explanation", interactive=False, lines=6)
|
747 |
status_output = gr.Markdown(value="*(idle)*") # live stage updates
|
748 |
|
749 |
-
# -------- Curated starter examples
|
750 |
gr.Examples(
|
751 |
examples=[
|
752 |
["John has three apples and Mary has seven, how many apples do they have together?",
|
@@ -755,18 +758,18 @@ with gr.Blocks(title="Math Solution Classifier", theme=gr.themes.Soft()) as app:
|
|
755 |
["A rectangle's length is twice its width. If the width is 7 cm, what is the perimeter of the rectangle?",
|
756 |
"The length of the rectangle is 2 * 7 = 14 cm.\n The perimeter is 14 + 7 = 21 cm.\n Final answer: 21",
|
757 |
"Conceptually flawed"],
|
758 |
-
["
|
759 |
-
"The
|
760 |
-
"
|
761 |
["What is 15% of 200?",
|
762 |
"15% = 15/100 = 0.15\n0.15 × 200 = 30\n Final answer: 30",
|
763 |
"Correct"],
|
764 |
["A circle has a radius of 5 cm. Using the approximation pi = 3.14, what is the circumference of the circle?",
|
765 |
"The circumference of the circle is 3.14 * 5 = 15.7 cm.\n Final answer: 15.7",
|
766 |
"Conceptually flawed"],
|
767 |
-
["
|
768 |
-
"The
|
769 |
-
"
|
770 |
["A 24-meter rope is cut into 6 equal pieces. A climber uses 2 of those pieces. How many meters of rope are still unused?",
|
771 |
"The length of each piece is 24 / 6 = 4 m.\n The climber uses 2 × 4 m = 8 m of rope.\n There are 24 m − 8 m = 16 m of rope still unused.\n Final answer: 16",
|
772 |
"Correct"]
|
@@ -776,7 +779,7 @@ with gr.Blocks(title="Math Solution Classifier", theme=gr.themes.Soft()) as app:
|
|
776 |
|
777 |
|
778 |
# ---------- Wiring ----------
|
779 |
-
# Main classify
|
780 |
classify_btn.click(
|
781 |
fn=classify_solution_stream,
|
782 |
inputs=[question_input, solution_input],
|
@@ -785,16 +788,14 @@ with gr.Blocks(title="Math Solution Classifier", theme=gr.themes.Soft()) as app:
|
|
785 |
concurrency_limit=1,
|
786 |
)
|
787 |
|
788 |
-
# ---- and replace the Surprise button wiring with this ----
|
789 |
surprise_btn.click(
|
790 |
fn=ui_surprise,
|
791 |
-
inputs=[selector_state],
|
792 |
-
outputs=[selector_state, question_input, solution_input],
|
793 |
queue=True,
|
794 |
)
|
795 |
|
796 |
|
797 |
-
# Enable queue for streaming
|
798 |
app.queue()
|
799 |
|
800 |
if __name__ == "__main__":
|
|
|
1 |
+
# app.py
|
2 |
import unsloth
|
3 |
from unsloth import FastModel
|
4 |
|
|
|
55 |
|
56 |
|
57 |
# ===================================================================
|
58 |
+
#HELPERS
|
59 |
# ===================================================================
|
60 |
|
61 |
# --- Helper Functions ---
|
|
|
109 |
correct_rhs_val = round(lhs_val, 4)
|
110 |
correct_rhs_str = f"{correct_rhs_val:.4f}".rstrip('0').rstrip('.')
|
111 |
|
112 |
+
|
113 |
return {
|
114 |
"error": True,
|
115 |
"line_key": key,
|
|
|
235 |
|
236 |
|
237 |
# ===================================================================
|
238 |
+
# PIPELINE COMPONENTS
|
239 |
# ===================================================================
|
240 |
|
241 |
def run_conceptual_check(question: str, solution: str, model, tokenizer) -> dict:
|
|
|
255 |
with torch.inference_mode():
|
256 |
outputs = model(**inputs, use_cache=False)
|
257 |
|
258 |
+
|
259 |
logits = outputs["logits"].to(torch.float32)
|
260 |
probs = torch.softmax(logits, dim=-1).squeeze().tolist()
|
261 |
|
|
|
327 |
"""
|
328 |
Main orchestrator that runs the full pipeline and generates the final explanation.
|
329 |
"""
|
330 |
+
# STAGE 1: Conceptual Check
|
331 |
conceptual_result = run_conceptual_check(question, solution, classifier_model, classifier_tokenizer)
|
332 |
confidence = conceptual_result['probabilities'][conceptual_result['prediction']]
|
333 |
|
334 |
+
# STAGE 2: Computational Check
|
335 |
computational_result = run_computational_check(solution, gemma_model, gemma_tokenizer)
|
336 |
|
337 |
# FINAL VERDICT LOGIC
|
|
|
372 |
|
373 |
log = []
|
374 |
|
375 |
+
|
376 |
if not question.strip() or not solution.strip():
|
377 |
log.append("⚠️ Provide a question and a solution.")
|
378 |
yield "Please fill in both fields", "", render(log)
|
379 |
return
|
380 |
|
381 |
+
|
382 |
if not models_ready():
|
383 |
log.append("⏳ Loading models…")
|
384 |
yield "⏳ Working…", "", render(log)
|
|
|
444 |
yield "Runtime error", f"{type(e).__name__}: {e}", render(log)
|
445 |
|
446 |
|
447 |
+
# ---------------- UI: streaming ----------------
|
448 |
with gr.Blocks(title="Math Solution Classifier", theme=gr.themes.Soft()) as app:
|
449 |
gr.Markdown("# 🧮 Math Solution Classifier")
|
450 |
gr.Markdown(
|
|
|
665 |
else:
|
666 |
self.balance["wrong"] += 1
|
667 |
|
668 |
+
# ===== CSV hookup =====
|
669 |
from pathlib import Path
|
670 |
import time
|
671 |
|
|
|
673 |
POOL = load_examples_csv(str(CSV_PATH))
|
674 |
|
675 |
def new_selector(seed: int | None = None):
|
676 |
+
|
677 |
return ExampleSelector(POOL, seed=seed or int(time.time()) & 0xFFFF)
|
678 |
|
679 |
+
|
680 |
def _truncate(s: str, n: int = 100) -> str:
|
681 |
s = s or ""
|
682 |
return s if len(s) <= n else s[: n - 1] + "…"
|
|
|
694 |
return table
|
695 |
|
696 |
|
|
|
697 |
def ui_surprise(selector, filter_label="any"):
|
698 |
"""Pick one example and push it straight to inputs; persist selector state."""
|
699 |
if selector is None or not POOL:
|
|
|
703 |
return selector, gr.update(), gr.update()
|
704 |
return selector, r["question"], r["solution"]
|
705 |
|
706 |
+
components_to_clear = [
|
707 |
+
question_input,
|
708 |
+
solution_input,
|
709 |
+
expected_label_example,
|
710 |
+
]
|
711 |
|
712 |
|
|
|
713 |
with gr.Blocks(title="Math Solution Classifier", theme=gr.themes.Soft()) as app:
|
714 |
gr.Markdown("# 🧮 Math Solution Classifier")
|
715 |
gr.Markdown(
|
|
|
718 |
" \n Press 'Surprise me' to randomly select a sample question/answer pair from our dataset."
|
719 |
)
|
720 |
|
721 |
+
|
722 |
selector_state = gr.State(new_selector())
|
723 |
|
724 |
with gr.Row():
|
|
|
726 |
with gr.Column(scale=1):
|
727 |
question_input = gr.Textbox(
|
728 |
label="Math Question",
|
729 |
+
placeholder="e.g., What is 14 divided by 2?",
|
730 |
lines=3,
|
731 |
)
|
732 |
solution_input = gr.Textbox(
|
733 |
label="Proposed Solution",
|
734 |
+
placeholder="e.g., 14/2 = 9",
|
735 |
lines=8,
|
736 |
)
|
737 |
expected_label_example = gr.Textbox(
|
|
|
741 |
with gr.Row():
|
742 |
classify_btn = gr.Button("Classify Solution", variant="primary")
|
743 |
surprise_btn = gr.Button("Surprise me") # <- new
|
744 |
+
clear_btn = clear_btn = gr.ClearButton(components=components_to_clear, value="Clear")
|
745 |
|
746 |
# -------- Right: outputs --------
|
747 |
with gr.Column(scale=1):
|
|
|
749 |
explanation_output = gr.Textbox(label="Explanation", interactive=False, lines=6)
|
750 |
status_output = gr.Markdown(value="*(idle)*") # live stage updates
|
751 |
|
752 |
+
# -------- Curated starter examples --------
|
753 |
gr.Examples(
|
754 |
examples=[
|
755 |
["John has three apples and Mary has seven, how many apples do they have together?",
|
|
|
758 |
["A rectangle's length is twice its width. If the width is 7 cm, what is the perimeter of the rectangle?",
|
759 |
"The length of the rectangle is 2 * 7 = 14 cm.\n The perimeter is 14 + 7 = 21 cm.\n Final answer: 21",
|
760 |
"Conceptually flawed"],
|
761 |
+
["",
|
762 |
+
"The lateral area of the bottom layer is 2 * 3.14 * 20 * 8 = 1004.8.\n The lateral area of the middle layer is 2 * 3.14 * 15 * 8 = 753.6.\n The lateral area of the top layer is 2 * 3.14 * 10 * 8 = 502.4.\n The exposed top surface is the area of the smallest circle: 3.14 * (10*10) = 314.\n The total frosted area is 1004.8 + 753.6 + 502.4 + 314 = 2888.8 sq cm.\n FINAL ANSWER: 2888.8",
|
763 |
+
"Computationally flawed"],
|
764 |
["What is 15% of 200?",
|
765 |
"15% = 15/100 = 0.15\n0.15 × 200 = 30\n Final answer: 30",
|
766 |
"Correct"],
|
767 |
["A circle has a radius of 5 cm. Using the approximation pi = 3.14, what is the circumference of the circle?",
|
768 |
"The circumference of the circle is 3.14 * 5 = 15.7 cm.\n Final answer: 15.7",
|
769 |
"Conceptually flawed"],
|
770 |
+
["A library is building new shelves. Each shelf is 1.2 meters long. A standard book is 3 cm thick, and a large book is 5 cm thick. A shelf must hold 20 standard books and 10 large books. After filling a shelf with these books, how much space, in centimeters, is left on the shelf?",
|
771 |
+
"The shelf length in centimeters is 1.2 * 100 = 120 cm.\n The space taken by standard books is 20 * 3 = 60 cm.\n The space taken by large books is 10 * 5 = 50 cm.\n The total space taken is 60 + 50 = 110 cm.\n The remaining space is 120 + 110 = 230 cm.\n FINAL ANSWER: 230",
|
772 |
+
"Conceptually flawed"],
|
773 |
["A 24-meter rope is cut into 6 equal pieces. A climber uses 2 of those pieces. How many meters of rope are still unused?",
|
774 |
"The length of each piece is 24 / 6 = 4 m.\n The climber uses 2 × 4 m = 8 m of rope.\n There are 24 m − 8 m = 16 m of rope still unused.\n Final answer: 16",
|
775 |
"Correct"]
|
|
|
779 |
|
780 |
|
781 |
# ---------- Wiring ----------
|
782 |
+
# Main classify
|
783 |
classify_btn.click(
|
784 |
fn=classify_solution_stream,
|
785 |
inputs=[question_input, solution_input],
|
|
|
788 |
concurrency_limit=1,
|
789 |
)
|
790 |
|
|
|
791 |
surprise_btn.click(
|
792 |
fn=ui_surprise,
|
793 |
+
inputs=[selector_state],
|
794 |
+
outputs=[selector_state, question_input, solution_input],
|
795 |
queue=True,
|
796 |
)
|
797 |
|
798 |
|
|
|
799 |
app.queue()
|
800 |
|
801 |
if __name__ == "__main__":
|