"""Define the agent graph and its components.""" import logging import os from typing import Dict, List, Optional, TypedDict, Union import yaml from langchain_core.messages import AIMessage, HumanMessage, SystemMessage from langchain_core.runnables import RunnableConfig from langgraph.graph import END, StateGraph from langgraph.types import interrupt from smolagents import CodeAgent, LiteLLMModel from configuration import Configuration from tools import tools # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Enable LiteLLM debug logging only if environment variable is set import litellm if os.getenv("LITELLM_DEBUG", "false").lower() == "true": litellm.set_verbose = True logger.setLevel(logging.DEBUG) else: litellm.set_verbose = False logger.setLevel(logging.INFO) # Configure LiteLLM to drop unsupported parameters litellm.drop_params = True # Load default prompt templates from local file current_dir = os.path.dirname(os.path.abspath(__file__)) prompts_dir = os.path.join(current_dir, "prompts") yaml_path = os.path.join(prompts_dir, "code_agent.yaml") with open(yaml_path, "r") as f: prompt_templates = yaml.safe_load(f) # Initialize the model and agent using configuration config = Configuration() model = LiteLLMModel( api_base=config.api_base, api_key=config.api_key, model_id=config.model_id, ) agent = CodeAgent( add_base_tools=True, max_steps=1, # Execute one step at a time model=model, prompt_templates=prompt_templates, tools=tools, verbosity_level=logging.DEBUG, ) class AgentState(TypedDict): """State for the agent graph.""" messages: List[Union[HumanMessage, AIMessage, SystemMessage]] question: str answer: Optional[str] step_logs: List[Dict] is_complete: bool step_count: int class AgentNode: """Node that runs the agent.""" def __init__(self, agent: CodeAgent): """Initialize the agent node with an agent.""" self.agent = agent def __call__( self, state: AgentState, config: Optional[RunnableConfig] = None ) -> AgentState: """Run the agent on the current state.""" # Log current state logger.info("Current state before processing:") logger.info(f"Messages: {state['messages']}") logger.info(f"Question: {state['question']}") logger.info(f"Answer: {state['answer']}") # Get configuration cfg = Configuration.from_runnable_config(config) logger.info(f"Using configuration: {cfg}") # Log execution start logger.info("Starting agent execution") # Run the agent result = self.agent.run(state["question"]) # Log result logger.info(f"Agent execution result type: {type(result)}") logger.info(f"Agent execution result value: {result}") # Update state new_state = state.copy() new_state["messages"].append(AIMessage(content=result)) new_state["answer"] = result new_state["step_count"] += 1 # Log updated state logger.info("Updated state after processing:") logger.info(f"Messages: {new_state['messages']}") logger.info(f"Question: {new_state['question']}") logger.info(f"Answer: {new_state['answer']}") return new_state class StepCallbackNode: """Node that handles step callbacks and user interaction.""" def __call__( self, state: AgentState, config: Optional[RunnableConfig] = None ) -> AgentState: """Handle step callback and user interaction.""" # Get configuration cfg = Configuration.from_runnable_config(config) # Log the step step_log = { "step": state["step_count"], "messages": [msg.content for msg in state["messages"]], "question": state["question"], "answer": state["answer"], } state["step_logs"].append(step_log) try: # Use interrupt for user input user_input = interrupt( "Press 'c' to continue, 'q' to quit, or 'i' for more info: " ) if user_input.lower() == "q": state["is_complete"] = True return state elif user_input.lower() == "i": logger.info(f"Current step: {state['step_count']}") logger.info(f"Question: {state['question']}") logger.info(f"Current answer: {state['answer']}") return self(state, config) # Recursively call for new input elif user_input.lower() == "c": return state else: logger.warning("Invalid input. Please use 'c', 'q', or 'i'.") return self(state, config) # Recursively call for new input except Exception as e: logger.warning(f"Error during interrupt: {str(e)}") return state def build_agent_graph(agent: AgentNode) -> StateGraph: """Build the agent graph.""" # Initialize the graph workflow = StateGraph(AgentState) # Add nodes workflow.add_node("agent", agent) workflow.add_node("callback", StepCallbackNode()) # Add edges workflow.add_edge("agent", "callback") workflow.add_conditional_edges( "callback", lambda x: END if x["is_complete"] else "agent", {True: END, False: "agent"}, ) # Set entry point workflow.set_entry_point("agent") return workflow.compile() # Initialize the agent graph agent_graph = build_agent_graph(AgentNode(agent))