Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
import requests | |
from transformers import pipeline | |
from sentence_transformers import SentenceTransformer | |
from qdrant_client import QdrantClient | |
from datetime import datetime | |
import dspy | |
import json | |
# === Load Models === | |
print("Loading zero-shot classifier...") | |
classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli") | |
print("Loading embedding model...") | |
embedding_model = SentenceTransformer("intfloat/e5-large") | |
# === Qdrant Setup === | |
print("Connecting to Qdrant...") | |
qdrant_client = QdrantClient(path="qdrant_data") | |
collection_name = "math_problems" | |
# === Guard Function === | |
def is_valid_math_question(text): | |
candidate_labels = ["math", "not math"] | |
result = classifier(text, candidate_labels) | |
return result['labels'][0] == "math" and result['scores'][0] > 0.7 | |
# === Retrieval === | |
def retrieve_from_qdrant(query): | |
query_vector = embedding_model.encode(query).tolist() | |
hits = qdrant_client.search(collection_name=collection_name, query_vector=query_vector, limit=3) | |
return [hit.payload for hit in hits] if hits else [] | |
# === Web Search === | |
def web_search_tavily(query): | |
TAVILY_API_KEY = "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" | |
response = requests.post( | |
"https://api.tavily.com/search", | |
json={"api_key": TAVILY_API_KEY, "query": query, "search_depth": "advanced"}, | |
) | |
return response.json().get("answer", "No answer found from Tavily.") | |
# === DSPy Signature === | |
class MathAnswer(dspy.Signature): | |
question = dspy.InputField() | |
retrieved_context = dspy.InputField() | |
answer = dspy.OutputField() | |
# === DSPy Programs === | |
class MathRetrievalQA(dspy.Program): | |
def forward(self, question): | |
context_items = retrieve_from_qdrant(question) | |
context = "\n".join([item["solution"] for item in context_items if "solution" in item]) | |
if not context: | |
return {"answer": "", "retrieved_context": ""} | |
prompt = f"Question: {question}\nContext: {context}\nAnswer:" | |
return {"answer": prompt, "retrieved_context": context} | |
class WebFallbackQA(dspy.Program): | |
def forward(self, question): | |
answer = web_search_tavily(question) | |
return {"answer": answer, "retrieved_context": "Tavily"} | |
class MathRouter(dspy.Program): | |
def forward(self, question): | |
if not is_valid_math_question(question): | |
return {"answer": "โ Only math questions are accepted. Please rephrase.", "retrieved_context": ""} | |
result = MathRetrievalQA().forward(question) | |
return result if result["answer"] else WebFallbackQA().forward(question) | |
router = MathRouter() | |
# === Feedback Storage === | |
def store_feedback(question, answer, feedback, correct_answer): | |
entry = { | |
"question": question, | |
"model_answer": answer, | |
"feedback": feedback, | |
"correct_answer": correct_answer, | |
"timestamp": str(datetime.now()) | |
} | |
with open("feedback.json", "a") as f: | |
f.write(json.dumps(entry) + "\n") | |
# === Gradio Functions === | |
def ask_question(question): | |
result = router.forward(question) | |
return result["answer"], question, result["answer"] | |
def submit_feedback(question, model_answer, feedback, correct_answer): | |
store_feedback(question, model_answer, feedback, correct_answer) | |
return "โ Feedback received. Thank you!" | |
# === Gradio UI === | |
with gr.Blocks() as demo: | |
gr.Markdown("## ๐งฎ Math Question Answering with DSPy + Feedback") | |
with gr.Row(): | |
question_input = gr.Textbox(label="Enter your math question", lines=2) | |
answer_output = gr.Markdown() | |
hidden_q = gr.Textbox(visible=False) | |
hidden_a = gr.Textbox(visible=False) | |
submit_btn = gr.Button("Get Answer") | |
submit_btn.click(fn=ask_question, inputs=[question_input], outputs=[answer_output, hidden_q, hidden_a]) | |
# Feedback section on same page | |
gr.Markdown("### ๐ฌ Give Feedback") | |
fb_correct = gr.Textbox(label="Correct Answer (optional)") | |
fb_like = gr.Radio(["๐", "๐"], label="Was the answer helpful?") | |
fb_submit_btn = gr.Button("Submit Feedback") | |
fb_status = gr.Markdown() | |
fb_submit_btn.click(fn=submit_feedback, | |
inputs=[hidden_q, hidden_a, fb_like, fb_correct], | |
outputs=[fb_status]) | |
demo.launch(share=True, debug=True) | |