Spaces:
Sleeping
Sleeping
File size: 5,357 Bytes
5775448 9ec24d8 b7b20e2 5775448 b7b20e2 9ec24d8 b7b20e2 9ec24d8 b7b20e2 9ec24d8 b7b20e2 9ec24d8 b7b20e2 9ec24d8 b7b20e2 9ec24d8 b7b20e2 9ec24d8 b7b20e2 9ec24d8 02f7269 b7b20e2 9ec24d8 b7b20e2 9ec24d8 b7b20e2 9ec24d8 b7b20e2 9ec24d8 b7b20e2 9ec24d8 b7b20e2 9ec24d8 b7b20e2 9ec24d8 b7b20e2 9ec24d8 b7b20e2 9ec24d8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 |
import gradio as gr
import torch
import requests
from transformers import pipeline
from sentence_transformers import SentenceTransformer
from qdrant_client import QdrantClient
from datetime import datetime
import dspy
import json
# === Load Models ===
print("Loading zero-shot classifier...")
classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
print("Loading embedding model...")
embedding_model = SentenceTransformer("intfloat/e5-large")
print("Loading text generation model...")
# Use a lighter model for testing
qa_pipeline = pipeline("text-generation", model="gpt2")
# === Qdrant Setup ===
print("Connecting to Qdrant...")
qdrant_client = QdrantClient(path="qdrant_data")
collection_name = "math_problems"
# === Guard Function ===
def is_valid_math_question(text):
candidate_labels = ["math", "not math"]
result = classifier(text, candidate_labels)
print("Classifier result:", result)
return result['labels'][0] == "math" and result['scores'][0] > 0.7
# === Retrieval ===
def retrieve_from_qdrant(query):
print("Retrieving context from Qdrant...")
query_vector = embedding_model.encode(query).tolist()
hits = qdrant_client.search(collection_name=collection_name, query_vector=query_vector, limit=3)
print("Retrieved hits:", hits)
return [hit.payload for hit in hits] if hits else []
# === Web Search ===
def web_search_tavily(query):
print("Calling Tavily...")
TAVILY_API_KEY = "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"
response = requests.post(
"https://api.tavily.com/search",
json={"api_key": TAVILY_API_KEY, "query": query, "search_depth": "advanced"},
)
return response.json().get("answer", "No answer found from Tavily.")
# === DSPy Signature ===
class MathAnswer(dspy.Signature):
question = dspy.InputField()
retrieved_context = dspy.InputField()
answer = dspy.OutputField()
# === DSPy Programs ===
class MathRetrievalQA(dspy.Program):
def forward(self, question):
print("Inside MathRetrievalQA...")
context_items = retrieve_from_qdrant(question)
context = "\n".join([item["solution"] for item in context_items if "solution" in item])
print("Context for generation:", context)
if not context:
return dspy.Output(answer="", retrieved_context="")
prompt = f"Question: {question}\nContext: {context}\nAnswer:"
print("Generating answer...")
answer = qa_pipeline(prompt, max_new_tokens=100)[0]["generated_text"]
print("Generated answer:", answer)
return dspy.Output(answer=answer, retrieved_context=context)
class WebFallbackQA(dspy.Program):
def forward(self, question):
print("Fallback to Tavily...")
answer = web_search_tavily(question)
return dspy.Output(answer=answer, retrieved_context="Tavily")
class MathRouter(dspy.Program):
def forward(self, question):
print("Routing question:", question)
if not is_valid_math_question(question):
return dspy.Output(answer="โ Only math questions are accepted. Please rephrase.", retrieved_context="")
result = MathRetrievalQA().forward(question)
return result if result.answer else WebFallbackQA().forward(question)
router = MathRouter()
# === Feedback Storage ===
def store_feedback(question, answer, feedback, correct_answer):
entry = {
"question": question,
"model_answer": answer,
"feedback": feedback,
"correct_answer": correct_answer,
"timestamp": str(datetime.now())
}
print("Storing feedback:", entry)
with open("feedback.json", "a") as f:
f.write(json.dumps(entry) + "\n")
# === Gradio Functions ===
def ask_question(question):
print("ask_question() called with:", question)
result = router.forward(question)
print("Result:", result)
return result.answer, question, result.answer
def submit_feedback(question, model_answer, feedback, correct_answer):
store_feedback(question, model_answer, feedback, correct_answer)
return "โ
Feedback received. Thank you!"
# === Gradio UI ===
with gr.Blocks() as demo:
gr.Markdown("## ๐งฎ Math Question Answering with DSPy + Feedback")
with gr.Tab("Ask a Math Question"):
with gr.Row():
question_input = gr.Textbox(label="Enter your math question", lines=2)
answer_output = gr.Markdown(label="Answer")
hidden_q = gr.Textbox(visible=False)
hidden_a = gr.Textbox(visible=False)
submit_btn = gr.Button("Get Answer")
submit_btn.click(fn=ask_question, inputs=[question_input], outputs=[answer_output, hidden_q, hidden_a])
with gr.Tab("Submit Feedback"):
gr.Markdown("### Was the answer helpful?")
fb_question = gr.Textbox(label="Original Question")
fb_answer = gr.Textbox(label="Model's Answer")
fb_like = gr.Radio(["๐", "๐"], label="Your Feedback")
fb_correct = gr.Textbox(label="Correct Answer (optional)")
fb_submit_btn = gr.Button("Submit Feedback")
fb_status = gr.Textbox(label="Status", interactive=False)
fb_submit_btn.click(fn=submit_feedback,
inputs=[fb_question, fb_answer, fb_like, fb_correct],
outputs=[fb_status])
demo.launch(share=True, debug=True)
|