Spaces:
Sleeping
Sleeping
from langgraph.graph import StateGraph, END | |
from langchain.chains import RetrievalQA | |
from typing import TypedDict, Optional | |
from tools import llm, load_vectorstore, search_tool | |
# Load your vectorstore | |
vectorstore = load_vectorstore() | |
# --- TypedDict to define graph state schema --- | |
class GraphState(TypedDict): | |
question: str | |
pdf_answer: Optional[str] | |
llm_answer: Optional[str] | |
web_answer: Optional[str] | |
# --- LangGraph Node Functions --- | |
def pdf_qa_node(state: GraphState) -> GraphState: | |
query = state["question"] | |
qa = RetrievalQA.from_chain_type(llm=llm, retriever=vectorstore.as_retriever()) | |
result = qa.run(query) | |
return {**state, "pdf_answer": result} | |
def check_pdf_relevance(state: GraphState) -> str: | |
ans = state.get("pdf_answer", "").lower() | |
if ( | |
"i don't know" in ans | |
or "i don't have information" in ans | |
or "no relevant" in ans | |
or "not available" in ans | |
or len(ans.strip()) < 20 | |
): | |
return "llm_fallback" | |
return "respond_pdf" | |
def llm_fallback_node(state: GraphState) -> GraphState: | |
query = state["question"] | |
prompt = f"""You are a helpful AI assistant. The user asked a question, and no relevant documents were found. | |
Try your best to answer this: | |
Question: {query} | |
Answer:""" | |
res = llm.invoke(prompt) | |
return {**state, "llm_answer": res.content} | |
def check_llm_confidence(state: GraphState) -> str: | |
ans = state.get("llm_answer", "").lower() | |
if "i don't know" in ans or "not sure" in ans or "no idea" in ans: | |
return "web_search" | |
return "respond_llm" | |
def web_search_node(state: GraphState) -> GraphState: | |
query = state["question"] | |
result = search_tool(query) | |
return {**state, "web_answer": result} | |
def respond_pdf(state: GraphState) -> dict: | |
print("π Responding from PDF") | |
return {"answer": state["pdf_answer"]} | |
def respond_llm(state: GraphState) -> dict: | |
print("π€ Responding from LLM") | |
return {"answer": state["llm_answer"]} | |
def respond_web(state: GraphState) -> dict: | |
print("π Responding from Web Search") | |
return {"answer": state["web_answer"]} | |
# --- Graph Creation Function --- | |
def create_graph(): | |
builder = StateGraph(GraphState) # Pass schema | |
builder.add_node("pdf_qa", pdf_qa_node) | |
builder.add_node("llm_fallback", llm_fallback_node) | |
builder.add_node("web_search", web_search_node) | |
builder.add_node("respond_pdf", respond_pdf) | |
builder.add_node("respond_llm", respond_llm) | |
builder.add_node("respond_web", respond_web) | |
builder.set_entry_point("pdf_qa") | |
builder.add_conditional_edges("pdf_qa", check_pdf_relevance, { | |
"respond_pdf": "respond_pdf", | |
"llm_fallback": "llm_fallback" | |
}) | |
builder.add_conditional_edges("llm_fallback", check_llm_confidence, { | |
"respond_llm": "respond_llm", | |
"web_search": "web_search" | |
}) | |
builder.add_edge("web_search", "respond_web") | |
# Set all end nodes | |
builder.add_edge("respond_pdf", END) | |
builder.add_edge("respond_llm", END) | |
builder.add_edge("respond_web", END) | |
return builder.compile() | |