Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
import requests | |
import json | |
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM | |
from sentence_transformers import SentenceTransformer | |
from qdrant_client import QdrantClient | |
from datetime import datetime | |
# === Load Models === | |
print("Loading zero-shot classifier...") | |
classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli") | |
print("Loading embedding model...") | |
embedding_model = SentenceTransformer("intfloat/e5-large") | |
print("Loading step-by-step generator...") | |
tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2") | |
model = AutoModelForCausalLM.from_pretrained("microsoft/phi-2") | |
# === Qdrant Setup === | |
print("Connecting to Qdrant...") | |
qdrant_client = QdrantClient(path="qdrant_data") | |
collection_name = "math_problems" | |
# === Guard Function === | |
def is_valid_math_question(text): | |
candidate_labels = ["math", "not math"] | |
result = classifier(text, candidate_labels) | |
return result['labels'][0] == "math" and result['scores'][0] > 0.7 | |
# === Retrieval from Qdrant === | |
def retrieve_from_qdrant(query): | |
query_vector = embedding_model.encode(query).tolist() | |
hits = qdrant_client.query_points( | |
collection_name=collection_name, | |
query_vector=query_vector, | |
limit=3 | |
) | |
return [hit.payload for hit in hits] if hits else [] | |
# === Web Search Fallback === | |
def web_search_tavily(query): | |
TAVILY_API_KEY = "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" | |
response = requests.post( | |
"https://api.tavily.com/search", | |
json={"api_key": TAVILY_API_KEY, "query": query, "search_depth": "advanced"}, | |
) | |
return response.json().get("answer", "No answer found from Tavily.") | |
# === Generator === | |
def generate_step_by_step_answer(question, context): | |
prompt = f"Answer the following math question step-by-step:\nQuestion: {question}\nContext: {context}\nAnswer:" | |
inputs = tokenizer(prompt, return_tensors="pt") | |
outputs = model.generate( | |
**inputs, | |
max_new_tokens=256, | |
temperature=0.7, | |
top_p=0.95, | |
do_sample=True, | |
pad_token_id=tokenizer.eos_token_id | |
) | |
return tokenizer.decode(outputs[0], skip_special_tokens=True) | |
# === Router === | |
def router(question): | |
if not is_valid_math_question(question): | |
return "โ Only math questions are accepted. Please rephrase.", "" | |
retrieved = retrieve_from_qdrant(question) | |
context = "\n".join([item["solution"] for item in retrieved if "solution" in item]) | |
if context: | |
answer = generate_step_by_step_answer(question, context) | |
return answer, context | |
else: | |
fallback = web_search_tavily(question) | |
return fallback, "Tavily Search" | |
# === Feedback Storage === | |
def store_feedback(question, answer, correct_answer): | |
entry = { | |
"question": question, | |
"model_answer": answer, | |
"correct_answer": correct_answer, | |
"timestamp": str(datetime.now()) | |
} | |
with open("feedback.json", "a") as f: | |
f.write(json.dumps(entry) + "\n") | |
# === Gradio Functions === | |
def ask_question(question): | |
answer, context = router(question) | |
return answer, question, answer | |
def submit_feedback(question, model_answer, correct_answer): | |
store_feedback(question, model_answer, correct_answer) | |
return "โ Feedback received. Thank you!" | |
# === Gradio UI === | |
with gr.Blocks() as demo: | |
gr.Markdown("## ๐งฎ Math Question Answering with Retrieval + Feedback") | |
with gr.Row(): | |
question_input = gr.Textbox(label="Enter your math question", lines=2) | |
submit_btn = gr.Button("Get Answer") | |
answer_output = gr.Markdown(label="Answer") | |
hidden_q = gr.Textbox(visible=False) | |
hidden_a = gr.Textbox(visible=False) | |
submit_btn.click(fn=ask_question, inputs=[question_input], outputs=[answer_output, hidden_q, hidden_a]) | |
gr.Markdown("### ๐ Submit Feedback") | |
fb_correct = gr.Textbox(label="Correct Answer (optional)") | |
fb_submit = gr.Button("Submit Feedback") | |
fb_status = gr.Textbox(label="Status", interactive=False) | |
fb_submit.click( | |
fn=submit_feedback, | |
inputs=[hidden_q, hidden_a, fb_correct], | |
outputs=[fb_status] | |
) | |
demo.launch(share=True) | |