LangGraph-FastAPI / langgraph_chain.py
Vishaltiwari2019's picture
Upload 5 files
9bf33d9 verified
raw
history blame
3.23 kB
from langgraph.graph import StateGraph, END
from langchain.chains import RetrievalQA
from typing import TypedDict, Optional
from tools import llm, load_vectorstore, search_tool
# Load your vectorstore
vectorstore = load_vectorstore()
# --- TypedDict to define graph state schema ---
class GraphState(TypedDict):
question: str
pdf_answer: Optional[str]
llm_answer: Optional[str]
web_answer: Optional[str]
# --- LangGraph Node Functions ---
def pdf_qa_node(state: GraphState) -> GraphState:
query = state["question"]
qa = RetrievalQA.from_chain_type(llm=llm, retriever=vectorstore.as_retriever())
result = qa.run(query)
return {**state, "pdf_answer": result}
def check_pdf_relevance(state: GraphState) -> str:
ans = state.get("pdf_answer", "").lower()
if (
"i don't know" in ans
or "i don't have information" in ans
or "no relevant" in ans
or "not available" in ans
or len(ans.strip()) < 20
):
return "llm_fallback"
return "respond_pdf"
def llm_fallback_node(state: GraphState) -> GraphState:
query = state["question"]
prompt = f"""You are a helpful AI assistant. The user asked a question, and no relevant documents were found.
Try your best to answer this:
Question: {query}
Answer:"""
res = llm.invoke(prompt)
return {**state, "llm_answer": res.content}
def check_llm_confidence(state: GraphState) -> str:
ans = state.get("llm_answer", "").lower()
if "i don't know" in ans or "not sure" in ans or "no idea" in ans:
return "web_search"
return "respond_llm"
def web_search_node(state: GraphState) -> GraphState:
query = state["question"]
result = search_tool(query)
return {**state, "web_answer": result}
def respond_pdf(state: GraphState) -> dict:
print("πŸ“„ Responding from PDF")
return {"answer": state["pdf_answer"]}
def respond_llm(state: GraphState) -> dict:
print("πŸ€– Responding from LLM")
return {"answer": state["llm_answer"]}
def respond_web(state: GraphState) -> dict:
print("🌐 Responding from Web Search")
return {"answer": state["web_answer"]}
# --- Graph Creation Function ---
def create_graph():
builder = StateGraph(GraphState) # Pass schema
builder.add_node("pdf_qa", pdf_qa_node)
builder.add_node("llm_fallback", llm_fallback_node)
builder.add_node("web_search", web_search_node)
builder.add_node("respond_pdf", respond_pdf)
builder.add_node("respond_llm", respond_llm)
builder.add_node("respond_web", respond_web)
builder.set_entry_point("pdf_qa")
builder.add_conditional_edges("pdf_qa", check_pdf_relevance, {
"respond_pdf": "respond_pdf",
"llm_fallback": "llm_fallback"
})
builder.add_conditional_edges("llm_fallback", check_llm_confidence, {
"respond_llm": "respond_llm",
"web_search": "web_search"
})
builder.add_edge("web_search", "respond_web")
# Set all end nodes
builder.add_edge("respond_pdf", END)
builder.add_edge("respond_llm", END)
builder.add_edge("respond_web", END)
return builder.compile()