File size: 5,069 Bytes
5775448
9ec24d8
b7b20e2
 
af77c21
b7b20e2
 
 
af77c21
 
 
 
b7b20e2
9ec24d8
b7b20e2
9ec24d8
 
b7b20e2
9ec24d8
af77c21
 
 
 
 
 
b7b20e2
9ec24d8
b7b20e2
 
 
af77c21
b7b20e2
 
 
 
 
af77c21
 
 
 
 
 
 
 
 
 
 
 
 
b7b20e2
 
 
 
 
 
 
 
02f7269
b7b20e2
 
 
 
 
 
af77c21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b7b20e2
 
 
 
 
 
 
 
 
 
 
 
 
af77c21
b7b20e2
af77c21
 
63e04ea
af77c21
 
b7b20e2
 
 
af77c21
a6a2ff2
 
 
af77c21
a6a2ff2
 
 
 
af77c21
a6a2ff2
 
af77c21
 
a6a2ff2
af77c21
a6a2ff2
 
af77c21
a6a2ff2
b7b20e2
af77c21
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144

import gradio as gr
import torch
import requests
import re
from datetime import datetime
import json

from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
from sentence_transformers import SentenceTransformer
from qdrant_client import QdrantClient

# === Load Models ===
print("Loading zero-shot classifier...")
classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")

print("Loading embedding model...")
embedding_model = SentenceTransformer("intfloat/e5-large")

print("Loading WizardMath model...")
tokenizer = AutoTokenizer.from_pretrained("WizardLM/WizardMath-7B-V1.1")
model = AutoModelForCausalLM.from_pretrained(
    "WizardLM/WizardMath-7B-V1.1", torch_dtype=torch.float16, device_map="auto"
)

# === Qdrant Setup ===
print("Connecting to Qdrant...")
qdrant_client = QdrantClient(path="qdrant_data")
collection_name = "math_problems"

# === Guard Functions ===
def is_valid_math_question(text):
    candidate_labels = ["math", "not math"]
    result = classifier(text, candidate_labels)
    return result['labels'][0] == "math" and result['scores'][0] > 0.7

def output_guardrails(answer):
    if not answer or len(answer.strip()) < 10:
        return False
    math_keywords = ["solve", "equation", "integral", "derivative", "value", "expression", "steps", "solution"]
    if not any(word in answer.lower() for word in math_keywords):
        return False
    banned_keywords = ["kill", "bomb", "hate", "politics", "violence"]
    if any(word in answer.lower() for word in banned_keywords):
        return False
    if re.match(r"^\s*I'm just a model|Sorry, I can't|As an AI", answer, re.IGNORECASE):
        return False
    return True

# === Retrieval ===
def retrieve_from_qdrant(query):
    query_vector = embedding_model.encode(query).tolist()
    hits = qdrant_client.search(collection_name=collection_name, query_vector=query_vector, limit=3)
    return [hit.payload for hit in hits] if hits else []

# === Web Search ===
def web_search_tavily(query):
    TAVILY_API_KEY = "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"
    response = requests.post(
        "https://api.tavily.com/search",
        json={"api_key": TAVILY_API_KEY, "query": query, "search_depth": "advanced"},
    )
    return response.json().get("answer", "No answer found from Tavily.")

# === Answer Generation ===
def generate_step_by_step_answer(question, context=""):
    prompt = f"### Question:\n{question}\n"
    if context:
        prompt += f"### Context:\n{context}\n"
    prompt += "### Let's solve it step by step:\n"

    inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
    outputs = model.generate(
        **inputs,
        max_new_tokens=256,
        temperature=0.7,
        top_p=0.95,
        do_sample=True,
        pad_token_id=tokenizer.eos_token_id
    )
    decoded = tokenizer.decode(outputs[0], skip_special_tokens=True)
    answer = decoded.split("### Let's solve it step by step:")[-1].strip()
    return answer

# === Router ===
def router(question):
    if not is_valid_math_question(question):
        return "โŒ Only math questions are accepted. Please rephrase."

    context_items = retrieve_from_qdrant(question)
    context = "\n".join([item.get("solution", "") for item in context_items])

    if context:
        answer = generate_step_by_step_answer(question, context)
        if output_guardrails(answer):
            return answer

    answer = web_search_tavily(question)
    return answer if output_guardrails(answer) else "โš ๏ธ No valid math answer found."

# === Feedback Storage ===
def store_feedback(question, answer, feedback, correct_answer):
    entry = {
        "question": question,
        "model_answer": answer,
        "feedback": feedback,
        "correct_answer": correct_answer,
        "timestamp": str(datetime.now())
    }
    with open("feedback.json", "a") as f:
        f.write(json.dumps(entry) + "\n")

# === Gradio UI ===
def ask_question(question):
    answer = router(question)
    return answer, question, answer

def submit_feedback(question, model_answer, feedback):
    store_feedback(question, model_answer, feedback, "")
    return "โœ… Feedback received. Thank you!"

with gr.Blocks() as demo:
    gr.Markdown("## ๐Ÿงฎ Math Tutor with AI Guardrails + Feedback")

    with gr.Row():
        question_input = gr.Textbox(label="Enter your math question", lines=2)
        submit_btn = gr.Button("Get Answer")

    answer_output = gr.Markdown()
    hidden_q = gr.Textbox(visible=False)
    hidden_a = gr.Textbox(visible=False)

    submit_btn.click(fn=ask_question, inputs=[question_input], outputs=[answer_output, hidden_q, hidden_a])

    gr.Markdown("### ๐Ÿ“ Feedback")
    fb_like = gr.Radio(["๐Ÿ‘", "๐Ÿ‘Ž"], label="Was this answer helpful?")
    fb_submit_btn = gr.Button("Submit Feedback")
    fb_status = gr.Textbox(label="Status", interactive=False)

    fb_submit_btn.click(fn=submit_feedback,
                        inputs=[hidden_q, hidden_a, fb_like],
                        outputs=[fb_status])

demo.launch(share=True, debug=True)