Spaces:
Sleeping
Sleeping
File size: 5,069 Bytes
5775448 9ec24d8 b7b20e2 af77c21 b7b20e2 af77c21 b7b20e2 9ec24d8 b7b20e2 9ec24d8 b7b20e2 9ec24d8 af77c21 b7b20e2 9ec24d8 b7b20e2 af77c21 b7b20e2 af77c21 b7b20e2 02f7269 b7b20e2 af77c21 b7b20e2 af77c21 b7b20e2 af77c21 63e04ea af77c21 b7b20e2 af77c21 a6a2ff2 af77c21 a6a2ff2 af77c21 a6a2ff2 af77c21 a6a2ff2 af77c21 a6a2ff2 af77c21 a6a2ff2 b7b20e2 af77c21 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 |
import gradio as gr
import torch
import requests
import re
from datetime import datetime
import json
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
from sentence_transformers import SentenceTransformer
from qdrant_client import QdrantClient
# === Load Models ===
print("Loading zero-shot classifier...")
classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
print("Loading embedding model...")
embedding_model = SentenceTransformer("intfloat/e5-large")
print("Loading WizardMath model...")
tokenizer = AutoTokenizer.from_pretrained("WizardLM/WizardMath-7B-V1.1")
model = AutoModelForCausalLM.from_pretrained(
"WizardLM/WizardMath-7B-V1.1", torch_dtype=torch.float16, device_map="auto"
)
# === Qdrant Setup ===
print("Connecting to Qdrant...")
qdrant_client = QdrantClient(path="qdrant_data")
collection_name = "math_problems"
# === Guard Functions ===
def is_valid_math_question(text):
candidate_labels = ["math", "not math"]
result = classifier(text, candidate_labels)
return result['labels'][0] == "math" and result['scores'][0] > 0.7
def output_guardrails(answer):
if not answer or len(answer.strip()) < 10:
return False
math_keywords = ["solve", "equation", "integral", "derivative", "value", "expression", "steps", "solution"]
if not any(word in answer.lower() for word in math_keywords):
return False
banned_keywords = ["kill", "bomb", "hate", "politics", "violence"]
if any(word in answer.lower() for word in banned_keywords):
return False
if re.match(r"^\s*I'm just a model|Sorry, I can't|As an AI", answer, re.IGNORECASE):
return False
return True
# === Retrieval ===
def retrieve_from_qdrant(query):
query_vector = embedding_model.encode(query).tolist()
hits = qdrant_client.search(collection_name=collection_name, query_vector=query_vector, limit=3)
return [hit.payload for hit in hits] if hits else []
# === Web Search ===
def web_search_tavily(query):
TAVILY_API_KEY = "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"
response = requests.post(
"https://api.tavily.com/search",
json={"api_key": TAVILY_API_KEY, "query": query, "search_depth": "advanced"},
)
return response.json().get("answer", "No answer found from Tavily.")
# === Answer Generation ===
def generate_step_by_step_answer(question, context=""):
prompt = f"### Question:\n{question}\n"
if context:
prompt += f"### Context:\n{context}\n"
prompt += "### Let's solve it step by step:\n"
inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
outputs = model.generate(
**inputs,
max_new_tokens=256,
temperature=0.7,
top_p=0.95,
do_sample=True,
pad_token_id=tokenizer.eos_token_id
)
decoded = tokenizer.decode(outputs[0], skip_special_tokens=True)
answer = decoded.split("### Let's solve it step by step:")[-1].strip()
return answer
# === Router ===
def router(question):
if not is_valid_math_question(question):
return "โ Only math questions are accepted. Please rephrase."
context_items = retrieve_from_qdrant(question)
context = "\n".join([item.get("solution", "") for item in context_items])
if context:
answer = generate_step_by_step_answer(question, context)
if output_guardrails(answer):
return answer
answer = web_search_tavily(question)
return answer if output_guardrails(answer) else "โ ๏ธ No valid math answer found."
# === Feedback Storage ===
def store_feedback(question, answer, feedback, correct_answer):
entry = {
"question": question,
"model_answer": answer,
"feedback": feedback,
"correct_answer": correct_answer,
"timestamp": str(datetime.now())
}
with open("feedback.json", "a") as f:
f.write(json.dumps(entry) + "\n")
# === Gradio UI ===
def ask_question(question):
answer = router(question)
return answer, question, answer
def submit_feedback(question, model_answer, feedback):
store_feedback(question, model_answer, feedback, "")
return "โ
Feedback received. Thank you!"
with gr.Blocks() as demo:
gr.Markdown("## ๐งฎ Math Tutor with AI Guardrails + Feedback")
with gr.Row():
question_input = gr.Textbox(label="Enter your math question", lines=2)
submit_btn = gr.Button("Get Answer")
answer_output = gr.Markdown()
hidden_q = gr.Textbox(visible=False)
hidden_a = gr.Textbox(visible=False)
submit_btn.click(fn=ask_question, inputs=[question_input], outputs=[answer_output, hidden_q, hidden_a])
gr.Markdown("### ๐ Feedback")
fb_like = gr.Radio(["๐", "๐"], label="Was this answer helpful?")
fb_submit_btn = gr.Button("Submit Feedback")
fb_status = gr.Textbox(label="Status", interactive=False)
fb_submit_btn.click(fn=submit_feedback,
inputs=[hidden_q, hidden_a, fb_like],
outputs=[fb_status])
demo.launch(share=True, debug=True) |