|
import os |
|
from dotenv import load_dotenv |
|
from langgraph.graph import START, StateGraph, MessagesState |
|
from langgraph.prebuilt import tools_condition |
|
from langgraph.prebuilt import ToolNode |
|
from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint, HuggingFaceEmbeddings |
|
from langchain_community.vectorstores import SupabaseVectorStore |
|
from langchain_core.messages import HumanMessage |
|
from langchain.tools.retriever import create_retriever_tool |
|
from supabase.client import Client, create_client |
|
from utils import load_prompt |
|
from tools import calculator, duck_web_search, wiki_search, arxiv_search |
|
|
|
load_dotenv() |
|
|
|
|
|
embeddings = HuggingFaceEmbeddings(model_name="Alibaba-NLP/gte-modernbert-base") |
|
|
|
supabase: Client = create_client(os.getenv("SUPABASE_URL"), os.getenv("SUPABASE_SERVICE_KEY")) |
|
vector_store = SupabaseVectorStore( |
|
client=supabase, |
|
embedding= embeddings, |
|
table_name="gaia_documents", |
|
query_name="match_documents_langchain", |
|
) |
|
|
|
retriever = create_retriever_tool( |
|
retriever=vector_store.as_retriever(), |
|
name="ModernBERT Retriever", |
|
description="A retriever of similar questions from a vector store.", |
|
) |
|
|
|
tools = [calculator, duck_web_search, wiki_search, arxiv_search] |
|
|
|
model_id = "Qwen/Qwen3-32B" |
|
|
|
llm = HuggingFaceEndpoint( |
|
repo_id=model_id, |
|
temperature=0, |
|
repetition_penalty=1.03, |
|
provider="auto", |
|
huggingfacehub_api_token=os.getenv("HF_INFERENCE_KEY") |
|
) |
|
|
|
agent = ChatHuggingFace(llm=llm) |
|
|
|
agent_with_tools = agent.bind_tools(tools) |
|
|
|
def retriever_node(state: MessagesState): |
|
"""RAG node""" |
|
similar_question = vector_store.similarity_search(state["messages"][0].content) |
|
response = [HumanMessage(f"Here I provide a similar question and answer for reference: \n\n{similar_question[0].page_content}")] |
|
return {"messages": response} |
|
|
|
def processor_node(state: MessagesState): |
|
|
|
system_prompt = load_prompt("prompt.yaml") |
|
|
|
messages = state.get("messages", []) |
|
response = [agent_with_tools.invoke([system_prompt] + messages)] |
|
"""Agent node that answers questions""" |
|
return {"messages": response} |
|
|
|
def agent_graph(): |
|
builder = StateGraph(MessagesState) |
|
|
|
|
|
builder.add_node("retriever_node", retriever_node) |
|
builder.add_node("processor_node", processor_node) |
|
builder.add_node("tools", ToolNode(tools)) |
|
|
|
|
|
builder.add_edge(START, "retriever_node") |
|
builder.add_edge("retriever_node", "processor_node") |
|
builder.add_conditional_edges("processor_node", tools_condition) |
|
builder.add_edge("tools", "processor_node") |
|
|
|
|
|
builder.compile() |
|
|
|
return builder |