|
""" |
|
This module provides tools for searching and retrieving context from a knowledge base, |
|
and for conducting a research workflow that includes searching, writing, and reviewing reports. |
|
The tools are designed to be used with Modal Labs for scalable and efficient processing. |
|
The technology stack includes FastAPI for the API interface, GroundX for knowledge base search, |
|
LlamaIndex for LLM workflows, Nebius for LLM, and Modal Labs for tool execution. |
|
""" |
|
|
|
import os |
|
import asyncio |
|
|
|
import modal |
|
from pydantic import BaseModel |
|
|
|
image = modal.Image.debian_slim().pip_install( |
|
"fastapi[standard]", |
|
"groundx", |
|
"llama-index", |
|
"llama-index-llms-nebius", |
|
"duckduckgo-search", |
|
"langchain-community") |
|
|
|
app = modal.App(name="hackathon-mcp-tools", image=image) |
|
|
|
class QueryInput(BaseModel): |
|
query: str |
|
|
|
@app.function(secrets=[ |
|
modal.Secret.from_name("hackathon-secret", required_keys=["GROUNDX_API_KEY"]) |
|
]) |
|
@modal.fastapi_endpoint(docs=True, method="POST") |
|
def search_rag_context(queryInput: QueryInput) -> str: |
|
""" |
|
Searches and retrieves relevant context from a knowledge base, |
|
based on the user's query. |
|
Args: |
|
query: The search query supplied by the user. |
|
Returns: |
|
str: Relevant text content that can be used by the LLM to answer the query. |
|
""" |
|
|
|
result = search_groundx_for_rag_context(queryInput.query) |
|
|
|
print("\n\n=============================") |
|
print(f"RAG Search Result: {result}") |
|
print("=============================\n") |
|
|
|
return |
|
|
|
def search_groundx_for_rag_context(query) -> str: |
|
from groundx import GroundX |
|
|
|
client = GroundX(api_key=os.getenv("GROUNDX_API_KEY") or '') |
|
response = client.search.content( |
|
id=os.getenv("GROUNDX_BUCKET_ID"), |
|
query=query, |
|
n=10, |
|
) |
|
|
|
return response.search.text or "No relevant context found" |
|
|
|
from llama_index.llms.nebius import NebiusLLM |
|
|
|
|
|
from llama_index.core.workflow import Context |
|
from llama_index.core.agent.workflow import ( |
|
FunctionAgent, |
|
AgentWorkflow, |
|
AgentOutput, |
|
ToolCall, |
|
ToolCallResult, |
|
) |
|
|
|
from langchain.utilities import DuckDuckGoSearchAPIWrapper |
|
|
|
@app.function(secrets=[ |
|
modal.Secret.from_name("hackathon-secret", required_keys=["NEBIUS_API_KEY", "AGENT_MODEL"]) |
|
]) |
|
@modal.fastapi_endpoint(docs=True, method="POST") |
|
def run_research_workflow(queryInput: QueryInput) -> str: |
|
handler = asyncio.run(execute_research_workflow(queryInput.query)) |
|
result = asyncio.run(final_report(handler)) |
|
return result |
|
|
|
NEBIUS_API_KEY = os.getenv("NEBIUS_API_KEY") |
|
AGENT_MODEL = os.getenv("AGENT_MODEL", "meta-llama/Meta-Llama-3.1-8B-Instruct") |
|
|
|
|
|
llm = NebiusLLM( |
|
api_key=NEBIUS_API_KEY, |
|
model=AGENT_MODEL, |
|
is_function_calling_model=True |
|
) |
|
|
|
|
|
duckduckgo = DuckDuckGoSearchAPIWrapper() |
|
|
|
MAX_SEARCH_CALLS = 2 |
|
search_call_count = 0 |
|
past_queries = set() |
|
|
|
async def duckduckgo_search(query: str) -> str: |
|
""" |
|
A DuckDuckGo-based search limiting number of searches and avoiding duplicates. |
|
""" |
|
global search_call_count, past_queries |
|
|
|
|
|
if query in past_queries: |
|
return f"Already searched for '{query}'." |
|
|
|
|
|
if search_call_count >= MAX_SEARCH_CALLS: |
|
return "Search limit reached." |
|
|
|
|
|
search_call_count += 1 |
|
past_queries.add(query) |
|
|
|
result = duckduckgo.run(query) |
|
return str(result) |
|
|
|
|
|
async def save_research(ctx: Context, notes: str, notes_title: str) -> str: |
|
""" |
|
Store research notes under a given title in the shared context. |
|
""" |
|
|
|
current_state = await ctx.get("state") |
|
if "research_notes" not in current_state: |
|
current_state["research_notes"] = {} |
|
current_state["research_notes"][notes_title] = notes |
|
await ctx.set("state", current_state) |
|
return "Notes saved." |
|
|
|
|
|
async def write_report(ctx: Context, report_content: str) -> str: |
|
""" |
|
Write a report in markdown, storing it in the shared context. |
|
""" |
|
|
|
current_state = await ctx.get("state") |
|
current_state["report_content"] = report_content |
|
await ctx.set("state", current_state) |
|
return "Report written." |
|
|
|
|
|
async def review_report(ctx: Context, review: str) -> str: |
|
""" |
|
Review the report and store feedback in the shared context. |
|
""" |
|
|
|
current_state = await ctx.get("state") |
|
current_state["review"] = review |
|
await ctx.set("state", current_state) |
|
return "Report reviewed." |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
research_agent = FunctionAgent( |
|
name="ResearchAgent", |
|
description=( |
|
"A research agent that searches the web using Google search through SerpAPI. " |
|
"It must not exceed 2 searches total, and must avoid repeating the same query. " |
|
"Once sufficient information is collected, it should hand off to the WriteAgent." |
|
), |
|
system_prompt=( |
|
"You are the ResearchAgent. Your goal is to gather sufficient information on the topic. " |
|
"Only perform at most 2 distinct searches. If you have enough information or have reached 2 searches, " |
|
"handoff to the WriteAgent. Avoid infinite loops! If search throws an error, stop further work and skip WriteAgent and ReviewAgent and return." |
|
"Respect invocation limits and cooldown periods." |
|
), |
|
llm=llm, |
|
tools=[ |
|
duckduckgo_search, |
|
save_research |
|
], |
|
max_iterations=2, |
|
cooldown=5, |
|
can_handoff_to=["WriteAgent"] |
|
) |
|
|
|
write_agent = FunctionAgent( |
|
name="WriteAgent", |
|
description=( |
|
"Writes a markdown report based on the research notes. " |
|
"Then hands off to the ReviewAgent for feedback." |
|
), |
|
system_prompt=( |
|
"You are the WriteAgent. Draft a structured markdown report based on the notes. " |
|
"If there is no report content or research notes, stop further work and skip ReviewAgent." |
|
"Do not attempt more than one write attempt. " |
|
"After writing, hand off to the ReviewAgent." |
|
"Respect invocation limits and cooldown periods." |
|
), |
|
llm=llm, |
|
tools=[write_report], |
|
max_iterations=2, |
|
cooldown=5, |
|
can_handoff_to=["ReviewAgent", "ResearchAgent"] |
|
) |
|
|
|
review_agent = FunctionAgent( |
|
name="ReviewAgent", |
|
description=( |
|
"Reviews the final report for correctness. Approves or requests changes." |
|
), |
|
system_prompt=( |
|
"You are the ReviewAgent. If there is no research notes or report content, skip this step and return." |
|
"Do not attempt more than one review attempt. " |
|
"Read the report, provide feedback, and either approve " |
|
"or request revisions. If revisions are needed, handoff to WriteAgent." |
|
"Respect invocation limits and cooldown periods." |
|
), |
|
llm=llm, |
|
tools=[review_report], |
|
max_iterations=2, |
|
cooldown=5, |
|
can_handoff_to=["WriteAgent"] |
|
) |
|
|
|
agent_workflow = AgentWorkflow( |
|
agents=[research_agent, write_agent, review_agent], |
|
root_agent=research_agent.name, |
|
initial_state={ |
|
"research_notes": {}, |
|
"report_content": "Not written yet.", |
|
"review": "Review required.", |
|
}, |
|
) |
|
|
|
async def execute_research_workflow(query: str): |
|
handler = agent_workflow.run( |
|
user_msg=( |
|
query |
|
) |
|
) |
|
|
|
current_agent = None |
|
|
|
async for event in handler.stream_events(): |
|
if hasattr(event, "current_agent_name") and event.current_agent_name != current_agent: |
|
current_agent = event.current_agent_name |
|
print(f"\n{'='*50}") |
|
print(f"π€ Agent: {current_agent}") |
|
print(f"{'='*50}\n") |
|
|
|
|
|
if isinstance(event, AgentOutput): |
|
if event.response.content: |
|
print("π€ Output:", event.response.content) |
|
if event.tool_calls: |
|
print("π οΈ Planning to use tools:", [call.tool_name for call in event.tool_calls]) |
|
|
|
elif isinstance(event, ToolCall): |
|
print(f"π¨ Calling Tool: {event.tool_name}") |
|
print(f" With arguments: {event.tool_kwargs}") |
|
|
|
elif isinstance(event, ToolCallResult): |
|
print(f"π§ Tool Result ({event.tool_name}):") |
|
print(f" Arguments: {event.tool_kwargs}") |
|
print(f" Output: {event.tool_output}") |
|
|
|
return handler |
|
|
|
async def final_report(handler) -> str: |
|
"""Retrieve the final report from the context.""" |
|
final_state = await handler.ctx.get("state") |
|
print("\n\n=============================") |
|
print("FINAL REPORT:\n") |
|
print(final_state["report_content"]) |
|
print("=============================\n") |
|
|
|
return final_state["report_content"] |
|
|