Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
import requests | |
from transformers import pipeline | |
from sentence_transformers import SentenceTransformer | |
from qdrant_client import QdrantClient | |
from datetime import datetime | |
import dspy | |
import json | |
import google.generativeai as genai | |
# Configure Gemini API | |
genai.configure(api_key="AIzaSyBO3-HG-WcITn58PdpK7mMyvFQitoH00qA") # Replace with your actual Gemini API key | |
# Load Gemini model | |
def output_guard(answer): | |
# Check if answer is empty or too short | |
if not answer or len(answer.strip()) < 20: | |
print("Output guard triggered: answer too short or empty.") | |
return False | |
# You can add more checks here if needed | |
return True | |
import os | |
from datetime import datetime | |
# Safe path for Hugging Face Spaces (will reset on restart) | |
feedback_path = "feedback.json" | |
def store_feedback(question, answer, feedback, correct_answer): | |
entry = { | |
"question": question, | |
"model_answer": answer, | |
"feedback": feedback, | |
"correct_answer": correct_answer, | |
"timestamp": str(datetime.now()) | |
} | |
print("Attempting to store feedback:", entry) | |
try: | |
with open(feedback_path, "a") as f: | |
f.write(json.dumps(entry) + "\n") | |
print("โ Feedback saved at", feedback_path) | |
except Exception as e: | |
print("โ Error writing feedback:", e) | |
import re | |
def latex_to_plain_math(latex_expr): | |
# Replace LaTeX formatting with plain text math | |
latex_expr = latex_expr.strip() | |
latex_expr = re.sub(r"\\frac\{(.+?)\}\{(.+?)\}", r"(\1) / (\2)", latex_expr) | |
latex_expr = re.sub(r"\\sqrt\{(.+?)\}", r"โ(\1)", latex_expr) | |
latex_expr = latex_expr.replace("^2", "ยฒ").replace("^3", "ยณ") | |
latex_expr = re.sub(r"\^(\d)", r"^\1", latex_expr) # other powers | |
latex_expr = latex_expr.replace("\\pm", "ยฑ") | |
latex_expr = latex_expr.replace("\\cdot", "โ ") | |
latex_expr = latex_expr.replace("{", "").replace("}", "") | |
return latex_expr | |
# === Load Models === | |
print("Loading zero-shot classifier...") | |
classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli") | |
print("Loading embedding model...") | |
embedding_model = SentenceTransformer("intfloat/e5-large") | |
# Use a lighter model for testing | |
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, pipeline | |
# === Qdrant Setup === | |
print("Connecting to Qdrant...") | |
qdrant_client = QdrantClient(path="qdrant_data") | |
collection_name = "math_problems" | |
# === Guard Function === | |
def is_valid_math_question(text): | |
candidate_labels = ["math", "not math"] | |
result = classifier(text, candidate_labels) | |
print("Classifier result:", result) | |
return result['labels'][0] == "math" and result['scores'][0] > 0.7 | |
# === Retrieval === | |
def retrieve_from_qdrant(query): | |
print("Retrieving context from Qdrant...") | |
query_vector = embedding_model.encode(query).tolist() | |
hits = qdrant_client.search(collection_name=collection_name, query_vector=query_vector, limit=1) | |
print("Retrieved hits:", hits) | |
return [hit.payload for hit in hits] if hits else [] | |
# === Web Search === | |
def web_search_tavily(query): | |
print("Calling Tavily...") | |
TAVILY_API_KEY = "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" | |
response = requests.post( | |
"https://api.tavily.com/search", | |
json={"api_key": TAVILY_API_KEY, "query": query, "search_depth": "advanced"}, | |
) | |
return response.json().get("answer", "No answer found from Tavily.") | |
# === DSPy Signature === | |
class MathAnswer(dspy.Signature): | |
question = dspy.InputField() | |
retrieved_context = dspy.InputField() | |
answer = dspy.OutputField() | |
# === DSPy Programs === | |
import google.generativeai as genai | |
# Configure Gemini | |
genai.configure(api_key="AIzaSyBO3-HG-WcITn58PdpK7mMyvFQitoH00qA") # Replace with your key | |
class MathRetrievalQA(dspy.Program): | |
def forward(self, question): | |
print("Inside MathRetrievalQA...") | |
context_items = retrieve_from_qdrant(question) | |
context = "\n".join([item["solution"] for item in context_items if "solution" in item]) | |
print("Context for generation:", context) | |
f = latex_to_plain_math(context) | |
print(f) | |
if not context: | |
return {"answer": "", "retrieved_context": ""} | |
prompt = f""" | |
You are a math textbook author. Write a clear, professional, and well-formatted solution for the following math problem, using proper LaTeX formatting in every step. | |
Format the following LaTeX-based math solution into a clean, human-readable explanation as found in textbooks. Use standard math symbols like ยฑ, โ, fractions with slashes (e.g. (a + b)/c), and superscripts with ^. Do not use LaTeX syntax or backslashes. Do not wrap equations in dollar signs. Present the steps clearly using numbered headings. Keep all fractions in plain text form. | |
Problem: {question} | |
Use the following context if needed: | |
{f} | |
Write only the formatted solution, as it would appear in a math textbook. please give me well formated as using stantard math symbols like +,=.- ,x,/. | |
""" | |
try: | |
model = genai.GenerativeModel('gemini-2.0-flash') # or use 'gemini-1.5-flash' | |
response = model.generate_content(prompt) | |
formatted_answer = response.text | |
print("Gemini Answer:", formatted_answer) | |
return {"answer": formatted_answer, "retrieved_context": context} | |
except Exception as e: | |
print("Gemini generation error:", e) | |
return {"answer": "โ ๏ธ Gemini failed to generate an answer.", "retrieved_context": context} | |
# return dspy.Output(answer=answer, retrieved_context=context) | |
class WebFallbackQA(dspy.Program): | |
def forward(self, question): | |
print("Fallback to Tavily...") | |
answer = web_search_tavily(question) | |
# return dspy.Output(answer=answer, retrieved_context="Tavily") | |
return {"answer": answer, "retrieved_context": "Tavily"} | |
class MathRouter(dspy.Program): | |
def forward(self, question): | |
print("Routing question:", question) | |
if not is_valid_math_question(question): | |
return {"answer": "โ Only math questions are accepted. Please rephrase.", "retrieved_context": ""} | |
result = MathRetrievalQA().forward(question) | |
# Apply output guard here | |
return result if result["answer"] else WebFallbackQA().forward(question) | |
router = MathRouter() | |
# === Gradio Functions === | |
def ask_question(question): | |
print("ask_question() called with:", question) | |
result = router.forward(question) | |
print("Result:", result) | |
#return result.answer, question, result.answer | |
return result["answer"], question, result["answer"] | |
def submit_feedback(question, model_answer, feedback, correct_answer): | |
store_feedback(question, model_answer, feedback, correct_answer) | |
return "โ Feedback received. Thank you!" | |
# === Gradio UI === | |
with gr.Blocks() as demo: | |
gr.Markdown("## ๐งฎ Math Agent") | |
with gr.Tab("Ask a Math Question & Submit Feedback"): | |
with gr.Row(): | |
question_input = gr.Textbox(label="Enter your math question", lines=2) | |
submit_btn = gr.Button("Get Answer") | |
gr.Markdown("### ๐ง Answer:") | |
answer_output = gr.Markdown() | |
# Hidden fields to hold question and answer for feedback inputs | |
hidden_q = gr.Textbox(visible=False) | |
hidden_a = gr.Textbox(visible=False) | |
# Connect submit button to ask_question functio | |
submit_btn.click(fn=ask_question, inputs=[question_input], outputs=[answer_output, hidden_q, hidden_a]) | |
gr.Markdown("### ๐ Submit Feedback") | |
fb_like = gr.Radio(["๐", "๐"], label="Was the answer helpful?") | |
fb_correct = gr.Textbox(label="Correct Answer (optional) or Comments") | |
fb_submit_btn = gr.Button("Submit Feedback") | |
fb_status = gr.Textbox(label="Status", interactive=False) | |
feedback_file = gr.File(label="๐ Download Saved Feedback", interactive=False) | |
# Feedback submit button uses hidden fields + feedback inputs | |
fb_submit_btn.click( | |
fn=submit_feedback, | |
inputs=[hidden_q, hidden_a, fb_like, fb_correct], | |
outputs=[fb_status] | |
) | |
# Update the file download component | |
fb_submit_btn.click( | |
fn=lambda: feedback_path, | |
outputs=[feedback_file] | |
) | |
demo.launch(share=True, debug=True) | |