|
""" |
|
This module provides tools for searching and retrieving context from a knowledge base, |
|
and for conducting a research workflow that includes searching, writing, and reviewing reports. |
|
The tools are designed to be used with Modal Labs for scalable and efficient processing. |
|
The technology stack includes FastAPI for the API interface, GroundX for knowledge base search, |
|
LlamaIndex for LLM workflows, Nebius for LLM, and Modal Labs for tool execution. |
|
""" |
|
|
|
import os |
|
import asyncio |
|
|
|
import modal |
|
from pydantic import BaseModel |
|
|
|
image = modal.Image.debian_slim().pip_install( |
|
"fastapi[standard]", |
|
"groundx", |
|
"llama-index", |
|
"llama-index-llms-nebius", |
|
"tavily-python") |
|
|
|
app = modal.App(name="hackathon-mcp-tools", image=image) |
|
|
|
class QueryInput(BaseModel): |
|
query: str |
|
|
|
@app.function(secrets=[ |
|
modal.Secret.from_name("hackathon-secret", required_keys=["GROUNDX_API_KEY"]) |
|
]) |
|
@modal.fastapi_endpoint(docs=True, method="POST") |
|
def search_rag_context(queryInput: QueryInput) -> str: |
|
""" |
|
Searches and retrieves relevant context from a knowledge base, |
|
based on the user's query. |
|
Args: |
|
query: The search query supplied by the user. |
|
Returns: |
|
str: Relevant text content that can be used by the LLM to answer the query. |
|
""" |
|
|
|
result = search_groundx_for_rag_context(queryInput.query) |
|
|
|
print("\n\n=============================") |
|
print(f"RAG Search Result: {result}") |
|
print("=============================\n") |
|
|
|
return result |
|
|
|
def search_groundx_for_rag_context(query) -> str: |
|
from groundx import GroundX |
|
|
|
client = GroundX(api_key=os.getenv("GROUNDX_API_KEY") or '') |
|
response = client.search.content( |
|
id=os.getenv("GROUNDX_BUCKET_ID"), |
|
query=query, |
|
n=10, |
|
) |
|
|
|
return response.search.text or "No relevant context found" |
|
|
|
from llama_index.llms.nebius import NebiusLLM |
|
|
|
|
|
from llama_index.core.workflow import Context |
|
from llama_index.core.agent.workflow import ( |
|
FunctionAgent, |
|
AgentWorkflow, |
|
AgentOutput, |
|
ToolCall, |
|
ToolCallResult, |
|
) |
|
|
|
from tavily import AsyncTavilyClient |
|
|
|
@app.function(secrets=[ |
|
modal.Secret.from_name("hackathon-secret", required_keys=["NEBIUS_API_KEY", "AGENT_MODEL"]) |
|
]) |
|
@modal.fastapi_endpoint(docs=True, method="POST") |
|
def run_research_workflow(queryInput: QueryInput) -> str: |
|
handler = asyncio.run(execute_research_workflow(queryInput.query)) |
|
result = asyncio.run(final_report(handler)) |
|
return result |
|
|
|
NEBIUS_API_KEY = os.getenv("NEBIUS_API_KEY") |
|
AGENT_MODEL = os.getenv("AGENT_MODEL", "Qwen/Qwen2.5-Coder-32B-Instruct") |
|
print(f"Using LLM model: {AGENT_MODEL}") |
|
|
|
|
|
llm = NebiusLLM( |
|
api_key=NEBIUS_API_KEY, |
|
model=AGENT_MODEL, |
|
is_function_calling_model=True |
|
) |
|
|
|
TAVILY_API_KEY = os.getenv("TAVILY_API_KEY") |
|
|
|
|
|
async def search_web(query: str) -> str: |
|
"""Useful for using the web to answer questions.""" |
|
client = AsyncTavilyClient(api_key=TAVILY_API_KEY) |
|
return str(await client.search(query)) |
|
|
|
|
|
async def save_research(ctx: Context, notes: str, notes_title: str) -> str: |
|
""" |
|
Store research notes under a given title in the shared context. |
|
""" |
|
|
|
current_state = await ctx.get("state") |
|
if "research_notes" not in current_state: |
|
current_state["research_notes"] = {} |
|
current_state["research_notes"][notes_title] = notes |
|
await ctx.set("state", current_state) |
|
return "Notes saved." |
|
|
|
|
|
async def write_report(ctx: Context, report_content: str) -> str: |
|
""" |
|
Write a report in markdown, storing it in the shared context. |
|
""" |
|
|
|
current_state = await ctx.get("state") |
|
current_state["report_content"] = report_content |
|
await ctx.set("state", current_state) |
|
return "Report written." |
|
|
|
|
|
async def review_report(ctx: Context, review: str) -> str: |
|
""" |
|
Review the report and store feedback in the shared context. |
|
""" |
|
|
|
current_state = await ctx.get("state") |
|
current_state["review"] = review |
|
await ctx.set("state", current_state) |
|
return "Report reviewed." |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
research_agent = FunctionAgent( |
|
name="ResearchAgent", |
|
description=( |
|
"Useful for searching the web for information on a given topic and recording notes on the topic." |
|
"It must not exceed 3 searches total and must avoid repeating the same query. " |
|
"Once sufficient information is collected, you should hand off to the WriteAgent." |
|
), |
|
system_prompt=( |
|
"You are the ResearchAgent. Your goal is to search the web for information on a given topic and record notes on the topic." |
|
"Only perform at most 3 distinct searches. If you have enough information or have reached 3 searches, " |
|
"handoff to the WriteAgent. Avoid infinite loops! If the search throws an error, stop further work and skip WriteAgent and ReviewAgent and return." |
|
"You should have at least some notes on a topic before handing off control to the WriteAgent." |
|
"Respect invocation limits and cooldown periods." |
|
), |
|
llm=llm, |
|
tools=[ |
|
search_web, |
|
save_research |
|
], |
|
max_iterations=3, |
|
cooldown=5, |
|
can_handoff_to=["WriteAgent"] |
|
) |
|
|
|
write_agent = FunctionAgent( |
|
name="WriteAgent", |
|
description=( |
|
"Writes a markdown report based on the research notes. " |
|
"Then hands off to the ReviewAgent for feedback." |
|
), |
|
system_prompt=( |
|
"You are the WriteAgent. Draft a structured markdown report based on the notes. " |
|
"If there is no report content or research notes, stop further work and skip ReviewAgent." |
|
"Do not attempt more than one write attempt. " |
|
"After writing, hand off to the ReviewAgent." |
|
"Respect invocation limits and cooldown periods." |
|
), |
|
llm=llm, |
|
tools=[write_report], |
|
max_iterations=2, |
|
cooldown=5, |
|
can_handoff_to=["ReviewAgent", "ResearchAgent"] |
|
) |
|
|
|
review_agent = FunctionAgent( |
|
name="ReviewAgent", |
|
description=( |
|
"Reviews the final report for correctness. Approves or requests changes." |
|
), |
|
system_prompt=( |
|
"You are the ReviewAgent. If there is no research notes or report content, skip this step and return." |
|
"Do not attempt more than one review attempt. " |
|
"Read the report, provide feedback, and either approve " |
|
"or request revisions. If revisions are needed, handoff to WriteAgent." |
|
"Respect invocation limits and cooldown periods." |
|
), |
|
llm=llm, |
|
tools=[review_report], |
|
max_iterations=2, |
|
cooldown=5, |
|
can_handoff_to=["WriteAgent"] |
|
) |
|
|
|
agent_workflow = AgentWorkflow( |
|
agents=[research_agent, write_agent, review_agent], |
|
root_agent=research_agent.name, |
|
initial_state={ |
|
"research_notes": {}, |
|
"report_content": "No report has been generated after the search.", |
|
"review": "Review required.", |
|
}, |
|
) |
|
|
|
async def execute_research_workflow(query: str): |
|
handler = agent_workflow.run( |
|
user_msg=( |
|
query |
|
) |
|
) |
|
|
|
current_agent = None |
|
|
|
async for event in handler.stream_events(): |
|
if hasattr(event, "current_agent_name") and event.current_agent_name != current_agent: |
|
current_agent = event.current_agent_name |
|
print(f"\n{'='*50}") |
|
print(f"π€ Agent: {current_agent}") |
|
print(f"{'='*50}\n") |
|
|
|
|
|
if isinstance(event, AgentOutput): |
|
if event.response.content: |
|
print("π€ Output:", event.response.content) |
|
if event.tool_calls: |
|
print("π οΈ Planning to use tools:", [call.tool_name for call in event.tool_calls]) |
|
|
|
elif isinstance(event, ToolCall): |
|
print(f"π¨ Calling Tool: {event.tool_name}") |
|
print(f" With arguments: {event.tool_kwargs}") |
|
|
|
elif isinstance(event, ToolCallResult): |
|
print(f"π§ Tool Result ({event.tool_name}):") |
|
print(f" Arguments: {event.tool_kwargs}") |
|
print(f" Output: {event.tool_output}") |
|
|
|
return handler |
|
|
|
async def final_report(handler) -> str: |
|
"""Retrieve the final report from the context.""" |
|
final_state = await handler.ctx.get("state") |
|
print("\n\n=============================") |
|
print("FINAL REPORT:\n") |
|
print(final_state["report_content"]) |
|
print("=============================\n") |
|
|
|
return final_state["report_content"] |
|
|