mcamargo00 commited on
Commit
44ab37e
·
verified ·
1 Parent(s): 4599113

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +372 -22
app.py CHANGED
@@ -389,7 +389,7 @@ def classify_solution_stream(question: str, solution: str):
389
  return
390
  log[-1] = "✅ Models loaded."
391
 
392
- verdicts_mapping = {"correct": "Correct.", "conceptual_error": "Conceptual error.", "computational_error": "Computational error."}
393
 
394
  try:
395
  # ---------- Stage 1: Conceptual ----------
@@ -476,47 +476,397 @@ with gr.Blocks(title="Math Solution Classifier", theme=gr.themes.Soft()) as app:
476
  status_output = gr.Markdown(value="*(idle)*") # live stage updates
477
 
478
  # -------- Examples --------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
479
  gr.Examples(
480
  examples=[
481
- ["John has three apples and Mary has seven, how many apples do they have together?",
 
 
482
  "They have 7 + 3 = 11 apples."],
483
- ["A tank holds 60 liters of fuel. A generator uses fuel at a rate of 5 liters per hour. After running for 9 hours, how many liters are still in the tank?",
484
- "The generator uses 5 L/h × 9 h = 45 L of fuel in 9 hours.\n Then, there remain 60 L + 45 L = 105 L in the tank.\n Final answer: 105 L"],
485
  ["What is 15% of 200?",
486
  "15% = 15/100 = 0.15\n0.15 × 200 = 30"],
487
- ["A 24-meter rope is cut into 6 equal pieces. A climber uses 2 of those pieces. How many meters of rope are still unused?",
488
- "The length of each piece is 24 / 6 = 4 m.\n The climber uses 2 × 4 m = 8 m of rope.\n There are 24 m − 8 m = 16 m of rope still unused."]
489
  ],
490
  inputs=[question_input, solution_input],
491
  )
492
 
493
- # -------- Wiring --------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
494
  classify_btn.click(
495
- fn=classify_solution_stream, # <- generator that yields (classification, explanation, status)
496
  inputs=[question_input, solution_input],
497
  outputs=[classification_output, explanation_output, status_output],
498
- show_progress=False, # <- no Gradio progress bars
499
- concurrency_limit=1, # <- per-event limit (good for GPU)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
500
  )
501
 
502
- clear_btn.click(
503
- lambda: ("", "", "", "", "*(idle)*"),
 
504
  inputs=None,
505
- outputs=[
506
- question_input,
507
- solution_input,
508
- classification_output,
509
- explanation_output,
510
- status_output,
511
- ],
 
 
512
  queue=False,
513
  )
514
 
515
- # enable queue for streaming (no deprecated args)
516
  app.queue()
517
 
518
  if __name__ == "__main__":
519
  app.launch()
520
 
521
-
522
-
 
389
  return
390
  log[-1] = "✅ Models loaded."
391
 
392
+ verdicts_mapping = {"correct": "Correct.", "conceptual_error": "Conceptually flawed.", "computational_error": "Computationally flawed."}
393
 
394
  try:
395
  # ---------- Stage 1: Conceptual ----------
 
476
  status_output = gr.Markdown(value="*(idle)*") # live stage updates
477
 
478
  # -------- Examples --------
