wishwakankanamg commited on
Commit
0e07fa6
·
1 Parent(s): 3b6de84
__pycache__/tools_agent.cpython-310.pyc CHANGED
Binary files a/__pycache__/tools_agent.cpython-310.pyc and b/__pycache__/tools_agent.cpython-310.pyc differ
 
agent.py DELETED
@@ -1,436 +0,0 @@
1
- """LangGraph Agent"""
2
- import os
3
- import pandas as pd
4
- from dotenv import load_dotenv
5
- from langgraph.graph import START, StateGraph, MessagesState
6
- from langgraph.prebuilt import tools_condition
7
- from langgraph.prebuilt import ToolNode
8
- from langchain_google_genai import ChatGoogleGenerativeAI
9
- from langchain_groq import ChatGroq
10
- from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint, HuggingFaceEmbeddings, HuggingFacePipeline
11
- from langchain_community.tools.tavily_search import TavilySearchResults
12
- from langchain_community.document_loaders import WikipediaLoader
13
- from langchain_community.document_loaders import ArxivLoader
14
- from langchain_community.vectorstores import SupabaseVectorStore
15
- from langchain_core.messages import SystemMessage, HumanMessage
16
- from langchain_core.tools import tool
17
- from langchain.tools.retriever import create_retriever_tool
18
- from supabase.client import Client, create_client
19
- from pydantic import BaseModel, Field
20
-
21
-
22
- from typing import List, Set, Any
23
-
24
- load_dotenv()
25
-
26
- class TableCommutativityInput(BaseModel):
27
- table: List[List[Any]] = Field(description="The 2D list representing the multiplication table.")
28
- elements: List[str] = Field(description="The list of header elements corresponding to the table rows/columns.")
29
-
30
- class VegetableListInput(BaseModel):
31
- items: List[str] = Field(description="A list of grocery item strings.")
32
-
33
- @tool
34
- def multiply(a: int, b: int) -> int:
35
- """Multiply two numbers.
36
- Args:
37
- a: first int
38
- b: second int
39
- """
40
- return a * b
41
-
42
- @tool
43
- def add(a: int, b: int) -> int:
44
- """Add two numbers.
45
-
46
- Args:
47
- a: first int
48
- b: second int
49
- """
50
- return a + b
51
-
52
- @tool
53
- def subtract(a: int, b: int) -> int:
54
- """Subtract two numbers.
55
-
56
- Args:
57
- a: first int
58
- b: second int
59
- """
60
- return a - b
61
-
62
- @tool
63
- def divide(a: int, b: int) -> int:
64
- """Divide two numbers.
65
-
66
- Args:
67
- a: first int
68
- b: second int
69
- """
70
- if b == 0:
71
- raise ValueError("Cannot divide by zero.")
72
- return a / b
73
-
74
- @tool
75
- def modulus(a: int, b: int) -> int:
76
- """Get the modulus of two numbers.
77
-
78
- Args:
79
- a: first int
80
- b: second int
81
- """
82
- return a % b
83
-
84
- @tool
85
- def wiki_search(query: str) -> str:
86
- """Search Wikipedia for a query and return maximum 2 results.
87
-
88
- Args:
89
- query: The search query."""
90
- search_docs = WikipediaLoader(query=query, load_max_docs=2).load()
91
- formatted_search_docs = "\n\n---\n\n".join(
92
- [
93
- f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
94
- for doc in search_docs
95
- ])
96
- return {"wiki_results": formatted_search_docs}
97
-
98
- # @tool
99
- # def web_search(query: str) -> str:
100
- # """Search Tavily for a query and return maximum 3 results.
101
-
102
- # Args:
103
- # query: The search query."""
104
- # search_docs = TavilySearchResults(max_results=3).invoke(query=query)
105
- # formatted_search_docs = "\n\n---\n\n".join(
106
- # [
107
- # f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
108
- # for doc in search_docs
109
- # ])
110
- # return {"web_results": formatted_search_docs}
111
- @tool
112
- def web_search(query: str) -> dict: # Changed return type annotation to dict
113
- """Search Tavily for a query and return maximum 3 results.
114
- Each result will be formatted with its source URL and content.
115
-
116
- Args:
117
- query: The search query.
118
- """
119
- print(f"\n--- Web Search Tool ---") # For debugging
120
- print(f"Received query: {query}")
121
- try:
122
- tavily_tool = TavilySearchResults(max_results=3)
123
- # .invoke() for TavilySearchResults typically expects 'input'
124
- # and returns a list of dictionaries
125
- search_results_list = tavily_tool.invoke(input=query)
126
-
127
- print(f"Raw Tavily search results type: {type(search_results_list)}")
128
- if isinstance(search_results_list, list):
129
- print(f"Number of results: {len(search_results_list)}")
130
- if search_results_list:
131
- print(f"Type of first result: {type(search_results_list[0])}")
132
- if isinstance(search_results_list[0], dict):
133
- print(f"Keys in first result: {search_results_list[0].keys()}")
134
-
135
- formatted_docs = []
136
- if isinstance(search_results_list, list):
137
- for doc_dict in search_results_list:
138
- if isinstance(doc_dict, dict):
139
- source = doc_dict.get("url", "N/A")
140
- content = doc_dict.get("content", "")
141
- # title = doc_dict.get("title", "") # Optionally include title
142
- # score = doc_dict.get("score", "") # Optionally include score
143
-
144
- # Constructing the XML-like format you desire
145
- formatted_doc = (
146
- f'<Document source="{source}">\n'
147
- f'{content}\n'
148
- f'</Document>'
149
- )
150
- formatted_docs.append(formatted_doc)
151
- else:
152
- # If an item in the list is not a dict, convert it to string
153
- print(f"Warning: Unexpected item type in Tavily results list: {type(doc_dict)}")
154
- formatted_docs.append(str(doc_dict))
155
-
156
- final_formatted_string = "\n\n---\n\n".join(formatted_docs)
157
-
158
- elif isinstance(search_results_list, str): # Less common, but for robustness
159
- final_formatted_string = search_results_list
160
- else:
161
- print(f"Unexpected Tavily search result format overall: {type(search_results_list)}")
162
- final_formatted_string = str(search_results_list) # Fallback
163
-
164
- print(f"Formatted search docs for LLM:\n{final_formatted_string[:500]}...") # Print a snippet
165
- return {"web_results": final_formatted_string}
166
-
167
- except Exception as e:
168
- print(f"Error during Tavily search for query '{query}': {e}")
169
- # It's good practice to return an error message in the expected dict format
170
- return {"web_results": f"Error performing web search: {e}"}
171
-
172
- @tool
173
- def arvix_search(query: str) -> str:
174
- """Search Arxiv for a query and return maximum 3 result.
175
-
176
- Args:
177
- query: The search query."""
178
- search_docs = ArxivLoader(query=query, load_max_docs=3).load()
179
- formatted_search_docs = "\n\n---\n\n".join(
180
- [
181
- f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content[:1000]}\n</Document>'
182
- for doc in search_docs
183
- ])
184
- return {"arvix_results": formatted_search_docs}
185
-
186
- @tool
187
- def reverse_text(text_to_reverse: str) -> str:
188
- """Reverses the input text.
189
- Args:
190
- text_to_reverse: The text to be reversed.
191
- """
192
- if not isinstance(text_to_reverse, str):
193
- raise TypeError("Input must be a string.")
194
- return text_to_reverse[::-1]
195
-
196
-
197
- @tool(args_schema=TableCommutativityInput)
198
- def find_non_commutative_elements(table: List[List[Any]], elements: List[str]) -> str:
199
- """
200
- Given a multiplication table (2D list) and its header elements,
201
- returns a comma-separated string of elements involved in any non-commutative operations (a*b != b*a),
202
- sorted alphabetically.
203
- """
204
- if len(table) != len(elements) or (len(table) > 0 and len(table[0]) != len(elements)):
205
- raise ValueError("Table dimensions must match the number of elements.")
206
-
207
- non_comm: Set[str] = set()
208
- for i, a in enumerate(elements):
209
- for j, b in enumerate(elements):
210
- if i < j: # Avoid checking twice (a*b vs b*a and b*a vs a*b) and self-comparison
211
- if table[i][j] != table[j][i]:
212
- non_comm.add(a)
213
- non_comm.add(b)
214
- # Return as a comma-separated string as per typical LLM tool output preference
215
- return ", ".join(sorted(list(non_comm)))
216
-
217
-
218
- @tool(args_schema=VegetableListInput)
219
- def list_vegetables(items: List[str]) -> str:
220
- """
221
- From a list of grocery items, returns a comma-separated string of those
222
- that are true vegetables (botanical definition, based on a predefined set),
223
- sorted alphabetically.
224
- """
225
- _VEG_SET = {
226
- "broccoli", "bell pepper", "celery", "corn", # Note: corn, bell pepper are botanically fruits
227
- "green beans", "lettuce", "sweet potatoes", "zucchini" # Note: green beans, zucchini are botanically fruits
228
- }
229
- # Corrected according to common culinary definitions rather than strict botanical for a typical user:
230
- _CULINARY_VEG_SET = {
231
- "broccoli", "celery", "lettuce", "sweet potatoes", # Potatoes are tubers (stems)
232
- # Items often considered vegetables culinarily but are botanically fruits:
233
- # "bell pepper", "corn", "green beans", "zucchini", "tomato", "cucumber", "squash", "eggplant"
234
- # You need to be very clear about which definition the tool should use.
235
- # For the original problem's intent with a "stickler botanist mom", the original set was
236
- # actually trying to define culinary vegetables, and the *fruits* were the ones to avoid.
237
- # The prompt needs to be clear. Let's assume the provided _VEG_SET was the desired one
238
- # despite its botanical inaccuracies for some items if the goal was "botanical vegetables".
239
- }
240
- # Sticking to the provided _VEG_SET for now, assuming it was curated for a specific purpose.
241
- # If the goal is strict botanical vegetables, this set would need significant revision.
242
-
243
- vegetables_found = sorted([item for item in items if item.lower() in _VEG_SET])
244
- return ", ".join(vegetables_found)
245
-
246
- class ExcelSumFoodInput(BaseModel):
247
- excel_path: str = Field(description="The file path to the .xlsx Excel file to read.")
248
-
249
- @tool(args_schema=ExcelSumFoodInput)
250
- def sum_food_sales(excel_path: str) -> str:
251
- """
252
- Reads an Excel file with columns 'Category' and 'Sales',
253
- and returns total sales (as a string) for categories that are NOT 'Drink',
254
- rounded to two decimal places.
255
- Args:
256
- excel_path: The file path to the .xlsx Excel file to read.
257
- """
258
- try:
259
- df = pd.read_excel(excel_path)
260
- if "Category" not in df.columns or "Sales" not in df.columns:
261
- raise ValueError("Excel file must contain 'Category' and 'Sales' columns.")
262
-
263
- # Ensure 'Sales' column is numeric, coercing errors to NaN
264
- df["Sales"] = pd.to_numeric(df["Sales"], errors='coerce')
265
-
266
- # Filter out 'Drink' and then sum, handling potential NaNs from coercion
267
- total = df.loc[df["Category"].str.lower() != "drink", "Sales"].sum(skipna=True)
268
-
269
- return str(round(float(total), 2))
270
- except FileNotFoundError:
271
- return f"Error: File not found at path '{excel_path}'"
272
- except ValueError as ve:
273
- return f"Error processing Excel file: {ve}"
274
- except Exception as e:
275
- return f"An unexpected error occurred: {e}"
276
-
277
- # load the system prompt from the file
278
- with open("system_prompt.txt", "r", encoding="utf-8") as f:
279
- system_prompt = f.read()
280
-
281
- # System message
282
- sys_msg = SystemMessage(content=system_prompt)
283
-
284
- # build a retriever
285
- embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2") # dim=768
286
- supabase: Client = create_client(
287
- os.environ.get("SUPABASE_URL"),
288
- os.environ.get("SUPABASE_SERVICE_KEY"))
289
- vector_store = SupabaseVectorStore(
290
- client=supabase,
291
- embedding= embeddings,
292
- table_name="documents",
293
- query_name="match_documents_langchain",
294
- )
295
- create_retriever_tool = create_retriever_tool(
296
- retriever=vector_store.as_retriever(),
297
- name="Question Search",
298
- description="A tool to retrieve similar questions from a vector store.",
299
- )
300
-
301
-
302
-
303
- tools = [
304
- multiply,
305
- add,
306
- subtract,
307
- divide,
308
- modulus,
309
- wiki_search,
310
- web_search,
311
- arvix_search,
312
- reverse_text,
313
- find_non_commutative_elements,
314
- list_vegetables,
315
- sum_food_sales,
316
- ]
317
-
318
- hf_token = os.environ.get('HF_TOKEN')
319
- if not hf_token:
320
- raise ValueError("Hugging Face API token (HF_TOKEN) not found in environment variables.")
321
-
322
- tavili_key = os.environ.get('TAVILY_API_KEY')
323
- if not tavili_key:
324
- raise ValueError("Hugging Face API token (HF_TOKEN) not found in environment variables.")
325
-
326
-
327
- # Build graph function
328
- def build_graph(provider: str = "huggingface"):
329
-
330
- """Build the graph"""
331
- # Load environment variables from .env file
332
- if provider == "google":
333
- # Google Gemini
334
- llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", temperature=0)
335
- elif provider == "groq":
336
- # Groq https://console.groq.com/docs/models
337
- llm = ChatGroq(model="qwen-qwq-32b", temperature=0) # optional : qwen-qwq-32b gemma2-9b-it
338
- elif provider == "huggingface":
339
- # repo_id = "togethercomputer/evo-1-131k-base"
340
- # repo_id="HuggingFaceH4/zephyr-7b-beta",
341
- # repo_id="Qwen/Qwen2.5-Coder-32B-Instruct",
342
-
343
- if not hf_token:
344
- raise ValueError("HF_TOKEN environment variable not set. It's required for Hugging Face provider.")
345
- llm = HuggingFaceEndpoint(
346
- repo_id="meta-llama/Llama-4-Scout-17B-16E-Instruct",
347
- provider="auto",
348
- task="text-generation",
349
- max_new_tokens=1000,
350
- do_sample=False,
351
- repetition_penalty=1.03,
352
-
353
-
354
- )
355
- llm = ChatHuggingFace(llm=llm)
356
- else:
357
- raise ValueError("Invalid provider. Choose 'google', 'groq' or 'huggingface'.")
358
- # Bind tools to LLM
359
- """Build the graph"""
360
-
361
-
362
- llm_with_tools = llm.bind_tools(tools)
363
-
364
- # Node
365
- def assistant(state: MessagesState):
366
- print("\n--- Assistant Node ---")
367
- print("Incoming messages to assistant:")
368
- for msg in state["messages"]:
369
- msg.pretty_print() #
370
-
371
- """Assistant node"""
372
- return {"messages": [llm_with_tools.invoke(state["messages"])]}
373
-
374
- def retriever(state: MessagesState):
375
- """Retriever node"""
376
- similar_question = vector_store.similarity_search(state["messages"][0].content)
377
- example_msg = HumanMessage(
378
- content=f"Here I provide a similar question and answer for reference: \n\n{similar_question[0].page_content}",
379
- )
380
- print("ex msgs"+[sys_msg] + state["messages"] + [example_msg])
381
- return {"messages": [sys_msg] + state["messages"] + [example_msg]}
382
-
383
- builder = StateGraph(MessagesState)
384
- builder.add_node("retriever", retriever)
385
- builder.add_node("assistant", assistant)
386
- builder.add_node("tools", ToolNode(tools))
387
- builder.add_edge(START, "assistant")
388
- builder.add_edge("retriever", "assistant")
389
- builder.add_conditional_edges(
390
- "assistant",
391
- tools_condition,
392
- )
393
- builder.add_edge("tools", "assistant")
394
-
395
- # Compile graph
396
- compiled_graph = builder.compile() # This line should already be there or be the next line
397
-
398
- # --- START: Add this visualization code ---
399
- try:
400
- print("Attempting to generate graph visualization...")
401
- image_filename = "langgraph_state_diagram.png"
402
- # Using draw_mermaid_png as it's often more robust
403
- image_bytes = compiled_graph.get_graph().draw_mermaid_png()
404
- with open(image_filename, "wb") as f:
405
- f.write(image_bytes)
406
- print(f"SUCCESS: Graph visualization saved to '{image_filename}'")
407
-
408
- except ImportError as e:
409
- print(f"WARNING: Could not generate graph image due to missing package: {e}. "
410
- "Ensure 'pygraphviz' and 'graphviz' (system) are installed, or Mermaid components are available.")
411
- except Exception as e:
412
- print(f"WARNING: An error occurred while generating the graph image: {e}")
413
- try:
414
- print("\nGraph (DOT format as fallback):\n", compiled_graph.get_graph().to_string())
415
- except Exception as dot_e:
416
- print(f"Could not even get DOT string: {dot_e}")
417
- # --- END: Visualization code ---
418
-
419
- return compiled_graph # This should be the last line of the function
420
-
421
- # test
422
- if __name__ == "__main__":
423
- question = "When was a picture of St. Thomas Aquinas first added to the Wikipedia page on the Principle of double effect?"
424
- # Build the graph
425
- graph = build_graph(provider="huggingface")
426
- # Run the graph
427
- messages = [HumanMessage(content=question)]
428
-
429
- print(messages)
430
- config = {"recursion_limit": 27}
431
-
432
- messages = graph.invoke({"messages": messages}, config=config)
433
- for m in messages["messages"]:
434
- m.pretty_print()
435
-
436
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
requirements.txt CHANGED
@@ -1,5 +1,5 @@
1
  gradio
2
  requests
3
- smolagents==1.13.0
4
  pandas
5
  smolagents[openai]
 
1
  gradio
2
  requests
3
+ smolagents
4
  pandas
5
  smolagents[openai]