Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
import requests | |
import re | |
from datetime import datetime | |
import json | |
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM | |
from sentence_transformers import SentenceTransformer | |
from qdrant_client import QdrantClient | |
# === Load Models === | |
print("Loading zero-shot classifier...") | |
classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli") | |
print("Loading embedding model...") | |
embedding_model = SentenceTransformer("intfloat/e5-large") | |
print("Loading WizardMath model...") | |
tokenizer = AutoTokenizer.from_pretrained("WizardLM/WizardMath-7B-V1.1") | |
model = AutoModelForCausalLM.from_pretrained( | |
"WizardLM/WizardMath-7B-V1.1", torch_dtype=torch.float16, device_map="auto" | |
) | |
# === Qdrant Setup === | |
print("Connecting to Qdrant...") | |
qdrant_client = QdrantClient(path="qdrant_data") | |
collection_name = "math_problems" | |
# === Guard Functions === | |
def is_valid_math_question(text): | |
candidate_labels = ["math", "not math"] | |
result = classifier(text, candidate_labels) | |
return result['labels'][0] == "math" and result['scores'][0] > 0.7 | |
def output_guardrails(answer): | |
if not answer or len(answer.strip()) < 10: | |
return False | |
math_keywords = ["solve", "equation", "integral", "derivative", "value", "expression", "steps", "solution"] | |
if not any(word in answer.lower() for word in math_keywords): | |
return False | |
banned_keywords = ["kill", "bomb", "hate", "politics", "violence"] | |
if any(word in answer.lower() for word in banned_keywords): | |
return False | |
if re.match(r"^\s*I'm just a model|Sorry, I can't|As an AI", answer, re.IGNORECASE): | |
return False | |
return True | |
# === Retrieval === | |
def retrieve_from_qdrant(query): | |
query_vector = embedding_model.encode(query).tolist() | |
hits = qdrant_client.search(collection_name=collection_name, query_vector=query_vector, limit=3) | |
return [hit.payload for hit in hits] if hits else [] | |
# === Web Search === | |
def web_search_tavily(query): | |
TAVILY_API_KEY = "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" | |
response = requests.post( | |
"https://api.tavily.com/search", | |
json={"api_key": TAVILY_API_KEY, "query": query, "search_depth": "advanced"}, | |
) | |
return response.json().get("answer", "No answer found from Tavily.") | |
# === Answer Generation === | |
def generate_step_by_step_answer(question, context=""): | |
prompt = f"### Question:\n{question}\n" | |
if context: | |
prompt += f"### Context:\n{context}\n" | |
prompt += "### Let's solve it step by step:\n" | |
inputs = tokenizer(prompt, return_tensors="pt").to("cuda") | |
outputs = model.generate( | |
**inputs, | |
max_new_tokens=256, | |
temperature=0.7, | |
top_p=0.95, | |
do_sample=True, | |
pad_token_id=tokenizer.eos_token_id | |
) | |
decoded = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
answer = decoded.split("### Let's solve it step by step:")[-1].strip() | |
return answer | |
# === Router === | |
def router(question): | |
if not is_valid_math_question(question): | |
return "โ Only math questions are accepted. Please rephrase." | |
context_items = retrieve_from_qdrant(question) | |
context = "\n".join([item.get("solution", "") for item in context_items]) | |
if context: | |
answer = generate_step_by_step_answer(question, context) | |
if output_guardrails(answer): | |
return answer | |
answer = web_search_tavily(question) | |
return answer if output_guardrails(answer) else "โ ๏ธ No valid math answer found." | |
# === Feedback Storage === | |
def store_feedback(question, answer, feedback, correct_answer): | |
entry = { | |
"question": question, | |
"model_answer": answer, | |
"feedback": feedback, | |
"correct_answer": correct_answer, | |
"timestamp": str(datetime.now()) | |
} | |
with open("feedback.json", "a") as f: | |
f.write(json.dumps(entry) + "\n") | |
# === Gradio UI === | |
def ask_question(question): | |
answer = router(question) | |
return answer, question, answer | |
def submit_feedback(question, model_answer, feedback): | |
store_feedback(question, model_answer, feedback, "") | |
return "โ Feedback received. Thank you!" | |
with gr.Blocks() as demo: | |
gr.Markdown("## ๐งฎ Math Tutor with AI Guardrails + Feedback") | |
with gr.Row(): | |
question_input = gr.Textbox(label="Enter your math question", lines=2) | |
submit_btn = gr.Button("Get Answer") | |
answer_output = gr.Markdown() | |
hidden_q = gr.Textbox(visible=False) | |
hidden_a = gr.Textbox(visible=False) | |
submit_btn.click(fn=ask_question, inputs=[question_input], outputs=[answer_output, hidden_q, hidden_a]) | |
gr.Markdown("### ๐ Feedback") | |
fb_like = gr.Radio(["๐", "๐"], label="Was this answer helpful?") | |
fb_submit_btn = gr.Button("Submit Feedback") | |
fb_status = gr.Textbox(label="Status", interactive=False) | |
fb_submit_btn.click(fn=submit_feedback, | |
inputs=[hidden_q, hidden_a, fb_like], | |
outputs=[fb_status]) | |
demo.launch(share=True, debug=True) |