Math_Agent / app.py
manasagangotri's picture
Update app.py
7586e9b verified
raw
history blame
7.64 kB
import gradio as gr
import torch
import requests
from transformers import pipeline
from sentence_transformers import SentenceTransformer
from qdrant_client import QdrantClient
from datetime import datetime
import dspy
import json
import google.generativeai as genai
# Configure Gemini API
genai.configure(api_key="AIzaSyBO3-HG-WcITn58PdpK7mMyvFQitoH00qA") # Replace with your actual Gemini API key
# Load Gemini model
import re
def latex_to_plain_math(latex_expr):
# Replace LaTeX formatting with plain text math
latex_expr = latex_expr.strip()
latex_expr = re.sub(r"\\frac\{(.+?)\}\{(.+?)\}", r"(\1) / (\2)", latex_expr)
latex_expr = re.sub(r"\\sqrt\{(.+?)\}", r"√(\1)", latex_expr)
latex_expr = latex_expr.replace("^2", "²").replace("^3", "³")
latex_expr = re.sub(r"\^(\d)", r"^\1", latex_expr) # other powers
latex_expr = latex_expr.replace("\\pm", "±")
latex_expr = latex_expr.replace("\\cdot", "⋅")
latex_expr = latex_expr.replace("{", "").replace("}", "")
return latex_expr
# === Load Models ===
print("Loading zero-shot classifier...")
classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
print("Loading embedding model...")
embedding_model = SentenceTransformer("intfloat/e5-large")
# Use a lighter model for testing
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, pipeline
# === Qdrant Setup ===
print("Connecting to Qdrant...")
qdrant_client = QdrantClient(path="qdrant_data")
collection_name = "math_problems"
# === Guard Function ===
def is_valid_math_question(text):
candidate_labels = ["math", "not math"]
result = classifier(text, candidate_labels)
print("Classifier result:", result)
return result['labels'][0] == "math" and result['scores'][0] > 0.7
# === Retrieval ===
def retrieve_from_qdrant(query):
print("Retrieving context from Qdrant...")
query_vector = embedding_model.encode(query).tolist()
hits = qdrant_client.search(collection_name=collection_name, query_vector=query_vector, limit=3)
print("Retrieved hits:", hits)
return [hit.payload for hit in hits] if hits else []
# === Web Search ===
def web_search_tavily(query):
print("Calling Tavily...")
TAVILY_API_KEY = "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"
response = requests.post(
"https://api.tavily.com/search",
json={"api_key": TAVILY_API_KEY, "query": query, "search_depth": "advanced"},
)
return response.json().get("answer", "No answer found from Tavily.")
# === DSPy Signature ===
class MathAnswer(dspy.Signature):
question = dspy.InputField()
retrieved_context = dspy.InputField()
answer = dspy.OutputField()
# === DSPy Programs ===
import google.generativeai as genai
# Configure Gemini
genai.configure(api_key="AIzaSyBO3-HG-WcITn58PdpK7mMyvFQitoH00qA") # Replace with your key
class MathRetrievalQA(dspy.Program):
def forward(self, question):
print("Inside MathRetrievalQA...")
context_items = retrieve_from_qdrant(question)
context = "\n".join([item["solution"] for item in context_items if "solution" in item])
print("Context for generation:", context)
f = latex_to_plain_math(context)
print(f)
if not context:
return {"answer": "", "retrieved_context": ""}
prompt = f"""
You are a math textbook author. Write a clear, professional, and well-formatted solution for the following math problem, using proper LaTeX formatting in every step.
Format the following LaTeX-based math solution into a clean, human-readable explanation as found in textbooks. Use standard math symbols like ±, √, fractions with slashes (e.g. (a + b)/c), and superscripts with ^. Do not use LaTeX syntax or backslashes. Do not wrap equations in dollar signs. Present the steps clearly using numbered headings. Keep all fractions in plain text form.
Use the following context if needed:
{f}
Write only the formatted solution, as it would appear in a math textbook.
"""
try:
model = genai.GenerativeModel('gemini-2.0-flash') # or use 'gemini-1.5-flash'
response = model.generate_content(prompt)
formatted_answer = response.text
print("Gemini Answer:", formatted_answer)
return {"answer": formatted_answer, "retrieved_context": context}
except Exception as e:
print("Gemini generation error:", e)
return {"answer": "⚠️ Gemini failed to generate an answer.", "retrieved_context": context}
# return dspy.Output(answer=answer, retrieved_context=context)
class WebFallbackQA(dspy.Program):
def forward(self, question):
print("Fallback to Tavily...")
answer = web_search_tavily(question)
# return dspy.Output(answer=answer, retrieved_context="Tavily")
return {"answer": answer, "retrieved_context": "Tavily"}
class MathRouter(dspy.Program):
def forward(self, question):
print("Routing question:", question)
if not is_valid_math_question(question):
return dspy.Output(answer="❌ Only math questions are accepted. Please rephrase.", retrieved_context="")
result = MathRetrievalQA().forward(question)
#return result if result.answer else WebFallbackQA().forward(question)
return result if result["answer"] else WebFallbackQA().forward(question)
router = MathRouter()
# === Feedback Storage ===
def store_feedback(question, answer, feedback, correct_answer):
entry = {
"question": question,
"model_answer": answer,
"feedback": feedback,
"correct_answer": correct_answer,
"timestamp": str(datetime.now())
}
print("Storing feedback:", entry)
with open("feedback.json", "a") as f:
f.write(json.dumps(entry) + "\n")
# === Gradio Functions ===
def ask_question(question):
print("ask_question() called with:", question)
result = router.forward(question)
print("Result:", result)
#return result.answer, question, result.answer
return result["answer"], question, result["answer"]
def submit_feedback(question, model_answer, feedback, correct_answer):
store_feedback(question, model_answer, feedback, correct_answer)
return "✅ Feedback received. Thank you!"
# === Gradio UI ===
with gr.Blocks() as demo:
gr.Markdown("## 🧮 Math Question Answering with DSPy + Feedback")
with gr.Tab("Ask a Math Question"):
with gr.Row():
question_input = gr.Textbox(label="Enter your math question", lines=2)
gr.Markdown("### 🧠 Answer:")
answer_output = gr.Markdown()
#answer_output = gr.Markdown(label="Answer")
hidden_q = gr.Textbox(visible=False)
hidden_a = gr.Textbox(visible=False)
submit_btn = gr.Button("Get Answer")
submit_btn.click(fn=ask_question, inputs=[question_input], outputs=[answer_output, hidden_q, hidden_a])
with gr.Tab("Submit Feedback"):
gr.Markdown("### Was the answer helpful?")
fb_question = gr.Textbox(label="Original Question")
fb_answer = gr.Textbox(label="Model's Answer")
fb_like = gr.Radio(["👍", "👎"], label="Your Feedback")
fb_correct = gr.Textbox(label="Correct Answer (optional)")
fb_submit_btn = gr.Button("Submit Feedback")
fb_status = gr.Textbox(label="Status", interactive=False)
fb_submit_btn.click(fn=submit_feedback,
inputs=[fb_question, fb_answer, fb_like, fb_correct],
outputs=[fb_status])
demo.launch(share=True, debug=True)