File size: 4,233 Bytes
5775448
9ec24d8
b7b20e2
 
 
af77c21
 
 
5eea801
af77c21
b7b20e2
9ec24d8
b7b20e2
9ec24d8
 
b7b20e2
9ec24d8
5eea801
 
 
af77c21
b7b20e2
9ec24d8
b7b20e2
 
 
5eea801
b7b20e2
 
 
 
 
5eea801
b7b20e2
 
5eea801
 
 
 
 
b7b20e2
 
5eea801
b7b20e2
02f7269
b7b20e2
 
 
 
 
 
5eea801
 
 
 
af77c21
 
 
 
 
 
 
 
5eea801
af77c21
 
 
 
5eea801
af77c21
5eea801
 
af77c21
 
5eea801
 
 
 
b7b20e2
 
5eea801
b7b20e2
 
 
 
 
 
 
 
 
5eea801
b7b20e2
5eea801
af77c21
63e04ea
5eea801
 
b7b20e2
 
5eea801
b7b20e2
5eea801
a6a2ff2
 
 
af77c21
a6a2ff2
5eea801
a6a2ff2
 
af77c21
a6a2ff2
 
5eea801
 
 
af77c21
a6a2ff2
5eea801
 
 
 
 
b7b20e2
5eea801
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126

import gradio as gr
import torch
import requests
import json
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
from sentence_transformers import SentenceTransformer
from qdrant_client import QdrantClient
from datetime import datetime

# === Load Models ===
print("Loading zero-shot classifier...")
classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")

print("Loading embedding model...")
embedding_model = SentenceTransformer("intfloat/e5-large")

print("Loading step-by-step generator...")
tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2")
model = AutoModelForCausalLM.from_pretrained("microsoft/phi-2")

# === Qdrant Setup ===
print("Connecting to Qdrant...")
qdrant_client = QdrantClient(path="qdrant_data")
collection_name = "math_problems"

# === Guard Function ===
def is_valid_math_question(text):
    candidate_labels = ["math", "not math"]
    result = classifier(text, candidate_labels)
    return result['labels'][0] == "math" and result['scores'][0] > 0.7

# === Retrieval from Qdrant ===
def retrieve_from_qdrant(query):
    query_vector = embedding_model.encode(query).tolist()
    hits = qdrant_client.query_points(
        collection_name=collection_name,
        query_vector=query_vector,
        limit=3
    )
    return [hit.payload for hit in hits] if hits else []

# === Web Search Fallback ===
def web_search_tavily(query):
    TAVILY_API_KEY = "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"
    response = requests.post(
        "https://api.tavily.com/search",
        json={"api_key": TAVILY_API_KEY, "query": query, "search_depth": "advanced"},
    )
    return response.json().get("answer", "No answer found from Tavily.")

# === Generator ===
def generate_step_by_step_answer(question, context):
    prompt = f"Answer the following math question step-by-step:\nQuestion: {question}\nContext: {context}\nAnswer:"
    inputs = tokenizer(prompt, return_tensors="pt")
    outputs = model.generate(
        **inputs,
        max_new_tokens=256,
        temperature=0.7,
        top_p=0.95,
        do_sample=True,
        pad_token_id=tokenizer.eos_token_id
    )
    return tokenizer.decode(outputs[0], skip_special_tokens=True)

# === Router ===
def router(question):
    if not is_valid_math_question(question):
        return "โŒ Only math questions are accepted. Please rephrase.", ""

    retrieved = retrieve_from_qdrant(question)
    context = "\n".join([item["solution"] for item in retrieved if "solution" in item])
    if context:
        answer = generate_step_by_step_answer(question, context)
        return answer, context
    else:
        fallback = web_search_tavily(question)
        return fallback, "Tavily Search"

# === Feedback Storage ===
def store_feedback(question, answer, correct_answer):
    entry = {
        "question": question,
        "model_answer": answer,
        "correct_answer": correct_answer,
        "timestamp": str(datetime.now())
    }
    with open("feedback.json", "a") as f:
        f.write(json.dumps(entry) + "\n")

# === Gradio Functions ===
def ask_question(question):
    answer, context = router(question)
    return answer, question, answer

def submit_feedback(question, model_answer, correct_answer):
    store_feedback(question, model_answer, correct_answer)
    return "โœ… Feedback received. Thank you!"

# === Gradio UI ===
with gr.Blocks() as demo:
    gr.Markdown("## ๐Ÿงฎ Math Question Answering with Retrieval + Feedback")

    with gr.Row():
        question_input = gr.Textbox(label="Enter your math question", lines=2)
        submit_btn = gr.Button("Get Answer")

    answer_output = gr.Markdown(label="Answer")
    hidden_q = gr.Textbox(visible=False)
    hidden_a = gr.Textbox(visible=False)

    submit_btn.click(fn=ask_question, inputs=[question_input], outputs=[answer_output, hidden_q, hidden_a])

    gr.Markdown("### ๐Ÿ“ Submit Feedback")
    fb_correct = gr.Textbox(label="Correct Answer (optional)")
    fb_submit = gr.Button("Submit Feedback")
    fb_status = gr.Textbox(label="Status", interactive=False)

    fb_submit.click(
        fn=submit_feedback,
        inputs=[hidden_q, hidden_a, fb_correct],
        outputs=[fb_status]
    )

demo.launch(share=True)