479
+
480
+
481
+ import csv, random
482
+ from typing import Dict, Optional, List, Tuple
483
+
484
+ # ---------- Data structures ----------
485
+ class QAItem:
486
+ __slots__ = ("id", "question", "correct", "wrong", "error_type")
487
+ def __init__(self, id: int, question: str,
488
+ correct: Optional[str], wrong: Optional[str], error_type: Optional[str]):
489
+ self.id = id
490
+ self.question = question
491
+ self.correct = correct or None
492
+ self.wrong = wrong or None
493
+ self.error_type = (error_type or "").strip() or None # e.g., "computational_error" / "conceptual_error"
494
+
495
+ # ---------- CSV loader ----------
496
+ def load_examples_csv(path: str) -> Dict[int, QAItem]:
497
+ """
498
+ Loads CSV and returns a dict: {question_id: QAItem}
499
+ Accepts either 1 row per question (both solutions present) or 2 rows merged by `index`.
500
+ """
501
+ def norm(s: Optional[str]) -> str:
502
+ return (s or "").strip()
503
+
504
+ pool: Dict[int, QAItem] = {}
505
+ with open(path, "r", encoding="utf-8") as f:
506
+ rdr = csv.DictReader(f)
507
+ # normalize headers
508
+ fieldmap = {k: k.strip().lower() for k in rdr.fieldnames or []}
509
+ rows = []
510
+ for row in rdr:
511
+ r = {fieldmap.get(k, k).lower(): v for k, v in row.items()}
512
+ rows.append(r)
513
+
514
+ for r in rows:
515
+ try:
516
+ qid = int(norm(r.get("index")))
517
+ except Exception:
518
+ # skip bad index rows
519
+ continue
520
+
521
+ q = norm(r.get("question"))
522
+ ca = norm(r.get("correct_answer"))
523
+ wa = norm(r.get("wrong_answer"))
524
+ et = norm(r.get("error_type"))
525
+
526
+ if qid not in pool:
527
+ pool[qid] = QAItem(qid, q, ca, wa, et)
528
+ else:
529
+ # merge if the CSV has multiple rows per id
530
+ item = pool[qid]
531
+ if not item.question and q:
532
+ item.question = q
533
+ if ca and not item.correct:
534
+ item.correct = ca
535
+ if wa and not item.wrong:
536
+ item.wrong = wa
537
+ if et and not item.error_type:
538
+ item.error_type = et
539
+
540
+ # drop questions that have neither solution
541
+ pool = {k: v for k, v in pool.items() if (v.correct or v.wrong)}
542
+ return pool
543
+
544
+ # ---------- Selection state with balance ----------
545
+ class ExampleSelector:
546
+ """
547
+ Keeps one solution per question, balances correct vs wrong across picks,
548
+ and supports label filtering.
549
+ """
550
+ def __init__(self, pool: Dict[int, QAItem], seed: Optional[int] = None):
551
+ self.pool = pool
552
+ self._rng = random.Random(seed)
553
+ self.reset()
554
+
555
+ def reset(self):
556
+ self.ids: List[int] = list(self.pool.keys())
557
+ self._rng.shuffle(self.ids)
558
+ self.cursor: int = 0
559
+ self.seen_ids: set[int] = set()
560
+ self.balance = {"correct": 0, "wrong": 0}
561
+
562
+ # ---- public API ----
563
+ def next_batch(self, k: int, filter_label: str = "any") -> List[Dict]:
564
+ """Return up to k rows (id, question, solution, label), updating internal state."""
565
+ out: List[Dict] = []
566
+ # iterate over id list cyclically until filled or exhausted
567
+ tried = 0
568
+ max_tries = len(self.ids) * 2 # guard
569
+ while len(out) < k and tried < max_tries:
570
+ if self.cursor >= len(self.ids):
571
+ break
572
+ qid = self.ids[self.cursor]
573
+ self.cursor += 1
574
+ tried += 1
575
+
576
+ if qid in self.seen_ids:
577
+ continue
578
+
579
+ item = self.pool[qid]
580
+ variant = self._choose_variant(item, filter_label)
581
+ if variant is None:
582
+ continue # no variant matches filter
583
+
584
+ row = self._build_row(item, variant)
585
+ out.append(row)
586
+ self._mark_used(item, variant)
587
+ return out
588
+
589
+ def surprise(self, filter_label: str = "any") -> Optional[Dict]:
590
+ """Pick a single row at random (respecting filter & balance)."""
591
+ candidates = [qid for qid in self.ids if qid not in self.seen_ids and self._variant_available(self.pool[qid], filter_label)]
592
+ if not candidates:
593
+ return None
594
+ qid = self._rng.choice(candidates)
595
+ item = self.pool[qid]
596
+ variant = self._choose_variant(item, filter_label)
597
+ if variant is None:
598
+ return None
599
+ row = self._build_row(item, variant)
600
+ self._mark_used(item, variant)
601
+ return row
602
+
603
+ # ---- helpers ----
604
+ def _variant_available(self, item: QAItem, filter_label: str) -> bool:
605
+ return self._choose_variant(item, filter_label, dry_run=True) is not None
606
+
607
+ def _choose_variant(self, item: QAItem, filter_label: str, dry_run: bool = False) -> Optional[str]:
608
+ """
609
+ Returns 'correct' or 'wrong' or None given availability, filter, and current balance.
610
+ filter_label ∈ {"any","correct","wrong","computational_error","conceptual_error"}
611
+ """
612
+ has_correct = bool(item.correct)
613
+ has_wrong = bool(item.wrong)
614
+
615
+ want_correct = (filter_label == "correct")
616
+ want_wrong = (filter_label == "wrong") or (filter_label in ("computational_error", "conceptual_error"))
617
+
618
+ # Build allowed set based on filter
619
+ allowed = []
620
+ if filter_label == "any":
621
+ if has_correct: allowed.append("correct")
622
+ if has_wrong: allowed.append("wrong")
623
+ elif want_correct:
624
+ if has_correct: allowed.append("correct")
625
+ elif want_wrong:
626
+ if has_wrong and (filter_label in ("wrong", "any") or (item.error_type == filter_label)):
627
+ allowed.append("wrong")
628
+
629
+ if not allowed:
630
+ return None
631
+
632
+ if len(allowed) == 1:
633
+ return allowed[0]
634
+
635
+ # Balance correct vs wrong across already-shown items
636
+ c, w = self.balance["correct"], self.balance["wrong"]
637
+ if c > w and "wrong" in allowed:
638
+ return "wrong"
639
+ if w > c and "correct" in allowed:
640
+ return "correct"
641
+ # tie-breaker: prefer wrong when specifically filtering to an error type
642
+ if filter_label in ("computational_error", "conceptual_error") and "wrong" in allowed:
643
+ return "wrong"
644
+ return self._rng.choice(allowed)
645
+
646
+ def _build_row(self, item: QAItem, variant: str) -> Dict:
647
+ if variant == "correct":
648
+ label = "correct"
649
+ sol = item.correct
650
+ else:
651
+ label = item.error_type or "wrong"
652
+ sol = item.wrong
653
+ return {
654
+ "id": item.id,
655
+ "question": item.question,
656
+ "solution": sol,
657
+ "label": label, # "correct" | "computational_error" | "conceptual_error" | "wrong"
658
+ }
659
+
660
+ def _mark_used(self, item: QAItem, variant: str):
661
+ # we mark the whole question as used so we never show both solutions
662
+ self.seen_ids.add(item.id)
663
+ if variant == "correct":
664
+ self.balance["correct"] += 1
665
+ else:
666
+ self.balance["wrong"] += 1
667
+
668
+ # ===== CSV hookup (place near other imports / globals) =====
669
+ from pathlib import Path
670
+ import time
671
+
672
+ CSV_PATH = Path(__file__).resolve().parent / "examples.csv"
673
+ POOL = load_examples_csv(str(CSV_PATH))
674
+
675
+ def new_selector(seed: int | None = None):
676
+ # per-session selector; seed for reproducibility if you want
677
+ return ExampleSelector(POOL, seed=seed or int(time.time()) & 0xFFFF)
678
+
679
+ # small helpers for UI
680
+ def _truncate(s: str, n: int = 100) -> str:
681
+ s = s or ""
682
+ return s if len(s) <= n else s[: n - 1] + "…"
683
+
684
+ def _rows_to_table(rows: list[dict]) -> list[list[str]]:
685
+ # Dataframe value: list of rows [ID, Label, Question, Solution]
686
+ table = []
687
+ for r in rows:
688
+ table.append([
689
+ str(r["id"]),
690
+ r["label"],
691
+ _truncate(r["question"], 120),
692
+ _truncate(r["solution"], 120),
693
+ ])
694
+ return table
695
+
696
+ def _dropdown_choices(rows: list[dict]) -> list[tuple[str, int]]:
697
+ # Friendly labels mapped to ID values
698
+ choices = []
699
+ for r in rows:
700
+ label = f'#{r["id"]} — {r["label"]} — {_truncate(r["question"], 60)}'
701
+ choices.append((label, r["id"]))
702
+ return choices
703
+
704
+ # ===== Gradio callbacks for examples =====
705
+ def ui_see_more(selector, rows, filter_label):
706
+ """Append a chunk of examples to the browser."""
707
+ chunk = selector.next_batch(k=6, filter_label=filter_label)
708
+ rows = (rows or []) + chunk
709
+ return (
710
+ rows, # rows_state
711
+ gr.update(value=_rows_to_table(rows)), # examples_df
712
+ gr.update(choices=_dropdown_choices(rows), value=None), # row_picker
713
+ )
714
+
715
+ def ui_reset_examples():
716
+ """Reset per-session selector and clear the browser."""
717
+ sel = new_selector()
718
+ rows: list[dict] = []
719
+ return (
720
+ sel, # selector_state
721
+ rows, # rows_state
722
+ gr.update(value=_rows_to_table(rows)), # examples_df
723
+ gr.update(choices=[], value=None), # row_picker
724
+ )
725
+
726
+ def ui_load_selected(rows, selected_id):
727
+ """Load the selected example into the main inputs."""
728
+ if not rows or selected_id is None:
729
+ return gr.update(), gr.update()
730
+ for r in rows:
731
+ if r["id"] == selected_id:
732
+ return r["question"], r["solution"]
733
+ return gr.update(), gr.update()
734
+
735
+ def ui_surprise(selector, filter_label):
736
+ """Pick one example and push it straight to inputs."""
737
+ r = selector.surprise(filter_label=filter_label)
738
+ if not r:
739
+ # no more examples; keep inputs unchanged
740
+ return gr.update(), gr.update()
741
+ return r["question"], r["solution"]
742
+
743
+
744
+ # ---------------- UI: add CSV-driven examples ----------------
745
+ with gr.Blocks(title="Math Solution Classifier", theme=gr.themes.Soft()) as app:
746
+ gr.Markdown("# 🧮 Math Solution Classifier")
747
+ gr.Markdown(
748
+ "Classify math solutions as **correct**, **conceptually flawed**, or **computationally flawed**. "
749
+ "Live status updates appear below as the two-stage pipeline runs."
750
+ )
751
+
752
+ # Per-session state
753
+ selector_state = gr.State(new_selector())
754
+ rows_state = gr.State([]) # list[dict] rows currently in the browser
755
+
756
+ with gr.Row():
757
+ # -------- Left: inputs --------
758
+ with gr.Column(scale=1):
759
+ question_input = gr.Textbox(
760
+ label="Math Question",
761
+ placeholder="e.g., Solve for x: 2x + 5 = 13",
762
+ lines=3,
763
+ )
764
+ solution_input = gr.Textbox(
765
+ label="Proposed Solution",
766
+ placeholder="e.g., 2x + 5 = 13\n2x = 13 - 5\n2x = 8\nx = 4",
767
+ lines=8,
768
+ )
769
+ with gr.Row():
770
+ classify_btn = gr.Button("Classify Solution", variant="primary")
771
+ surprise_btn = gr.Button("Surprise me") # <- new
772
+ clear_btn = gr.Button("Clear")
773
+
774
+ # -------- Right: outputs --------
775
+ with gr.Column(scale=1):
776
+ classification_output = gr.Textbox(label="Classification", interactive=False)
777
+ explanation_output = gr.Textbox(label="Explanation", interactive=False, lines=6)
778
+ status_output = gr.Markdown(value="*(idle)*") # live stage updates
779
+
780
+ # -------- Curated starter examples (static) --------
781
  gr.Examples(
782
  examples=[
783
+ ["Solve for x: 2x + 5 = 13",
784
+ "2x + 5 = 13\n2x = 13 - 5\n2x = 8\nx = 4"],
785
+ ["John has three apples and Mary has seven, how many apples do they have together?",
786
  "They have 7 + 3 = 11 apples."],
 
 
787
  ["What is 15% of 200?",
788
  "15% = 15/100 = 0.15\n0.15 × 200 = 30"],
 
 
789
  ],
790
  inputs=[question_input, solution_input],
791
  )
