Spaces:
Sleeping
Sleeping
File size: 4,233 Bytes
5775448 9ec24d8 b7b20e2 af77c21 5eea801 af77c21 b7b20e2 9ec24d8 b7b20e2 9ec24d8 b7b20e2 9ec24d8 5eea801 af77c21 b7b20e2 9ec24d8 b7b20e2 5eea801 b7b20e2 5eea801 b7b20e2 5eea801 b7b20e2 5eea801 b7b20e2 02f7269 b7b20e2 5eea801 af77c21 5eea801 af77c21 5eea801 af77c21 5eea801 af77c21 5eea801 b7b20e2 5eea801 b7b20e2 5eea801 b7b20e2 5eea801 af77c21 63e04ea 5eea801 b7b20e2 5eea801 b7b20e2 5eea801 a6a2ff2 af77c21 a6a2ff2 5eea801 a6a2ff2 af77c21 a6a2ff2 5eea801 af77c21 a6a2ff2 5eea801 b7b20e2 5eea801 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 |
import gradio as gr
import torch
import requests
import json
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
from sentence_transformers import SentenceTransformer
from qdrant_client import QdrantClient
from datetime import datetime
# === Load Models ===
print("Loading zero-shot classifier...")
classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
print("Loading embedding model...")
embedding_model = SentenceTransformer("intfloat/e5-large")
print("Loading step-by-step generator...")
tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2")
model = AutoModelForCausalLM.from_pretrained("microsoft/phi-2")
# === Qdrant Setup ===
print("Connecting to Qdrant...")
qdrant_client = QdrantClient(path="qdrant_data")
collection_name = "math_problems"
# === Guard Function ===
def is_valid_math_question(text):
candidate_labels = ["math", "not math"]
result = classifier(text, candidate_labels)
return result['labels'][0] == "math" and result['scores'][0] > 0.7
# === Retrieval from Qdrant ===
def retrieve_from_qdrant(query):
query_vector = embedding_model.encode(query).tolist()
hits = qdrant_client.query_points(
collection_name=collection_name,
query_vector=query_vector,
limit=3
)
return [hit.payload for hit in hits] if hits else []
# === Web Search Fallback ===
def web_search_tavily(query):
TAVILY_API_KEY = "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"
response = requests.post(
"https://api.tavily.com/search",
json={"api_key": TAVILY_API_KEY, "query": query, "search_depth": "advanced"},
)
return response.json().get("answer", "No answer found from Tavily.")
# === Generator ===
def generate_step_by_step_answer(question, context):
prompt = f"Answer the following math question step-by-step:\nQuestion: {question}\nContext: {context}\nAnswer:"
inputs = tokenizer(prompt, return_tensors="pt")
outputs = model.generate(
**inputs,
max_new_tokens=256,
temperature=0.7,
top_p=0.95,
do_sample=True,
pad_token_id=tokenizer.eos_token_id
)
return tokenizer.decode(outputs[0], skip_special_tokens=True)
# === Router ===
def router(question):
if not is_valid_math_question(question):
return "โ Only math questions are accepted. Please rephrase.", ""
retrieved = retrieve_from_qdrant(question)
context = "\n".join([item["solution"] for item in retrieved if "solution" in item])
if context:
answer = generate_step_by_step_answer(question, context)
return answer, context
else:
fallback = web_search_tavily(question)
return fallback, "Tavily Search"
# === Feedback Storage ===
def store_feedback(question, answer, correct_answer):
entry = {
"question": question,
"model_answer": answer,
"correct_answer": correct_answer,
"timestamp": str(datetime.now())
}
with open("feedback.json", "a") as f:
f.write(json.dumps(entry) + "\n")
# === Gradio Functions ===
def ask_question(question):
answer, context = router(question)
return answer, question, answer
def submit_feedback(question, model_answer, correct_answer):
store_feedback(question, model_answer, correct_answer)
return "โ
Feedback received. Thank you!"
# === Gradio UI ===
with gr.Blocks() as demo:
gr.Markdown("## ๐งฎ Math Question Answering with Retrieval + Feedback")
with gr.Row():
question_input = gr.Textbox(label="Enter your math question", lines=2)
submit_btn = gr.Button("Get Answer")
answer_output = gr.Markdown(label="Answer")
hidden_q = gr.Textbox(visible=False)
hidden_a = gr.Textbox(visible=False)
submit_btn.click(fn=ask_question, inputs=[question_input], outputs=[answer_output, hidden_q, hidden_a])
gr.Markdown("### ๐ Submit Feedback")
fb_correct = gr.Textbox(label="Correct Answer (optional)")
fb_submit = gr.Button("Submit Feedback")
fb_status = gr.Textbox(label="Status", interactive=False)
fb_submit.click(
fn=submit_feedback,
inputs=[hidden_q, hidden_a, fb_correct],
outputs=[fb_status]
)
demo.launch(share=True)
|