File size: 5,740 Bytes
5775448
9ec24d8
b7b20e2
 
d16f9ab
af77c21
 
5eea801
d16f9ab
 
af77c21
b7b20e2
9ec24d8
b7b20e2
9ec24d8
 
b7b20e2
9ec24d8
d16f9ab
 
 
af77c21
b7b20e2
9ec24d8
b7b20e2
 
 
5eea801
b7b20e2
 
 
d16f9ab
b7b20e2
 
d16f9ab
b7b20e2
d16f9ab
b7b20e2
d16f9ab
 
b7b20e2
 
d16f9ab
b7b20e2
d16f9ab
02f7269
b7b20e2
 
 
 
 
 
d16f9ab
 
 
 
 
 
 
 
 
 
 
 
 
 
6d00b6b
d16f9ab
6d00b6b
 
 
 
d16f9ab
6d00b6b
d16f9ab
 
 
 
 
6d00b6b
 
d16f9ab
 
 
 
 
 
6d00b6b
d16f9ab
6d00b6b
 
 
b7b20e2
 
d16f9ab
b7b20e2
 
 
d16f9ab
b7b20e2
 
 
d16f9ab
b7b20e2
 
 
5eea801
b7b20e2
d16f9ab
 
 
6d00b6b
 
 
d16f9ab
 
 
 
b7b20e2
 
5eea801
b7b20e2
d16f9ab
 
 
 
 
 
 
 
 
 
 
af77c21
d16f9ab
a6a2ff2
d16f9ab
87cc698
 
 
 
 
 
 
6d00b6b
 
 
d16f9ab
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150

import gradio as gr
import torch
import requests
from transformers import pipeline
from sentence_transformers import SentenceTransformer
from qdrant_client import QdrantClient
from datetime import datetime
import dspy
import json

# === Load Models ===
print("Loading zero-shot classifier...")
classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")

print("Loading embedding model...")
embedding_model = SentenceTransformer("intfloat/e5-large")

print("Loading text generation model...")
# Use a lighter model for testing
#qa_pipeline = pipeline("text-generation", model="gpt2")

# === Qdrant Setup ===
print("Connecting to Qdrant...")
qdrant_client = QdrantClient(path="qdrant_data")
collection_name = "math_problems"

# === Guard Function ===
def is_valid_math_question(text):
    candidate_labels = ["math", "not math"]
    result = classifier(text, candidate_labels)
    print("Classifier result:", result)
    return result['labels'][0] == "math" and result['scores'][0] > 0.7

# === Retrieval ===
def retrieve_from_qdrant(query):
    print("Retrieving context from Qdrant...")
    query_vector = embedding_model.encode(query).tolist()
    hits = qdrant_client.search(collection_name=collection_name, query_vector=query_vector, limit=3)
    print("Retrieved hits:", hits)
    return [hit.payload for hit in hits] if hits else []

# === Web Search ===
def web_search_tavily(query):
    print("Calling Tavily...")
    TAVILY_API_KEY = "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"
    response = requests.post(
        "https://api.tavily.com/search",
        json={"api_key": TAVILY_API_KEY, "query": query, "search_depth": "advanced"},
    )
    return response.json().get("answer", "No answer found from Tavily.")

# === DSPy Signature ===
class MathAnswer(dspy.Signature):
    question = dspy.InputField()
    retrieved_context = dspy.InputField()
    answer = dspy.OutputField()

# === DSPy Programs ===
class MathRetrievalQA(dspy.Program):
    def forward(self, question):
        print("Inside MathRetrievalQA...")
        context_items = retrieve_from_qdrant(question)
        context = "\n".join([item["solution"] for item in context_items if "solution" in item])
        print("Context for generation:", context)
        if not context:
            return dspy.Output(answer="", retrieved_context="")
        prompt = f"Question: {question}\nContext: {context}\nAnswer:"
        print("Generating answer...")
       # answer = qa_pipeline(prompt, max_new_tokens=100)[0]["generated_text"]
        print("Generated answer:", prompt)
        return {"answer": prompt, "retrieved_context": context}

       # return dspy.Output(answer=answer, retrieved_context=context)

class WebFallbackQA(dspy.Program):
    def forward(self, question):
        print("Fallback to Tavily...")
        answer = web_search_tavily(question)
       # return dspy.Output(answer=answer, retrieved_context="Tavily")
        return {"answer": answer, "retrieved_context": "Tavily"}


class MathRouter(dspy.Program):
    def forward(self, question):
        print("Routing question:", question)
        if not is_valid_math_question(question):
            return dspy.Output(answer="โŒ Only math questions are accepted. Please rephrase.", retrieved_context="")
        result = MathRetrievalQA().forward(question)
        #return result if result.answer else WebFallbackQA().forward(question)
        return result if result["answer"] else WebFallbackQA().forward(question)
router = MathRouter()

# === Feedback Storage ===
def store_feedback(question, answer, feedback, correct_answer):
    entry = {
        "question": question,
        "model_answer": answer,
        "feedback": feedback,
        "correct_answer": correct_answer,
        "timestamp": str(datetime.now())
    }
    print("Storing feedback:", entry)
    with open("feedback.json", "a") as f:
        f.write(json.dumps(entry) + "\n")

# === Gradio Functions ===
def ask_question(question):
    print("ask_question() called with:", question)
    result = router.forward(question)
    print("Result:", result)
    #return result.answer, question, result.answer
    #return result["answer"], question, result["answer"]
    return result["answer"]


def submit_feedback(question, model_answer, feedback, correct_answer):
    store_feedback(question, model_answer, feedback, correct_answer)
    return "โœ… Feedback received. Thank you!"

# === Gradio UI ===
with gr.Blocks() as demo:
    gr.Markdown("## ๐Ÿงฎ Math Question Answering with DSPy + Feedback")
    
    with gr.Tab("Ask a Math Question"):
        with gr.Row():
            question_input = gr.Textbox(label="Enter your math question", lines=2)
        gr.Markdown("### ๐Ÿง  Answer:")
        answer_output = gr.Markdown()

        #answer_output = gr.Markdown(label="Answer")
        hidden_q = gr.Textbox(visible=False)
        hidden_a = gr.Textbox(visible=False)
        submit_btn = gr.Button("Get Answer")
        submit_btn.click(fn=ask_question, inputs=[question_input], outputs=[answer_output, hidden_q, hidden_a])

    with gr.Tab("Submit Feedback"):
        gr.Markdown("### Was the answer helpful?")
        fb_question = gr.Textbox(label="Original Question")
        fb_answer = gr.Textbox(label="Model's Answer")
        fb_like = gr.Radio(["๐Ÿ‘", "๐Ÿ‘Ž"], label="Your Feedback")
        fb_correct = gr.Textbox(label="Correct Answer (optional)")
        fb_submit_btn = gr.Button("Submit Feedback")
        fb_status = gr.Textbox(label="Status", interactive=False)
        fb_submit_btn.click(fn=submit_feedback,
                            inputs=[fb_question, fb_answer, fb_like, fb_correct],
                            outputs=[fb_status])

demo.launch(share=True, debug=True)