792
 
793
+ # -------- Dynamic browser (CSV) --------
794
+ with gr.Accordion("Browse more examples", open=False):
795
+ with gr.Row():
796
+ filter_dd = gr.Dropdown(
797
+ label="Filter",
798
+ choices=[
799
+ ("Any", "any"),
800
+ ("Correct only", "correct"),
801
+ ("Conceptual error only", "conceptual_error"),
802
+ ("Computational error only", "computational_error"),
803
+ ],
804
+ value="any",
805
+ allow_custom_value=False,
806
+ )
807
+ see_more_btn = gr.Button("See more")
808
+ reset_list_btn = gr.Button("Reset list")
809
+
810
+ examples_df = gr.Dataframe(
811
+ headers=["ID", "Label", "Question", "Solution"],
812
+ value=[],
813
+ interactive=False,
814
+ row_count=(0, "dynamic"),
815
+ col_count=4,
816
+ wrap=True,
817
+ height=260,
818
+ label="Examples",
819
+ )
820
+ with gr.Row():
821
+ row_picker = gr.Dropdown(label="Select example to load", choices=[], value=None, scale=2)
822
+ load_btn = gr.Button("Load to editor", scale=1)
823
+
824
+ # ---------- Wiring ----------
825
+ # Main classify (streaming)
826
  classify_btn.click(
827
+ fn=classify_solution_stream,
828
  inputs=[question_input, solution_input],
829
  outputs=[classification_output, explanation_output, status_output],
830
+ show_progress=False,
831
+ concurrency_limit=1,
832
+ )
833
+
834
+ # Surprise me → fills inputs from the CSV pool
835
+ surprise_btn.click(
836
+ fn=ui_surprise,
837
+ inputs=[selector_state, filter_dd],
838
+ outputs=[question_input, solution_input],
839
+ queue=True,
840
+ )
841
+
842
+ # See more → appends rows to the browser
843
+ see_more_btn.click(
844
+ fn=ui_see_more,
845
+ inputs=[selector_state, rows_state, filter_dd],
846
+ outputs=[rows_state, examples_df, row_picker],
847
+ queue=False,
848
  )
849
 
850
+ # Reset list → new selector + clear table
851
+ reset_list_btn.click(
852
+ fn=ui_reset_examples,
853
  inputs=None,
854
+ outputs=[selector_state, rows_state, examples_df, row_picker],
855
+ queue=False,
856
+ )
857
+
858
+ # Load selected row → fills main inputs
859
+ load_btn.click(
860
+ fn=ui_load_selected,
861
+ inputs=[rows_state, row_picker],
862
+ outputs=[question_input, solution_input],
863
  queue=False,
864
  )
865
 
866
+ # Enable queue for streaming
867
  app.queue()
868
 
869
  if __name__ == "__main__":
870
  app.launch()
871
 
872
+