Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
import requests | |
from transformers import pipeline | |
from sentence_transformers import SentenceTransformer | |
from qdrant_client import QdrantClient | |
from datetime import datetime | |
import dspy | |
import json | |
# === Load Models === | |
print("Loading zero-shot classifier...") | |
classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli") | |
print("Loading embedding model...") | |
embedding_model = SentenceTransformer("intfloat/e5-large") | |
print("Loading text generation model...") | |
# Use a lighter model for testing | |
#qa_pipeline = pipeline("text-generation", model="gpt2") | |
# === Qdrant Setup === | |
print("Connecting to Qdrant...") | |
qdrant_client = QdrantClient(path="qdrant_data") | |
collection_name = "math_problems" | |
# === Guard Function === | |
def is_valid_math_question(text): | |
candidate_labels = ["math", "not math"] | |
result = classifier(text, candidate_labels) | |
print("Classifier result:", result) | |
return result['labels'][0] == "math" and result['scores'][0] > 0.7 | |
# === Retrieval === | |
def retrieve_from_qdrant(query): | |
print("Retrieving context from Qdrant...") | |
query_vector = embedding_model.encode(query).tolist() | |
hits = qdrant_client.search(collection_name=collection_name, query_vector=query_vector, limit=3) | |
print("Retrieved hits:", hits) | |
return [hit.payload for hit in hits] if hits else [] | |
# === Web Search === | |
def web_search_tavily(query): | |
print("Calling Tavily...") | |
TAVILY_API_KEY = "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" | |
response = requests.post( | |
"https://api.tavily.com/search", | |
json={"api_key": TAVILY_API_KEY, "query": query, "search_depth": "advanced"}, | |
) | |
return response.json().get("answer", "No answer found from Tavily.") | |
# === DSPy Signature === | |
class MathAnswer(dspy.Signature): | |
question = dspy.InputField() | |
retrieved_context = dspy.InputField() | |
answer = dspy.OutputField() | |
# === DSPy Programs === | |
# === DSPy Programs with Output Guard === | |
class MathRetrievalQA(dspy.Program): | |
def forward(self, question): | |
print("Inside MathRetrievalQA...") | |
context_items = retrieve_from_qdrant(question) | |
context = "\n".join([item["solution"] for item in context_items if "solution" in item]) | |
print("Context for generation:", context) | |
if not context: | |
return {"answer": "", "retrieved_context": ""} | |
# === Replace below with real model call when ready === | |
prompt = f"Question: {question}\nContext: {context}\nAnswer:" | |
print("Prompt for generation:", prompt) | |
# TEMP answer (replace with real generated output) | |
generated_answer = "This is a placeholder answer based on the context." # Simulated generation | |
print("Generated answer:", generated_answer) | |
# === Output Guard === | |
if not generated_answer or len(generated_answer.strip()) < 10 or "I don't know" in generated_answer: | |
return {"answer": "", "retrieved_context": context} | |
return {"answer": generated_answer.strip(), "retrieved_context": context} | |
class WebFallbackQA(dspy.Program): | |
def forward(self, question): | |
print("Fallback to Tavily...") | |
answer = web_search_tavily(question) | |
if not answer or len(answer.strip()) < 10 or "No answer found" in answer: | |
answer = "โ Sorry, I couldn't find a reliable answer." | |
return {"answer": answer.strip(), "retrieved_context": "Tavily"} | |
class MathRouter(dspy.Program): | |
def forward(self, question): | |
print("Routing question:", question) | |
if not is_valid_math_question(question): | |
return {"answer": "โ Only math questions are accepted. Please rephrase.", "retrieved_context": ""} | |
result = MathRetrievalQA().forward(question) | |
if result["answer"]: | |
return result | |
else: | |
return WebFallbackQA().forward(question) | |
# === Feedback Storage === | |
def store_feedback(question, answer, feedback, correct_answer): | |
entry = { | |
"question": question, | |
"model_answer": answer, | |
"feedback": feedback, | |
"correct_answer": correct_answer, | |
"timestamp": str(datetime.now()) | |
} | |
print("Storing feedback:", entry) | |
with open("feedback.json", "a") as f: | |
f.write(json.dumps(entry) + "\n") | |
def load_feedback_entries(): | |
entries = [] | |
try: | |
with open("feedback.json", "r") as f: | |
for line in f: | |
entry = json.loads(line) | |
entries.append(entry) | |
except FileNotFoundError: | |
pass | |
return entries | |
# === Gradio Functions === | |
# === Gradio Functions === | |
def ask_question(question): | |
print("ask_question() called with:", question) | |
result = router.forward(question) | |
print("Result:", result) | |
return result["answer"], question, result["answer"] | |
def submit_feedback(question, model_answer, feedback, correct_answer): | |
store_feedback(question, model_answer, feedback, correct_answer) | |
return "โ Feedback received. Thank you!" | |
# === Gradio UI === | |
with gr.Blocks() as demo: | |
gr.Markdown("## ๐งฎ Math Question Answering with DSPy + Feedback") | |
with gr.Tab("Ask a Math Question"): | |
with gr.Row(): | |
question_input = gr.Textbox(label="Enter your math question", lines=2) | |
gr.Markdown("### ๐ง Answer:") | |
answer_output = gr.Markdown() | |
#answer_output = gr.Markdown(label="Answer") | |
hidden_q = gr.Textbox(visible=False) | |
hidden_a = gr.Textbox(visible=False) | |
submit_btn = gr.Button("Get Answer") | |
submit_btn.click(fn=ask_question, inputs=[question_input], outputs=[answer_output, hidden_q, hidden_a]) | |
with gr.Tab("Submit Feedback"): | |
gr.Markdown("### Was the answer helpful?") | |
fb_question = gr.Textbox(label="Original Question") | |
fb_answer = gr.Textbox(label="Model's Answer") | |
fb_like = gr.Radio(["๐", "๐"], label="Your Feedback") | |
fb_correct = gr.Textbox(label="Correct Answer (optional)") | |
fb_submit_btn = gr.Button("Submit Feedback") | |
fb_status = gr.Textbox(label="Status", interactive=False) | |
feedback_display = gr.Dataframe(headers=["Question", "Answer", "Feedback", "Correct Answer", "Timestamp"], | |
row_count=10, max_rows=50, wrap=True) | |
def feedback_submission_and_display(question, answer, feedback, correct_answer): | |
store_feedback(question, answer, feedback, correct_answer) | |
entries = load_feedback_entries() | |
display_rows = [[ | |
e["question"], | |
e["model_answer"], | |
e["feedback"], | |
e["correct_answer"], | |
e["timestamp"] | |
] for e in entries] | |
return "โ Feedback received. Thank you!", display_rows | |
fb_submit_btn.click( | |
fn=feedback_submission_and_display, | |
inputs=[fb_question, fb_answer, fb_like, fb_correct], | |
outputs=[fb_status, feedback_display] | |
) | |
demo.launch(share=True, debug=True) | |