File size: 3,227 Bytes
9bf33d9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
from langgraph.graph import StateGraph, END
from langchain.chains import RetrievalQA
from typing import TypedDict, Optional
from tools import llm, load_vectorstore, search_tool

# Load your vectorstore
vectorstore = load_vectorstore()

# --- TypedDict to define graph state schema ---
class GraphState(TypedDict):
    question: str
    pdf_answer: Optional[str]
    llm_answer: Optional[str]
    web_answer: Optional[str]

# --- LangGraph Node Functions ---

def pdf_qa_node(state: GraphState) -> GraphState:
    query = state["question"]
    qa = RetrievalQA.from_chain_type(llm=llm, retriever=vectorstore.as_retriever())
    result = qa.run(query)
    return {**state, "pdf_answer": result}

def check_pdf_relevance(state: GraphState) -> str:
    ans = state.get("pdf_answer", "").lower()
    if (
        "i don't know" in ans
        or "i don't have information" in ans
        or "no relevant" in ans
        or "not available" in ans
        or len(ans.strip()) < 20
    ):
        return "llm_fallback"
    return "respond_pdf"

def llm_fallback_node(state: GraphState) -> GraphState:
    query = state["question"]
    prompt = f"""You are a helpful AI assistant. The user asked a question, and no relevant documents were found.



Try your best to answer this:



Question: {query}

Answer:"""
    res = llm.invoke(prompt)
    return {**state, "llm_answer": res.content}

def check_llm_confidence(state: GraphState) -> str:
    ans = state.get("llm_answer", "").lower()
    if "i don't know" in ans or "not sure" in ans or "no idea" in ans:
        return "web_search"
    return "respond_llm"

def web_search_node(state: GraphState) -> GraphState:
    query = state["question"]
    result = search_tool(query)
    return {**state, "web_answer": result}

def respond_pdf(state: GraphState) -> dict:
    print("πŸ“„ Responding from PDF")
    return {"answer": state["pdf_answer"]}

def respond_llm(state: GraphState) -> dict:
    print("πŸ€– Responding from LLM")
    return {"answer": state["llm_answer"]}

def respond_web(state: GraphState) -> dict:
    print("🌐 Responding from Web Search")
    return {"answer": state["web_answer"]}

# --- Graph Creation Function ---
def create_graph():
    builder = StateGraph(GraphState)  # Pass schema

    builder.add_node("pdf_qa", pdf_qa_node)
    builder.add_node("llm_fallback", llm_fallback_node)
    builder.add_node("web_search", web_search_node)

    builder.add_node("respond_pdf", respond_pdf)
    builder.add_node("respond_llm", respond_llm)
    builder.add_node("respond_web", respond_web)

    builder.set_entry_point("pdf_qa")

    builder.add_conditional_edges("pdf_qa", check_pdf_relevance, {
        "respond_pdf": "respond_pdf",
        "llm_fallback": "llm_fallback"
    })

    builder.add_conditional_edges("llm_fallback", check_llm_confidence, {
        "respond_llm": "respond_llm",
        "web_search": "web_search"
    })

    builder.add_edge("web_search", "respond_web")

    # Set all end nodes
    builder.add_edge("respond_pdf", END)
    builder.add_edge("respond_llm", END)
    builder.add_edge("respond_web", END)

    return builder.compile()