import gradio as gr import torch import requests from transformers import pipeline from sentence_transformers import SentenceTransformer from qdrant_client import QdrantClient from datetime import datetime import dspy import json # === Load Models === print("Loading zero-shot classifier...") classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli") print("Loading embedding model...") embedding_model = SentenceTransformer("intfloat/e5-large") print("Loading text generation model...") # Use a lighter model for testing qa_pipeline = pipeline("text-generation", model="gpt2") # === Qdrant Setup === print("Connecting to Qdrant...") qdrant_client = QdrantClient(path="qdrant_data") collection_name = "math_problems" # === Guard Function === def is_valid_math_question(text): candidate_labels = ["math", "not math"] result = classifier(text, candidate_labels) print("Classifier result:", result) return result['labels'][0] == "math" and result['scores'][0] > 0.7 # === Retrieval === def retrieve_from_qdrant(query): print("Retrieving context from Qdrant...") query_vector = embedding_model.encode(query).tolist() hits = qdrant_client.search(collection_name=collection_name, query_vector=query_vector, limit=3) print("Retrieved hits:", hits) return [hit.payload for hit in hits] if hits else [] # === Web Search === def web_search_tavily(query): print("Calling Tavily...") TAVILY_API_KEY = "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" response = requests.post( "https://api.tavily.com/search", json={"api_key": TAVILY_API_KEY, "query": query, "search_depth": "advanced"}, ) return response.json().get("answer", "No answer found from Tavily.") # === DSPy Signature === class MathAnswer(dspy.Signature): question = dspy.InputField() retrieved_context = dspy.InputField() answer = dspy.OutputField() # === DSPy Programs === class MathRetrievalQA(dspy.Program): def forward(self, question): print("Inside MathRetrievalQA...") context_items = retrieve_from_qdrant(question) context = "\n".join([item["solution"] for item in context_items if "solution" in item]) print("Context for generation:", context) if not context: return dspy.Output(answer="", retrieved_context="") prompt = f"Question: {question}\nContext: {context}\nAnswer:" print("Generating answer...") answer = qa_pipeline(prompt, max_new_tokens=100)[0]["generated_text"] print("Generated answer:", answer) return dspy.Output(answer=answer, retrieved_context=context) class WebFallbackQA(dspy.Program): def forward(self, question): print("Fallback to Tavily...") answer = web_search_tavily(question) return dspy.Output(answer=answer, retrieved_context="Tavily") class MathRouter(dspy.Program): def forward(self, question): print("Routing question:", question) if not is_valid_math_question(question): return dspy.Output(answer="❌ Only math questions are accepted. Please rephrase.", retrieved_context="") result = MathRetrievalQA().forward(question) return result if result.answer else WebFallbackQA().forward(question) router = MathRouter() # === Feedback Storage === def store_feedback(question, answer, feedback, correct_answer): entry = { "question": question, "model_answer": answer, "feedback": feedback, "correct_answer": correct_answer, "timestamp": str(datetime.now()) } print("Storing feedback:", entry) with open("feedback.json", "a") as f: f.write(json.dumps(entry) + "\n") # === Gradio Functions === def ask_question(question): print("ask_question() called with:", question) result = router.forward(question) print("Result:", result) return result.answer, question, result.answer def submit_feedback(question, model_answer, feedback, correct_answer): store_feedback(question, model_answer, feedback, correct_answer) return "✅ Feedback received. Thank you!" # === Gradio UI === with gr.Blocks() as demo: gr.Markdown("## 🧮 Math Question Answering with DSPy + Feedback") with gr.Tab("Ask a Math Question"): with gr.Row(): question_input = gr.Textbox(label="Enter your math question", lines=2) answer_output = gr.Markdown(label="Answer") hidden_q = gr.Textbox(visible=False) hidden_a = gr.Textbox(visible=False) submit_btn = gr.Button("Get Answer") submit_btn.click(fn=ask_question, inputs=[question_input], outputs=[answer_output, hidden_q, hidden_a]) with gr.Tab("Submit Feedback"): gr.Markdown("### Was the answer helpful?") fb_question = gr.Textbox(label="Original Question") fb_answer = gr.Textbox(label="Model's Answer") fb_like = gr.Radio(["👍", "👎"], label="Your Feedback") fb_correct = gr.Textbox(label="Correct Answer (optional)") fb_submit_btn = gr.Button("Submit Feedback") fb_status = gr.Textbox(label="Status", interactive=False) fb_submit_btn.click(fn=submit_feedback, inputs=[fb_question, fb_answer, fb_like, fb_correct], outputs=[fb_status]) demo.launch(share=True, debug=True)