mcp-rag-workflow / tools /multi_agent_workflow_for_research.py
Rajesh Betkiker
Added tools
0321eee
raw
history blame
8.04 kB
"""
This file contains the multi-agent workflow for the research project.
Using LlamaIndex built a modular, intelligent multi-agent workflow.
With real-time tools and structured memory.
The workflow is as follows:
1. The ResearchAgent searches the web for information.
2. The WriteAgent writes a report based on the research notes.
3. The ReviewAgent reviews the report and provides feedback.
"""
import os
import asyncio
# Load environment variables from .env file
from dotenv import load_dotenv
load_dotenv()
from llama_index.llms.nebius import NebiusLLM
# llama-index workflow classes
from llama_index.core.workflow import Context
from llama_index.core.agent.workflow import (
FunctionAgent,
AgentWorkflow,
AgentOutput,
ToolCall,
ToolCallResult,
)
from langchain.utilities import DuckDuckGoSearchAPIWrapper
NEBIUS_API_KEY = os.getenv("NEBIUS_API_KEY")
# Load an LLM
llm = NebiusLLM(
api_key=NEBIUS_API_KEY,
model="meta-llama/Meta-Llama-3.1-8B-Instruct",
is_function_calling_model=True
)
# Search tools using DuckDuckGo
duckduckgo = DuckDuckGoSearchAPIWrapper()
MAX_SEARCH_CALLS = 2 # Limit the number of searches to 2
search_call_count = 0
past_queries = set()
async def safe_duckduckgo_search(query: str) -> str:
"""
A DuckDuckGo-based search function that:
- Prevents more than MAX_SEARCH_CALLS total searches.
- Skips duplicate queries.
"""
global search_call_count, past_queries
# Check for duplicate queries
if query in past_queries:
return f"Already searched for '{query}'. Avoiding duplicate search."
# Check if we've reached the max search calls
if search_call_count >= MAX_SEARCH_CALLS:
return "Search limit reached, no more searches allowed."
# Otherwise, perform the search
search_call_count += 1
past_queries.add(query)
# DuckDuckGoSearchAPIWrapper.run(...) is synchronous, but we have an async signature
result = duckduckgo.run(query)
return str(result)
# Research tools
async def save_research(ctx: Context, notes: str, notes_title: str) -> str:
"""
Store research notes under a given title in the shared context.
"""
current_state = await ctx.get("state")
if "research_notes" not in current_state:
current_state["research_notes"] = {}
current_state["research_notes"][notes_title] = notes
await ctx.set("state", current_state)
return "Notes saved."
# Report tools
async def write_report(ctx: Context, report_content: str) -> str:
"""
Write a report in markdown, storing it in the shared context.
"""
current_state = await ctx.get("state")
current_state["report_content"] = report_content
await ctx.set("state", current_state)
return "Report written."
# Review tools
async def review_report(ctx: Context, review: str) -> str:
"""
Review the report and store feedback in the shared context.
"""
current_state = await ctx.get("state")
current_state["review"] = review
await ctx.set("state", current_state)
return "Report reviewed."
# We have three agents with distinct responsibilities:
# - The ResearchAgent is responsible for gathering information from the web.
# - The WriteAgent is responsible for writing the report.
# - The ReviewAgent is responsible for reviewing the report.
# The ResearchAgent uses the DuckDuckGoSearchAPIWrapper to search the web.
research_agent = FunctionAgent(
name="ResearchAgent",
description=(
"A research agent that searches the web using Google search through SerpAPI. "
"It must not exceed 2 searches total, and must avoid repeating the same query. "
"Once sufficient information is collected, it should hand off to the WriteAgent."
),
system_prompt=(
"You are the ResearchAgent. Your goal is to gather sufficient information on the topic. "
"Only perform at most 2 distinct searches. If you have enough information or have reached 2 searches, "
"handoff to the WriteAgent. Avoid infinite loops! If search throws an error, stop further work and skip WriteAgent and ReviewAgent and return."
"Respect invocation limits and cooldown periods."
),
llm=llm,
tools=[
safe_duckduckgo_search,
save_research
],
max_iterations=2, # Limit to 2 iterations to prevent infinite loops
cooldown=5, # Cooldown to prevent rapid re-querying
can_handoff_to=["WriteAgent"]
)
write_agent = FunctionAgent(
name="WriteAgent",
description=(
"Writes a markdown report based on the research notes. "
"Then hands off to the ReviewAgent for feedback."
),
system_prompt=(
"You are the WriteAgent. Draft a structured markdown report based on the notes. "
"If there is no report content or research notes, stop further work and skip ReviewAgent."
"Do not attempt more than one write attempt. "
"After writing, hand off to the ReviewAgent."
"Respect invocation limits and cooldown periods."
),
llm=llm,
tools=[write_report],
max_iterations=2, # Limit to 2 iterations to prevent infinite loops
cooldown=5, # Cooldown to prevent rapid re-querying
can_handoff_to=["ReviewAgent", "ResearchAgent"]
)
review_agent = FunctionAgent(
name="ReviewAgent",
description=(
"Reviews the final report for correctness. Approves or requests changes."
),
system_prompt=(
"You are the ReviewAgent. If there is no research notes or report content, skip this step and return."
"Do not attempt more than one review attempt. "
"Read the report, provide feedback, and either approve "
"or request revisions. If revisions are needed, handoff to WriteAgent."
"Respect invocation limits and cooldown periods."
),
llm=llm,
tools=[review_report],
max_iterations=2, # Limit to 2 iterations to prevent infinite loops
cooldown=5, # Cooldown to prevent rapid re-querying
can_handoff_to=["WriteAgent"]
)
agent_workflow = AgentWorkflow(
agents=[research_agent, write_agent, review_agent],
root_agent=research_agent.name, # Start with the ResearchAgent
initial_state={
"research_notes": {},
"report_content": "Not written yet.",
"review": "Review required.",
},
)
async def execute_research_workflow(query: str):
handler = agent_workflow.run(
user_msg=(
query
)
)
current_agent = None
async for event in handler.stream_events():
if hasattr(event, "current_agent_name") and event.current_agent_name != current_agent:
current_agent = event.current_agent_name
print(f"\n{'='*50}")
print(f"πŸ€– Agent: {current_agent}")
print(f"{'='*50}\n")
# Print outputs or tool calls
if isinstance(event, AgentOutput):
if event.response.content:
print("πŸ“€ Output:", event.response.content)
if event.tool_calls:
print("πŸ› οΈ Planning to use tools:", [call.tool_name for call in event.tool_calls])
elif isinstance(event, ToolCall):
print(f"πŸ”¨ Calling Tool: {event.tool_name}")
print(f" With arguments: {event.tool_kwargs}")
elif isinstance(event, ToolCallResult):
print(f"πŸ”§ Tool Result ({event.tool_name}):")
print(f" Arguments: {event.tool_kwargs}")
print(f" Output: {event.tool_output}")
return handler
async def final_report(handler) -> str:
"""Retrieve the final report from the context."""
final_state = await handler.ctx.get("state")
print("\n\n=============================")
print("FINAL REPORT:\n")
print(final_state["report_content"])
print("=============================\n")
return final_state["report_content"]
def run_research_workflow(query: str):
handler = asyncio.run(execute_research_workflow(query))
result = asyncio.run(final_report(handler))
return result