File size: 4,665 Bytes
b7b20e2
5775448
b7b20e2
 
 
 
 
 
5775448
b7b20e2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import gradio as gr

import torch
import requests
from transformers import pipeline
from sentence_transformers import SentenceTransformer
from qdrant_client import QdrantClient
from datetime import datetime
import dspy
import json

# === Load Models ===
classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
embedding_model = SentenceTransformer("intfloat/e5-large")
qa_pipeline = pipeline("text-generation", model="WizardLM/WizardMath-7B-V1.0", device_map="auto", torch_dtype=torch.float16)

# === Qdrant Setup ===
qdrant_client = QdrantClient(path="qdrant_data")
collection_name = "math_problems"

# === Guard Function ===
def is_valid_math_question(text):
    candidate_labels = ["math", "not math"]
    result = classifier(text, candidate_labels)
    return result['labels'][0] == "math" and result['scores'][0] > 0.7

# === Retrieval ===
def retrieve_from_qdrant(query):
    query_vector = embedding_model.encode(query).tolist()
    hits = qdrant_client.search(collection_name=collection_name, query_vector=query_vector, limit=3)
    return [hit.payload for hit in hits] if hits else []

# === Web Search ===
def web_search_tavily(query):
    TAVILY_API_KEY = "your_tavily_api_key"
    response = requests.post(
        "https://api.tavily.com/search",
        json={"api_key": TAVILY_API_KEY, "query": query, "search_depth": "advanced"},
    )
    return response.json().get("answer", "No answer found from Tavily.")

# === DSPy Signature ===
class MathAnswer(dspy.Signature):
    question = dspy.InputField()
    retrieved_context = dspy.InputField()
    answer = dspy.OutputField()

# === DSPy Programs ===
class MathRetrievalQA(dspy.Program):
    def forward(self, question):
        context_items = retrieve_from_qdrant(question)
        context = "\n".join([item["solution"] for item in context_items if "solution" in item])
        if not context:
            return dspy.Output(answer="", retrieved_context="")
        prompt = f"Question: {question}\nContext: {context}\nAnswer:"
        answer = qa_pipeline(prompt, max_new_tokens=512)[0]["generated_text"]
        return dspy.Output(answer=answer, retrieved_context=context)

class WebFallbackQA(dspy.Program):
    def forward(self, question):
        answer = web_search_tavily(question)
        return dspy.Output(answer=answer, retrieved_context="Tavily")

class MathRouter(dspy.Program):
    def forward(self, question):
        if not is_valid_math_question(question):
            return dspy.Output(answer="โŒ Only math questions are accepted. Please rephrase.", retrieved_context="")
        result = MathRetrievalQA().forward(question)
        return result if result.answer else WebFallbackQA().forward(question)

router = MathRouter()

# === Feedback Storage ===
def store_feedback(question, answer, feedback, correct_answer):
    entry = {
        "question": question,
        "model_answer": answer,
        "feedback": feedback,
        "correct_answer": correct_answer,
        "timestamp": str(datetime.now())
    }
    with open("feedback.json", "a") as f:
        f.write(json.dumps(entry) + "\n")

# === Gradio Functions ===
def ask_question(question):
    result = router.forward(question)
    return result.answer, question, result.answer

def submit_feedback(question, model_answer, feedback, correct_answer):
    store_feedback(question, model_answer, feedback, correct_answer)
    return "โœ… Feedback received. Thank you!"

# === Gradio UI ===
with gr.Blocks() as demo:
    gr.Markdown("## ๐Ÿงฎ Math Question Answering with DSPy + Feedback")
    
    with gr.Tab("Ask a Math Question"):
        with gr.Row():
            question_input = gr.Textbox(label="Enter your math question", lines=2)
        answer_output = gr.Markdown(label="Answer")
        hidden_q = gr.Textbox(visible=False)
        hidden_a = gr.Textbox(visible=False)
        submit_btn = gr.Button("Get Answer")
        submit_btn.click(fn=ask_question, inputs=[question_input], outputs=[answer_output, hidden_q, hidden_a])

    with gr.Tab("Submit Feedback"):
        gr.Markdown("### Was the answer helpful?")
        fb_question = gr.Textbox(label="Original Question")
        fb_answer = gr.Textbox(label="Model's Answer")
        fb_like = gr.Radio(["๐Ÿ‘", "๐Ÿ‘Ž"], label="Your Feedback")
        fb_correct = gr.Textbox(label="Correct Answer (optional)")
        fb_submit_btn = gr.Button("Submit Feedback")
        fb_status = gr.Textbox(label="Status", interactive=False)
        fb_submit_btn.click(fn=submit_feedback,
                            inputs=[fb_question, fb_answer, fb_like, fb_correct],
                            outputs=[fb_status])

demo.launch()