|
import logging |
|
import os |
|
import uuid |
|
from typing import AsyncIterator, Optional, TypedDict |
|
|
|
from dotenv import find_dotenv, load_dotenv |
|
from langgraph.checkpoint.memory import MemorySaver |
|
from langgraph.graph import END, START, StateGraph |
|
from smolagents import CodeAgent, LiteLLMModel |
|
from smolagents.memory import ActionStep, FinalAnswerStep |
|
from smolagents.monitoring import LogLevel |
|
|
|
|
|
logging.basicConfig( |
|
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" |
|
) |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
load_dotenv(find_dotenv()) |
|
|
|
|
|
API_BASE = os.getenv("API_BASE") |
|
API_KEY = os.getenv("API_KEY") |
|
MODEL_ID = os.getenv("MODEL_ID") |
|
|
|
if not all([API_BASE, API_KEY, MODEL_ID]): |
|
raise ValueError( |
|
"Missing required environment variables: API_BASE, API_KEY, MODEL_ID" |
|
) |
|
|
|
|
|
|
|
class AgentState(TypedDict): |
|
task: str |
|
current_step: Optional[dict] |
|
error: Optional[str] |
|
answer_text: Optional[str] |
|
|
|
|
|
|
|
try: |
|
model = LiteLLMModel( |
|
api_base=API_BASE, |
|
api_key=API_KEY, |
|
model_id=MODEL_ID, |
|
) |
|
except Exception as e: |
|
logger.error(f"Failed to initialize model: {str(e)}") |
|
raise |
|
|
|
|
|
try: |
|
agent = CodeAgent( |
|
add_base_tools=True, |
|
additional_authorized_imports=["pandas", "numpy"], |
|
max_steps=10, |
|
model=model, |
|
tools=[], |
|
step_callbacks=None, |
|
verbosity_level=LogLevel.ERROR, |
|
) |
|
agent.logger.console.width = 66 |
|
except Exception as e: |
|
logger.error(f"Failed to initialize agent: {str(e)}") |
|
raise |
|
|
|
|
|
async def process_step(state: AgentState) -> AgentState: |
|
"""Process a single step of the agent's execution.""" |
|
try: |
|
|
|
state["current_step"] = None |
|
state["answer_text"] = None |
|
state["error"] = None |
|
|
|
steps = agent.run( |
|
task=state["task"], |
|
additional_args=None, |
|
images=None, |
|
max_steps=1, |
|
stream=True, |
|
reset=False, |
|
) |
|
|
|
for step in steps: |
|
if isinstance(step, ActionStep): |
|
|
|
state["current_step"] = { |
|
"step_number": step.step_number, |
|
"model_output": step.model_output, |
|
"observations": step.observations, |
|
"tool_calls": [ |
|
{"name": tc.name, "arguments": tc.arguments} |
|
for tc in (step.tool_calls or []) |
|
], |
|
"action_output": step.action_output, |
|
} |
|
logger.info(f"Processed action step {step.step_number}") |
|
elif isinstance(step, FinalAnswerStep): |
|
state["answer_text"] = step.final_answer |
|
logger.info("Processed final answer") |
|
logger.debug(f"Final answer details: {step}") |
|
logger.info(f"Extracted answer text: {state['answer_text']}") |
|
|
|
return state |
|
|
|
return state |
|
except Exception as e: |
|
state["error"] = str(e) |
|
logger.error(f"Error during agent execution step: {str(e)}") |
|
return state |
|
|
|
|
|
def should_continue(state: AgentState) -> bool: |
|
"""Determine if the agent should continue processing steps.""" |
|
|
|
continue_execution = state.get("answer_text") is None and state.get("error") is None |
|
logger.debug( |
|
f"Checking should_continue: answer_text={state.get('answer_text') is not None}, error={state.get('error') is not None} -> Continue={continue_execution}" |
|
) |
|
return continue_execution |
|
|
|
|
|
|
|
memory = MemorySaver() |
|
builder = StateGraph(AgentState) |
|
builder.add_node("process_step", process_step) |
|
builder.add_edge(START, "process_step") |
|
builder.add_conditional_edges( |
|
"process_step", should_continue, {True: "process_step", False: END} |
|
) |
|
graph = builder.compile(checkpointer=memory) |
|
|
|
|
|
async def stream_execution(task: str, thread_id: str) -> AsyncIterator[AgentState]: |
|
"""Stream the execution of the agent.""" |
|
if not task: |
|
raise ValueError("Task cannot be empty") |
|
|
|
logger.info(f"Initializing agent execution for task: {task}") |
|
|
|
|
|
initial_state: AgentState = { |
|
"task": task, |
|
"current_step": None, |
|
"error": None, |
|
"answer_text": None, |
|
} |
|
|
|
|
|
async for state in graph.astream( |
|
initial_state, {"configurable": {"thread_id": thread_id}} |
|
): |
|
yield state |
|
|
|
if state.get("error") and not state.get("answer_text"): |
|
logger.error(f"Propagating error from stream: {state['error']}") |
|
raise Exception(state["error"]) |
|
|
|
|
|
async def run_with_streaming(task: str, thread_id: str) -> dict: |
|
"""Run the agent with streaming output and return the results.""" |
|
last_state = None |
|
steps = [] |
|
error = None |
|
final_answer_text = None |
|
|
|
try: |
|
logger.info(f"Starting execution run for task: {task}") |
|
async for state in stream_execution(task, thread_id): |
|
last_state = state |
|
|
|
if current_step := state.get("current_step"): |
|
if not steps or steps[-1]["step_number"] != current_step["step_number"]: |
|
steps.append(current_step) |
|
|
|
print(f"\nStep {current_step['step_number']}:") |
|
print(f"Model Output: {current_step['model_output']}") |
|
print(f"Observations: {current_step['observations']}") |
|
if current_step.get("tool_calls"): |
|
print("Tool Calls:") |
|
for tc in current_step["tool_calls"]: |
|
print(f" - {tc['name']}: {tc['arguments']}") |
|
if current_step.get("action_output"): |
|
print(f"Action Output: {current_step['action_output']}") |
|
|
|
|
|
logger.info("Stream finished.") |
|
if last_state: |
|
|
|
node_name = list(last_state.keys())[0] |
|
actual_state = last_state.get(node_name) |
|
if actual_state: |
|
final_answer_text = actual_state.get("answer_text") |
|
error = actual_state.get("error") |
|
logger.info( |
|
f"Final answer text extracted from last state: {final_answer_text}" |
|
) |
|
logger.info(f"Error extracted from last state: {error}") |
|
|
|
last_step_in_state = actual_state.get("current_step") |
|
if last_step_in_state and ( |
|
not steps |
|
or steps[-1]["step_number"] != last_step_in_state["step_number"] |
|
): |
|
logger.debug("Adding last step from final state to steps list.") |
|
steps.append(last_step_in_state) |
|
else: |
|
logger.warning( |
|
"Could not find actual state dictionary within last_state." |
|
) |
|
|
|
return {"steps": steps, "final_answer": final_answer_text, "error": error} |
|
|
|
except Exception as e: |
|
import traceback |
|
|
|
logger.error( |
|
f"Exception during run_with_streaming: {str(e)}\n{traceback.format_exc()}" |
|
) |
|
|
|
final_answer_text = None |
|
error_msg = str(e) |
|
if last_state: |
|
node_name = list(last_state.keys())[0] |
|
actual_state = last_state.get(node_name) |
|
if actual_state: |
|
final_answer_text = actual_state.get("answer_text") |
|
|
|
return {"steps": steps, "final_answer": final_answer_text, "error": error_msg} |
|
|
|
|
|
if __name__ == "__main__": |
|
import asyncio |
|
import uuid |
|
|
|
|
|
task_to_run = "What is the capital of France?" |
|
thread_id = str(uuid.uuid4()) |
|
logger.info( |
|
f"Starting agent run from __main__ for task: '{task_to_run}' with thread_id: {thread_id}" |
|
) |
|
result = asyncio.run(run_with_streaming(task_to_run, thread_id)) |
|
logger.info("Agent run finished.") |
|
|
|
|
|
print("\n--- Execution Results ---") |
|
print(f"Number of Steps: {len(result.get('steps', []))}") |
|
|
|
|
|
|
|
print(f"Final Answer: {result.get('final_answer') or 'Not found'}") |
|
if err := result.get("error"): |
|
print(f"Error: {err}") |
|
|