|
"""Define the agent graph and its components.""" |
|
|
|
import logging |
|
import os |
|
from typing import Dict, List, Optional, TypedDict, Union |
|
|
|
import yaml |
|
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage |
|
from langchain_core.runnables import RunnableConfig |
|
from langgraph.graph import END, StateGraph |
|
from langgraph.types import interrupt |
|
from smolagents import CodeAgent, LiteLLMModel |
|
|
|
from configuration import Configuration |
|
from tools import tools |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
import litellm |
|
|
|
if os.getenv("LITELLM_DEBUG", "false").lower() == "true": |
|
litellm.set_verbose = True |
|
logger.setLevel(logging.DEBUG) |
|
else: |
|
litellm.set_verbose = False |
|
logger.setLevel(logging.INFO) |
|
|
|
|
|
litellm.drop_params = True |
|
|
|
|
|
current_dir = os.path.dirname(os.path.abspath(__file__)) |
|
prompts_dir = os.path.join(current_dir, "prompts") |
|
yaml_path = os.path.join(prompts_dir, "code_agent.yaml") |
|
|
|
with open(yaml_path, "r") as f: |
|
prompt_templates = yaml.safe_load(f) |
|
|
|
|
|
config = Configuration() |
|
model = LiteLLMModel( |
|
api_base=config.api_base, |
|
api_key=config.api_key, |
|
model_id=config.model_id, |
|
) |
|
|
|
agent = CodeAgent( |
|
add_base_tools=True, |
|
max_steps=1, |
|
model=model, |
|
prompt_templates=prompt_templates, |
|
tools=tools, |
|
verbosity_level=logging.DEBUG, |
|
) |
|
|
|
|
|
class AgentState(TypedDict): |
|
"""State for the agent graph.""" |
|
|
|
messages: List[Union[HumanMessage, AIMessage, SystemMessage]] |
|
question: str |
|
answer: Optional[str] |
|
step_logs: List[Dict] |
|
is_complete: bool |
|
step_count: int |
|
|
|
|
|
class AgentNode: |
|
"""Node that runs the agent.""" |
|
|
|
def __init__(self, agent: CodeAgent): |
|
"""Initialize the agent node with an agent.""" |
|
self.agent = agent |
|
|
|
def __call__( |
|
self, state: AgentState, config: Optional[RunnableConfig] = None |
|
) -> AgentState: |
|
"""Run the agent on the current state.""" |
|
|
|
logger.info("Current state before processing:") |
|
logger.info(f"Messages: {state['messages']}") |
|
logger.info(f"Question: {state['question']}") |
|
logger.info(f"Answer: {state['answer']}") |
|
|
|
|
|
cfg = Configuration.from_runnable_config(config) |
|
logger.info(f"Using configuration: {cfg}") |
|
|
|
|
|
logger.info("Starting agent execution") |
|
|
|
|
|
result = self.agent.run(state["question"]) |
|
|
|
|
|
logger.info(f"Agent execution result type: {type(result)}") |
|
logger.info(f"Agent execution result value: {result}") |
|
|
|
|
|
new_state = state.copy() |
|
new_state["messages"].append(AIMessage(content=result)) |
|
new_state["answer"] = result |
|
new_state["step_count"] += 1 |
|
|
|
|
|
logger.info("Updated state after processing:") |
|
logger.info(f"Messages: {new_state['messages']}") |
|
logger.info(f"Question: {new_state['question']}") |
|
logger.info(f"Answer: {new_state['answer']}") |
|
|
|
return new_state |
|
|
|
|
|
class StepCallbackNode: |
|
"""Node that handles step callbacks and user interaction.""" |
|
|
|
def __call__( |
|
self, state: AgentState, config: Optional[RunnableConfig] = None |
|
) -> AgentState: |
|
"""Handle step callback and user interaction.""" |
|
|
|
cfg = Configuration.from_runnable_config(config) |
|
|
|
|
|
step_log = { |
|
"step": state["step_count"], |
|
"messages": [msg.content for msg in state["messages"]], |
|
"question": state["question"], |
|
"answer": state["answer"], |
|
} |
|
state["step_logs"].append(step_log) |
|
|
|
try: |
|
|
|
user_input = interrupt( |
|
"Press 'c' to continue, 'q' to quit, or 'i' for more info: " |
|
) |
|
|
|
if user_input.lower() == "q": |
|
state["is_complete"] = True |
|
return state |
|
elif user_input.lower() == "i": |
|
logger.info(f"Current step: {state['step_count']}") |
|
logger.info(f"Question: {state['question']}") |
|
logger.info(f"Current answer: {state['answer']}") |
|
return self(state, config) |
|
elif user_input.lower() == "c": |
|
return state |
|
else: |
|
logger.warning("Invalid input. Please use 'c', 'q', or 'i'.") |
|
return self(state, config) |
|
|
|
except Exception as e: |
|
logger.warning(f"Error during interrupt: {str(e)}") |
|
return state |
|
|
|
|
|
def build_agent_graph(agent: AgentNode) -> StateGraph: |
|
"""Build the agent graph.""" |
|
|
|
workflow = StateGraph(AgentState) |
|
|
|
|
|
workflow.add_node("agent", agent) |
|
workflow.add_node("callback", StepCallbackNode()) |
|
|
|
|
|
workflow.add_edge("agent", "callback") |
|
workflow.add_conditional_edges( |
|
"callback", |
|
lambda x: END if x["is_complete"] else "agent", |
|
{True: END, False: "agent"}, |
|
) |
|
|
|
|
|
workflow.set_entry_point("agent") |
|
|
|
return workflow.compile() |
|
|
|
|
|
|
|
agent_graph = build_agent_graph(AgentNode(agent)) |
|
|