import gradio as gr import torch import requests import re from datetime import datetime import json from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM from sentence_transformers import SentenceTransformer from qdrant_client import QdrantClient # === Load Models === print("Loading zero-shot classifier...") classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli") print("Loading embedding model...") embedding_model = SentenceTransformer("intfloat/e5-large") print("Loading WizardMath model...") tokenizer = AutoTokenizer.from_pretrained("WizardLM/WizardMath-7B-V1.1") model = AutoModelForCausalLM.from_pretrained( "WizardLM/WizardMath-7B-V1.1", torch_dtype=torch.float16, device_map="auto" ) # === Qdrant Setup === print("Connecting to Qdrant...") qdrant_client = QdrantClient(path="qdrant_data") collection_name = "math_problems" # === Guard Functions === def is_valid_math_question(text): candidate_labels = ["math", "not math"] result = classifier(text, candidate_labels) return result['labels'][0] == "math" and result['scores'][0] > 0.7 def output_guardrails(answer): if not answer or len(answer.strip()) < 10: return False math_keywords = ["solve", "equation", "integral", "derivative", "value", "expression", "steps", "solution"] if not any(word in answer.lower() for word in math_keywords): return False banned_keywords = ["kill", "bomb", "hate", "politics", "violence"] if any(word in answer.lower() for word in banned_keywords): return False if re.match(r"^\s*I'm just a model|Sorry, I can't|As an AI", answer, re.IGNORECASE): return False return True # === Retrieval === def retrieve_from_qdrant(query): query_vector = embedding_model.encode(query).tolist() hits = qdrant_client.search(collection_name=collection_name, query_vector=query_vector, limit=3) return [hit.payload for hit in hits] if hits else [] # === Web Search === def web_search_tavily(query): TAVILY_API_KEY = "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" response = requests.post( "https://api.tavily.com/search", json={"api_key": TAVILY_API_KEY, "query": query, "search_depth": "advanced"}, ) return response.json().get("answer", "No answer found from Tavily.") # === Answer Generation === def generate_step_by_step_answer(question, context=""): prompt = f"### Question:\n{question}\n" if context: prompt += f"### Context:\n{context}\n" prompt += "### Let's solve it step by step:\n" inputs = tokenizer(prompt, return_tensors="pt").to("cuda") outputs = model.generate( **inputs, max_new_tokens=256, temperature=0.7, top_p=0.95, do_sample=True, pad_token_id=tokenizer.eos_token_id ) decoded = tokenizer.decode(outputs[0], skip_special_tokens=True) answer = decoded.split("### Let's solve it step by step:")[-1].strip() return answer # === Router === def router(question): if not is_valid_math_question(question): return "❌ Only math questions are accepted. Please rephrase." context_items = retrieve_from_qdrant(question) context = "\n".join([item.get("solution", "") for item in context_items]) if context: answer = generate_step_by_step_answer(question, context) if output_guardrails(answer): return answer answer = web_search_tavily(question) return answer if output_guardrails(answer) else "⚠️ No valid math answer found." # === Feedback Storage === def store_feedback(question, answer, feedback, correct_answer): entry = { "question": question, "model_answer": answer, "feedback": feedback, "correct_answer": correct_answer, "timestamp": str(datetime.now()) } with open("feedback.json", "a") as f: f.write(json.dumps(entry) + "\n") # === Gradio UI === def ask_question(question): answer = router(question) return answer, question, answer def submit_feedback(question, model_answer, feedback): store_feedback(question, model_answer, feedback, "") return "✅ Feedback received. Thank you!" with gr.Blocks() as demo: gr.Markdown("## 🧮 Math Tutor with AI Guardrails + Feedback") with gr.Row(): question_input = gr.Textbox(label="Enter your math question", lines=2) submit_btn = gr.Button("Get Answer") answer_output = gr.Markdown() hidden_q = gr.Textbox(visible=False) hidden_a = gr.Textbox(visible=False) submit_btn.click(fn=ask_question, inputs=[question_input], outputs=[answer_output, hidden_q, hidden_a]) gr.Markdown("### 📝 Feedback") fb_like = gr.Radio(["👍", "👎"], label="Was this answer helpful?") fb_submit_btn = gr.Button("Submit Feedback") fb_status = gr.Textbox(label="Status", interactive=False) fb_submit_btn.click(fn=submit_feedback, inputs=[hidden_q, hidden_a, fb_like], outputs=[fb_status]) demo.launch(share=True, debug=True)