File size: 16,598 Bytes
3b22917 2ab5a48 3b22917 5ebf2af 3b22917 0874ae0 2ab5a48 0874ae0 90bf24a 3b22917 0874ae0 3b22917 7ecee44 767d298 7ecee44 767d298 7ecee44 767d298 7ecee44 3b22917 835cee8 767d298 835cee8 4e1775f 3b22917 835cee8 767d298 835cee8 767d298 835cee8 767d298 835cee8 767d298 3b22917 0874ae0 3b22917 0874ae0 3b22917 81027bf 3b22917 0874ae0 3b22917 3b91398 767d298 3b22917 54868b2 b478018 42c164c 9d959fa 42c164c 00255ab 9d959fa f8b904a 4e3212e 4864bca 00255ab 9d959fa 3b22917 13cd04c 42c164c b478018 3b22917 f3ac4ce 3b22917 81027bf 4e1775f 81027bf 3b22917 81027bf 3b22917 4e1775f 81027bf 3b22917 596e446 3b22917 5a0e318 3b22917 5a0e318 4e1775f 3b22917 a8f5726 3b22917 ffa0300 116874e 2dcb2bc 69d9615 ffa0300 3b22917 596e446 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 |
"""LangGraph Agent"""
import os
import pandas as pd
from dotenv import load_dotenv
from langgraph.graph import START, StateGraph, MessagesState
from langgraph.prebuilt import tools_condition
from langgraph.prebuilt import ToolNode
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_groq import ChatGroq
from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint, HuggingFaceEmbeddings, HuggingFacePipeline
from langchain_community.tools.tavily_search import TavilySearchResults
from langchain_community.document_loaders import WikipediaLoader
from langchain_community.document_loaders import ArxivLoader
from langchain_community.vectorstores import SupabaseVectorStore
from langchain_core.messages import SystemMessage, HumanMessage
from langchain_core.tools import tool
from langchain.tools.retriever import create_retriever_tool
from supabase.client import Client, create_client
from pydantic import BaseModel, Field
from typing import List, Set, Any
load_dotenv()
class TableCommutativityInput(BaseModel):
table: List[List[Any]] = Field(description="The 2D list representing the multiplication table.")
elements: List[str] = Field(description="The list of header elements corresponding to the table rows/columns.")
class VegetableListInput(BaseModel):
items: List[str] = Field(description="A list of grocery item strings.")
@tool
def multiply(a: int, b: int) -> int:
"""Multiply two numbers.
Args:
a: first int
b: second int
"""
return a * b
@tool
def add(a: int, b: int) -> int:
"""Add two numbers.
Args:
a: first int
b: second int
"""
return a + b
@tool
def subtract(a: int, b: int) -> int:
"""Subtract two numbers.
Args:
a: first int
b: second int
"""
return a - b
@tool
def divide(a: int, b: int) -> int:
"""Divide two numbers.
Args:
a: first int
b: second int
"""
if b == 0:
raise ValueError("Cannot divide by zero.")
return a / b
@tool
def modulus(a: int, b: int) -> int:
"""Get the modulus of two numbers.
Args:
a: first int
b: second int
"""
return a % b
@tool
def wiki_search(query: str) -> str:
"""Search Wikipedia for a query and return maximum 2 results.
Args:
query: The search query."""
search_docs = WikipediaLoader(query=query, load_max_docs=2).load()
formatted_search_docs = "\n\n---\n\n".join(
[
f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
for doc in search_docs
])
return {"wiki_results": formatted_search_docs}
# @tool
# def web_search(query: str) -> str:
# """Search Tavily for a query and return maximum 3 results.
# Args:
# query: The search query."""
# search_docs = TavilySearchResults(max_results=3).invoke(query=query)
# formatted_search_docs = "\n\n---\n\n".join(
# [
# f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
# for doc in search_docs
# ])
# return {"web_results": formatted_search_docs}
@tool
def web_search(query: str) -> dict: # Changed return type annotation to dict
"""Search Tavily for a query and return maximum 3 results.
Each result will be formatted with its source URL and content.
Args:
query: The search query.
"""
print(f"\n--- Web Search Tool ---") # For debugging
print(f"Received query: {query}")
try:
tavily_tool = TavilySearchResults(max_results=3)
# .invoke() for TavilySearchResults typically expects 'input'
# and returns a list of dictionaries
search_results_list = tavily_tool.invoke(input=query)
print(f"Raw Tavily search results type: {type(search_results_list)}")
if isinstance(search_results_list, list):
print(f"Number of results: {len(search_results_list)}")
if search_results_list:
print(f"Type of first result: {type(search_results_list[0])}")
if isinstance(search_results_list[0], dict):
print(f"Keys in first result: {search_results_list[0].keys()}")
formatted_docs = []
if isinstance(search_results_list, list):
for doc_dict in search_results_list:
if isinstance(doc_dict, dict):
source = doc_dict.get("url", "N/A")
content = doc_dict.get("content", "")
# title = doc_dict.get("title", "") # Optionally include title
# score = doc_dict.get("score", "") # Optionally include score
# Constructing the XML-like format you desire
formatted_doc = (
f'<Document source="{source}">\n'
f'{content}\n'
f'</Document>'
)
formatted_docs.append(formatted_doc)
else:
# If an item in the list is not a dict, convert it to string
print(f"Warning: Unexpected item type in Tavily results list: {type(doc_dict)}")
formatted_docs.append(str(doc_dict))
final_formatted_string = "\n\n---\n\n".join(formatted_docs)
elif isinstance(search_results_list, str): # Less common, but for robustness
final_formatted_string = search_results_list
else:
print(f"Unexpected Tavily search result format overall: {type(search_results_list)}")
final_formatted_string = str(search_results_list) # Fallback
print(f"Formatted search docs for LLM:\n{final_formatted_string[:500]}...") # Print a snippet
return {"web_results": final_formatted_string}
except Exception as e:
print(f"Error during Tavily search for query '{query}': {e}")
# It's good practice to return an error message in the expected dict format
return {"web_results": f"Error performing web search: {e}"}
@tool
def arvix_search(query: str) -> str:
"""Search Arxiv for a query and return maximum 3 result.
Args:
query: The search query."""
search_docs = ArxivLoader(query=query, load_max_docs=3).load()
formatted_search_docs = "\n\n---\n\n".join(
[
f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content[:1000]}\n</Document>'
for doc in search_docs
])
return {"arvix_results": formatted_search_docs}
@tool
def reverse_text(text_to_reverse: str) -> str:
"""Reverses the input text.
Args:
text_to_reverse: The text to be reversed.
"""
if not isinstance(text_to_reverse, str):
raise TypeError("Input must be a string.")
return text_to_reverse[::-1]
@tool(args_schema=TableCommutativityInput)
def find_non_commutative_elements(table: List[List[Any]], elements: List[str]) -> str:
"""
Given a multiplication table (2D list) and its header elements,
returns a comma-separated string of elements involved in any non-commutative operations (a*b != b*a),
sorted alphabetically.
"""
if len(table) != len(elements) or (len(table) > 0 and len(table[0]) != len(elements)):
raise ValueError("Table dimensions must match the number of elements.")
non_comm: Set[str] = set()
for i, a in enumerate(elements):
for j, b in enumerate(elements):
if i < j: # Avoid checking twice (a*b vs b*a and b*a vs a*b) and self-comparison
if table[i][j] != table[j][i]:
non_comm.add(a)
non_comm.add(b)
# Return as a comma-separated string as per typical LLM tool output preference
return ", ".join(sorted(list(non_comm)))
@tool(args_schema=VegetableListInput)
def list_vegetables(items: List[str]) -> str:
"""
From a list of grocery items, returns a comma-separated string of those
that are true vegetables (botanical definition, based on a predefined set),
sorted alphabetically.
"""
_VEG_SET = {
"broccoli", "bell pepper", "celery", "corn", # Note: corn, bell pepper are botanically fruits
"green beans", "lettuce", "sweet potatoes", "zucchini" # Note: green beans, zucchini are botanically fruits
}
# Corrected according to common culinary definitions rather than strict botanical for a typical user:
_CULINARY_VEG_SET = {
"broccoli", "celery", "lettuce", "sweet potatoes", # Potatoes are tubers (stems)
# Items often considered vegetables culinarily but are botanically fruits:
# "bell pepper", "corn", "green beans", "zucchini", "tomato", "cucumber", "squash", "eggplant"
# You need to be very clear about which definition the tool should use.
# For the original problem's intent with a "stickler botanist mom", the original set was
# actually trying to define culinary vegetables, and the *fruits* were the ones to avoid.
# The prompt needs to be clear. Let's assume the provided _VEG_SET was the desired one
# despite its botanical inaccuracies for some items if the goal was "botanical vegetables".
}
# Sticking to the provided _VEG_SET for now, assuming it was curated for a specific purpose.
# If the goal is strict botanical vegetables, this set would need significant revision.
vegetables_found = sorted([item for item in items if item.lower() in _VEG_SET])
return ", ".join(vegetables_found)
class ExcelSumFoodInput(BaseModel):
excel_path: str = Field(description="The file path to the .xlsx Excel file to read.")
@tool(args_schema=ExcelSumFoodInput)
def sum_food_sales(excel_path: str) -> str:
"""
Reads an Excel file with columns 'Category' and 'Sales',
and returns total sales (as a string) for categories that are NOT 'Drink',
rounded to two decimal places.
Args:
excel_path: The file path to the .xlsx Excel file to read.
"""
try:
df = pd.read_excel(excel_path)
if "Category" not in df.columns or "Sales" not in df.columns:
raise ValueError("Excel file must contain 'Category' and 'Sales' columns.")
# Ensure 'Sales' column is numeric, coercing errors to NaN
df["Sales"] = pd.to_numeric(df["Sales"], errors='coerce')
# Filter out 'Drink' and then sum, handling potential NaNs from coercion
total = df.loc[df["Category"].str.lower() != "drink", "Sales"].sum(skipna=True)
return str(round(float(total), 2))
except FileNotFoundError:
return f"Error: File not found at path '{excel_path}'"
except ValueError as ve:
return f"Error processing Excel file: {ve}"
except Exception as e:
return f"An unexpected error occurred: {e}"
# load the system prompt from the file
with open("system_prompt.txt", "r", encoding="utf-8") as f:
system_prompt = f.read()
# System message
sys_msg = SystemMessage(content=system_prompt)
# build a retriever
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2") # dim=768
supabase: Client = create_client(
os.environ.get("SUPABASE_URL"),
os.environ.get("SUPABASE_SERVICE_KEY"))
vector_store = SupabaseVectorStore(
client=supabase,
embedding= embeddings,
table_name="documents",
query_name="match_documents_langchain",
)
create_retriever_tool = create_retriever_tool(
retriever=vector_store.as_retriever(),
name="Question Search",
description="A tool to retrieve similar questions from a vector store.",
)
tools = [
multiply,
add,
subtract,
divide,
modulus,
wiki_search,
web_search,
arvix_search,
reverse_text,
find_non_commutative_elements,
list_vegetables,
sum_food_sales,
]
hf_token = os.environ.get('HF_TOKEN')
if not hf_token:
raise ValueError("Hugging Face API token (HF_TOKEN) not found in environment variables.")
tavili_key = os.environ.get('TAVILY_API_KEY')
if not tavili_key:
raise ValueError("Hugging Face API token (HF_TOKEN) not found in environment variables.")
# Build graph function
def build_graph(provider: str = "huggingface"):
"""Build the graph"""
# Load environment variables from .env file
if provider == "google":
# Google Gemini
llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", temperature=0)
elif provider == "groq":
# Groq https://console.groq.com/docs/models
llm = ChatGroq(model="qwen-qwq-32b", temperature=0) # optional : qwen-qwq-32b gemma2-9b-it
elif provider == "huggingface":
# repo_id = "togethercomputer/evo-1-131k-base"
# repo_id="HuggingFaceH4/zephyr-7b-beta",
# repo_id="Qwen/Qwen2.5-Coder-32B-Instruct",
if not hf_token:
raise ValueError("HF_TOKEN environment variable not set. It's required for Hugging Face provider.")
llm = HuggingFaceEndpoint(
repo_id="meta-llama/Llama-4-Scout-17B-16E-Instruct",
provider="auto",
task="text-generation",
max_new_tokens=1000,
do_sample=False,
repetition_penalty=1.03,
)
llm = ChatHuggingFace(llm=llm)
else:
raise ValueError("Invalid provider. Choose 'google', 'groq' or 'huggingface'.")
# Bind tools to LLM
"""Build the graph"""
llm_with_tools = llm.bind_tools(tools)
# Node
def assistant(state: MessagesState):
print("\n--- Assistant Node ---")
print("Incoming messages to assistant:")
for msg in state["messages"]:
msg.pretty_print() #
"""Assistant node"""
return {"messages": [llm_with_tools.invoke(state["messages"])]}
def retriever(state: MessagesState):
"""Retriever node"""
similar_question = vector_store.similarity_search(state["messages"][0].content)
example_msg = HumanMessage(
content=f"Here I provide a similar question and answer for reference: \n\n{similar_question[0].page_content}",
)
print("ex msgs"+[sys_msg] + state["messages"] + [example_msg])
return {"messages": [sys_msg] + state["messages"] + [example_msg]}
builder = StateGraph(MessagesState)
builder.add_node("retriever", retriever)
builder.add_node("assistant", assistant)
builder.add_node("tools", ToolNode(tools))
builder.add_edge(START, "assistant")
builder.add_edge("retriever", "assistant")
builder.add_conditional_edges(
"assistant",
tools_condition,
)
builder.add_edge("tools", "assistant")
# Compile graph
compiled_graph = builder.compile() # This line should already be there or be the next line
# --- START: Add this visualization code ---
try:
print("Attempting to generate graph visualization...")
image_filename = "langgraph_state_diagram.png"
# Using draw_mermaid_png as it's often more robust
image_bytes = compiled_graph.get_graph().draw_mermaid_png()
with open(image_filename, "wb") as f:
f.write(image_bytes)
print(f"SUCCESS: Graph visualization saved to '{image_filename}'")
except ImportError as e:
print(f"WARNING: Could not generate graph image due to missing package: {e}. "
"Ensure 'pygraphviz' and 'graphviz' (system) are installed, or Mermaid components are available.")
except Exception as e:
print(f"WARNING: An error occurred while generating the graph image: {e}")
try:
print("\nGraph (DOT format as fallback):\n", compiled_graph.get_graph().to_string())
except Exception as dot_e:
print(f"Could not even get DOT string: {dot_e}")
# --- END: Visualization code ---
return compiled_graph # This should be the last line of the function
# test
if __name__ == "__main__":
question = "When was a picture of St. Thomas Aquinas first added to the Wikipedia page on the Principle of double effect?"
# Build the graph
graph = build_graph(provider="huggingface")
# Run the graph
messages = [HumanMessage(content=question)]
print(messages)
config = {"recursion_limit": 27}
messages = graph.invoke({"messages": messages}, config=config)
for m in messages["messages"]:
m.pretty_print()
|