betki commited on
Commit
184f36a
Β·
verified Β·
1 Parent(s): d8571a8

Added new tools

Browse files
Files changed (1) hide show
  1. tools/tools_on_modal_labs.py +280 -0
tools/tools_on_modal_labs.py ADDED
@@ -0,0 +1,280 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This module provides tools for searching and retrieving context from a knowledge base,
3
+ and for conducting a research workflow that includes searching, writing, and reviewing reports.
4
+ The tools are designed to be used with Modal Labs for scalable and efficient processing.
5
+ The technology stack includes FastAPI for the API interface, GroundX for knowledge base search,
6
+ LlamaIndex for LLM workflows, Nebius for LLM, and Modal Labs for tool execution.
7
+ """
8
+
9
+ import os
10
+ import asyncio
11
+
12
+ import modal
13
+ from pydantic import BaseModel
14
+
15
+ image = modal.Image.debian_slim().pip_install(
16
+ "fastapi[standard]",
17
+ "groundx",
18
+ "llama-index",
19
+ "llama-index-llms-nebius",
20
+ "duckduckgo-search",
21
+ "langchain-community")
22
+
23
+ app = modal.App(name="hackathon-mcp-tools", image=image)
24
+
25
+ class QueryInput(BaseModel):
26
+ query: str
27
+
28
+ @app.function(secrets=[
29
+ modal.Secret.from_name("hackathon-secret", required_keys=["GROUNDX_API_KEY"])
30
+ ])
31
+ @modal.fastapi_endpoint(docs=True, method="POST")
32
+ def search_rag_context(queryInput: QueryInput) -> str:
33
+ """
34
+ Searches and retrieves relevant context from a knowledge base,
35
+ based on the user's query.
36
+ Args:
37
+ query: The search query supplied by the user.
38
+ Returns:
39
+ str: Relevant text content that can be used by the LLM to answer the query.
40
+ """
41
+
42
+ result = search_groundx_for_rag_context(queryInput.query)
43
+
44
+ print("\n\n=============================")
45
+ print(f"RAG Search Result: {result}")
46
+ print("=============================\n")
47
+
48
+ return
49
+
50
+ def search_groundx_for_rag_context(query) -> str:
51
+ from groundx import GroundX
52
+
53
+ client = GroundX(api_key=os.getenv("GROUNDX_API_KEY") or '')
54
+ response = client.search.content(
55
+ id=os.getenv("GROUNDX_BUCKET_ID"),
56
+ query=query,
57
+ n=10,
58
+ )
59
+
60
+ return response.search.text or "No relevant context found"
61
+
62
+ from llama_index.llms.nebius import NebiusLLM
63
+
64
+ # llama-index workflow classes
65
+ from llama_index.core.workflow import Context
66
+ from llama_index.core.agent.workflow import (
67
+ FunctionAgent,
68
+ AgentWorkflow,
69
+ AgentOutput,
70
+ ToolCall,
71
+ ToolCallResult,
72
+ )
73
+
74
+ from langchain.utilities import DuckDuckGoSearchAPIWrapper
75
+
76
+ @app.function(secrets=[
77
+ modal.Secret.from_name("hackathon-secret", required_keys=["NEBIUS_API_KEY", "AGENT_MODEL"])
78
+ ])
79
+ @modal.fastapi_endpoint(docs=True, method="POST")
80
+ def run_research_workflow(queryInput: QueryInput) -> str:
81
+ handler = asyncio.run(execute_research_workflow(queryInput.query))
82
+ result = asyncio.run(final_report(handler))
83
+ return result
84
+
85
+ NEBIUS_API_KEY = os.getenv("NEBIUS_API_KEY")
86
+ AGENT_MODEL = os.getenv("AGENT_MODEL", "meta-llama/Meta-Llama-3.1-8B-Instruct")
87
+
88
+ # Load an LLM
89
+ llm = NebiusLLM(
90
+ api_key=NEBIUS_API_KEY,
91
+ model=AGENT_MODEL,
92
+ is_function_calling_model=True
93
+ )
94
+
95
+ # Search tools using DuckDuckGo
96
+ duckduckgo = DuckDuckGoSearchAPIWrapper()
97
+
98
+ MAX_SEARCH_CALLS = 2 # Limit the number of searches to 2
99
+ search_call_count = 0
100
+ past_queries = set()
101
+
102
+ async def duckduckgo_search(query: str) -> str:
103
+ """
104
+ A DuckDuckGo-based search limiting number of searches and avoiding duplicates.
105
+ """
106
+ global search_call_count, past_queries
107
+
108
+ # Check for duplicate queries
109
+ if query in past_queries:
110
+ return f"Already searched for '{query}'."
111
+
112
+ # Check if we've reached the max search calls
113
+ if search_call_count >= MAX_SEARCH_CALLS:
114
+ return "Search limit reached."
115
+
116
+ # Otherwise, perform the search
117
+ search_call_count += 1
118
+ past_queries.add(query)
119
+
120
+ result = duckduckgo.run(query)
121
+ return str(result)
122
+
123
+ # Research tools
124
+ async def save_research(ctx: Context, notes: str, notes_title: str) -> str:
125
+ """
126
+ Store research notes under a given title in the shared context.
127
+ """
128
+
129
+ current_state = await ctx.get("state")
130
+ if "research_notes" not in current_state:
131
+ current_state["research_notes"] = {}
132
+ current_state["research_notes"][notes_title] = notes
133
+ await ctx.set("state", current_state)
134
+ return "Notes saved."
135
+
136
+ # Report tools
137
+ async def write_report(ctx: Context, report_content: str) -> str:
138
+ """
139
+ Write a report in markdown, storing it in the shared context.
140
+ """
141
+
142
+ current_state = await ctx.get("state")
143
+ current_state["report_content"] = report_content
144
+ await ctx.set("state", current_state)
145
+ return "Report written."
146
+
147
+ # Review tools
148
+ async def review_report(ctx: Context, review: str) -> str:
149
+ """
150
+ Review the report and store feedback in the shared context.
151
+ """
152
+
153
+ current_state = await ctx.get("state")
154
+ current_state["review"] = review
155
+ await ctx.set("state", current_state)
156
+ return "Report reviewed."
157
+
158
+
159
+ # We have three agents with distinct responsibilities:
160
+ # - The ResearchAgent is responsible for gathering information from the web.
161
+ # - The WriteAgent is responsible for writing the report.
162
+ # - The ReviewAgent is responsible for reviewing the report.
163
+
164
+ # The ResearchAgent uses the DuckDuckGoSearchAPIWrapper to search the web.
165
+
166
+ research_agent = FunctionAgent(
167
+ name="ResearchAgent",
168
+ description=(
169
+ "A research agent that searches the web using Google search through SerpAPI. "
170
+ "It must not exceed 2 searches total, and must avoid repeating the same query. "
171
+ "Once sufficient information is collected, it should hand off to the WriteAgent."
172
+ ),
173
+ system_prompt=(
174
+ "You are the ResearchAgent. Your goal is to gather sufficient information on the topic. "
175
+ "Only perform at most 2 distinct searches. If you have enough information or have reached 2 searches, "
176
+ "handoff to the WriteAgent. Avoid infinite loops! If search throws an error, stop further work and skip WriteAgent and ReviewAgent and return."
177
+ "Respect invocation limits and cooldown periods."
178
+ ),
179
+ llm=llm,
180
+ tools=[
181
+ duckduckgo_search,
182
+ save_research
183
+ ],
184
+ max_iterations=2, # Limit to 2 iterations to prevent infinite loops
185
+ cooldown=5, # Cooldown to prevent rapid re-querying
186
+ can_handoff_to=["WriteAgent"]
187
+ )
188
+
189
+ write_agent = FunctionAgent(
190
+ name="WriteAgent",
191
+ description=(
192
+ "Writes a markdown report based on the research notes. "
193
+ "Then hands off to the ReviewAgent for feedback."
194
+ ),
195
+ system_prompt=(
196
+ "You are the WriteAgent. Draft a structured markdown report based on the notes. "
197
+ "If there is no report content or research notes, stop further work and skip ReviewAgent."
198
+ "Do not attempt more than one write attempt. "
199
+ "After writing, hand off to the ReviewAgent."
200
+ "Respect invocation limits and cooldown periods."
201
+ ),
202
+ llm=llm,
203
+ tools=[write_report],
204
+ max_iterations=2, # Limit to 2 iterations to prevent infinite loops
205
+ cooldown=5, # Cooldown to prevent rapid re-querying
206
+ can_handoff_to=["ReviewAgent", "ResearchAgent"]
207
+ )
208
+
209
+ review_agent = FunctionAgent(
210
+ name="ReviewAgent",
211
+ description=(
212
+ "Reviews the final report for correctness. Approves or requests changes."
213
+ ),
214
+ system_prompt=(
215
+ "You are the ReviewAgent. If there is no research notes or report content, skip this step and return."
216
+ "Do not attempt more than one review attempt. "
217
+ "Read the report, provide feedback, and either approve "
218
+ "or request revisions. If revisions are needed, handoff to WriteAgent."
219
+ "Respect invocation limits and cooldown periods."
220
+ ),
221
+ llm=llm,
222
+ tools=[review_report],
223
+ max_iterations=2, # Limit to 2 iterations to prevent infinite loops
224
+ cooldown=5, # Cooldown to prevent rapid re-querying
225
+ can_handoff_to=["WriteAgent"]
226
+ )
227
+
228
+ agent_workflow = AgentWorkflow(
229
+ agents=[research_agent, write_agent, review_agent],
230
+ root_agent=research_agent.name, # Start with the ResearchAgent
231
+ initial_state={
232
+ "research_notes": {},
233
+ "report_content": "Not written yet.",
234
+ "review": "Review required.",
235
+ },
236
+ )
237
+
238
+ async def execute_research_workflow(query: str):
239
+ handler = agent_workflow.run(
240
+ user_msg=(
241
+ query
242
+ )
243
+ )
244
+
245
+ current_agent = None
246
+
247
+ async for event in handler.stream_events():
248
+ if hasattr(event, "current_agent_name") and event.current_agent_name != current_agent:
249
+ current_agent = event.current_agent_name
250
+ print(f"\n{'='*50}")
251
+ print(f"πŸ€– Agent: {current_agent}")
252
+ print(f"{'='*50}\n")
253
+
254
+ # Print outputs or tool calls
255
+ if isinstance(event, AgentOutput):
256
+ if event.response.content:
257
+ print("πŸ“€ Output:", event.response.content)
258
+ if event.tool_calls:
259
+ print("πŸ› οΈ Planning to use tools:", [call.tool_name for call in event.tool_calls])
260
+
261
+ elif isinstance(event, ToolCall):
262
+ print(f"πŸ”¨ Calling Tool: {event.tool_name}")
263
+ print(f" With arguments: {event.tool_kwargs}")
264
+
265
+ elif isinstance(event, ToolCallResult):
266
+ print(f"πŸ”§ Tool Result ({event.tool_name}):")
267
+ print(f" Arguments: {event.tool_kwargs}")
268
+ print(f" Output: {event.tool_output}")
269
+
270
+ return handler
271
+
272
+ async def final_report(handler) -> str:
273
+ """Retrieve the final report from the context."""
274
+ final_state = await handler.ctx.get("state")
275
+ print("\n\n=============================")
276
+ print("FINAL REPORT:\n")
277
+ print(final_state["report_content"])
278
+ print("=============================\n")
279
+
280
+ return final_state["report_content"]