Spaces:
Sleeping
Sleeping
File size: 7,110 Bytes
5775448 9ec24d8 b7b20e2 d16f9ab af77c21 5eea801 d16f9ab af77c21 b7b20e2 9ec24d8 b7b20e2 9ec24d8 b7b20e2 9ec24d8 d16f9ab af77c21 b7b20e2 9ec24d8 b7b20e2 5eea801 b7b20e2 d16f9ab b7b20e2 d16f9ab b7b20e2 d16f9ab b7b20e2 d16f9ab b7b20e2 d16f9ab b7b20e2 d16f9ab 02f7269 b7b20e2 d16f9ab b7b20e2 d16f9ab b7b20e2 d16f9ab b7b20e2 d16f9ab b7b20e2 d16f9ab 5eea801 b7b20e2 d16f9ab 63e04ea d16f9ab b7b20e2 5eea801 b7b20e2 d16f9ab af77c21 d16f9ab a6a2ff2 d16f9ab 87cc698 b7b20e2 87cc698 d16f9ab |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 |
import gradio as gr
import torch
import requests
from transformers import pipeline
from sentence_transformers import SentenceTransformer
from qdrant_client import QdrantClient
from datetime import datetime
import dspy
import json
# === Load Models ===
print("Loading zero-shot classifier...")
classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
print("Loading embedding model...")
embedding_model = SentenceTransformer("intfloat/e5-large")
print("Loading text generation model...")
# Use a lighter model for testing
#qa_pipeline = pipeline("text-generation", model="gpt2")
# === Qdrant Setup ===
print("Connecting to Qdrant...")
qdrant_client = QdrantClient(path="qdrant_data")
collection_name = "math_problems"
# === Guard Function ===
def is_valid_math_question(text):
candidate_labels = ["math", "not math"]
result = classifier(text, candidate_labels)
print("Classifier result:", result)
return result['labels'][0] == "math" and result['scores'][0] > 0.7
# === Retrieval ===
def retrieve_from_qdrant(query):
print("Retrieving context from Qdrant...")
query_vector = embedding_model.encode(query).tolist()
hits = qdrant_client.search(collection_name=collection_name, query_vector=query_vector, limit=3)
print("Retrieved hits:", hits)
return [hit.payload for hit in hits] if hits else []
# === Web Search ===
def web_search_tavily(query):
print("Calling Tavily...")
TAVILY_API_KEY = "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"
response = requests.post(
"https://api.tavily.com/search",
json={"api_key": TAVILY_API_KEY, "query": query, "search_depth": "advanced"},
)
return response.json().get("answer", "No answer found from Tavily.")
# === DSPy Signature ===
class MathAnswer(dspy.Signature):
question = dspy.InputField()
retrieved_context = dspy.InputField()
answer = dspy.OutputField()
# === DSPy Programs ===
# === DSPy Programs with Output Guard ===
class MathRetrievalQA(dspy.Program):
def forward(self, question):
print("Inside MathRetrievalQA...")
context_items = retrieve_from_qdrant(question)
context = "\n".join([item["solution"] for item in context_items if "solution" in item])
print("Context for generation:", context)
if not context:
return {"answer": "", "retrieved_context": ""}
# === Replace below with real model call when ready ===
prompt = f"Question: {question}\nContext: {context}\nAnswer:"
print("Prompt for generation:", prompt)
# TEMP answer (replace with real generated output)
generated_answer = "This is a placeholder answer based on the context." # Simulated generation
print("Generated answer:", generated_answer)
# === Output Guard ===
if not generated_answer or len(generated_answer.strip()) < 10 or "I don't know" in generated_answer:
return {"answer": "", "retrieved_context": context}
return {"answer": generated_answer.strip(), "retrieved_context": context}
class WebFallbackQA(dspy.Program):
def forward(self, question):
print("Fallback to Tavily...")
answer = web_search_tavily(question)
if not answer or len(answer.strip()) < 10 or "No answer found" in answer:
answer = "โ Sorry, I couldn't find a reliable answer."
return {"answer": answer.strip(), "retrieved_context": "Tavily"}
class MathRouter(dspy.Program):
def forward(self, question):
print("Routing question:", question)
if not is_valid_math_question(question):
return {"answer": "โ Only math questions are accepted. Please rephrase.", "retrieved_context": ""}
result = MathRetrievalQA().forward(question)
if result["answer"]:
return result
else:
return WebFallbackQA().forward(question)
# === Feedback Storage ===
def store_feedback(question, answer, feedback, correct_answer):
entry = {
"question": question,
"model_answer": answer,
"feedback": feedback,
"correct_answer": correct_answer,
"timestamp": str(datetime.now())
}
print("Storing feedback:", entry)
with open("feedback.json", "a") as f:
f.write(json.dumps(entry) + "\n")
def load_feedback_entries():
entries = []
try:
with open("feedback.json", "r") as f:
for line in f:
entry = json.loads(line)
entries.append(entry)
except FileNotFoundError:
pass
return entries
# === Gradio Functions ===
# === Gradio Functions ===
def ask_question(question):
print("ask_question() called with:", question)
result = router.forward(question)
print("Result:", result)
return result["answer"], question, result["answer"]
def submit_feedback(question, model_answer, feedback, correct_answer):
store_feedback(question, model_answer, feedback, correct_answer)
return "โ
Feedback received. Thank you!"
# === Gradio UI ===
with gr.Blocks() as demo:
gr.Markdown("## ๐งฎ Math Question Answering with DSPy + Feedback")
with gr.Tab("Ask a Math Question"):
with gr.Row():
question_input = gr.Textbox(label="Enter your math question", lines=2)
gr.Markdown("### ๐ง Answer:")
answer_output = gr.Markdown()
#answer_output = gr.Markdown(label="Answer")
hidden_q = gr.Textbox(visible=False)
hidden_a = gr.Textbox(visible=False)
submit_btn = gr.Button("Get Answer")
submit_btn.click(fn=ask_question, inputs=[question_input], outputs=[answer_output, hidden_q, hidden_a])
with gr.Tab("Submit Feedback"):
gr.Markdown("### Was the answer helpful?")
fb_question = gr.Textbox(label="Original Question")
fb_answer = gr.Textbox(label="Model's Answer")
fb_like = gr.Radio(["๐", "๐"], label="Your Feedback")
fb_correct = gr.Textbox(label="Correct Answer (optional)")
fb_submit_btn = gr.Button("Submit Feedback")
fb_status = gr.Textbox(label="Status", interactive=False)
feedback_display = gr.Dataframe(headers=["Question", "Answer", "Feedback", "Correct Answer", "Timestamp"],
row_count=10, max_rows=50, wrap=True)
def feedback_submission_and_display(question, answer, feedback, correct_answer):
store_feedback(question, answer, feedback, correct_answer)
entries = load_feedback_entries()
display_rows = [[
e["question"],
e["model_answer"],
e["feedback"],
e["correct_answer"],
e["timestamp"]
] for e in entries]
return "โ
Feedback received. Thank you!", display_rows
fb_submit_btn.click(
fn=feedback_submission_and_display,
inputs=[fb_question, fb_answer, fb_like, fb_correct],
outputs=[fb_status, feedback_display]
)
demo.launch(share=True, debug=True)
|