|
import logging |
|
from typing import Callable, List, Optional, TypedDict |
|
from langgraph.graph import StateGraph, END |
|
from smolagents import CodeAgent, ToolCallingAgent, LiteLLMModel |
|
from tools import tools |
|
import yaml |
|
import os |
|
import litellm |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
litellm.drop_params = True |
|
|
|
|
|
class AgentState(TypedDict): |
|
messages: list |
|
question: str |
|
answer: str | None |
|
|
|
class AgentNode: |
|
def __init__(self): |
|
|
|
current_dir = os.path.dirname(os.path.abspath(__file__)) |
|
prompts_dir = os.path.join(current_dir, "prompts") |
|
|
|
yaml_path = os.path.join(prompts_dir, "code_agent.yaml") |
|
|
|
with open(yaml_path, 'r') as f: |
|
prompt_templates = yaml.safe_load(f) |
|
|
|
|
|
logger.info("Default system prompt:") |
|
logger.info("-" * 80) |
|
logger.info(prompt_templates["system_prompt"]) |
|
logger.info("-" * 80) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.model = LiteLLMModel( |
|
api_base="http://localhost:11434", |
|
api_key=None, |
|
model_id="ollama/codellama", |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
step_callbacks: Optional[List[Callable]] = [ |
|
lambda step: logger.info(f"Step {step.step_number} completed: {step.action}") |
|
] |
|
|
|
self.agent = CodeAgent( |
|
add_base_tools=True, |
|
max_steps=1, |
|
model=self.model, |
|
prompt_templates=prompt_templates, |
|
step_callbacks=step_callbacks, |
|
tools=tools, |
|
verbosity_level=logging.DEBUG |
|
) |
|
|
|
def __call__(self, state: AgentState) -> AgentState: |
|
try: |
|
|
|
logger.info("Current state before processing:") |
|
logger.info(f"Messages: {state['messages']}") |
|
logger.info(f"Question: {state['question']}") |
|
logger.info(f"Answer: {state['answer']}") |
|
|
|
|
|
logger.info("Calling agent.run()...") |
|
result = self.agent.run(state["question"]) |
|
|
|
|
|
logger.info("Agent run completed:") |
|
logger.info(f"Result type: {type(result)}") |
|
logger.info(f"Result value: {result}") |
|
|
|
|
|
state["answer"] = result |
|
|
|
|
|
logger.info("Updated state after processing:") |
|
logger.info(f"Messages: {state['messages']}") |
|
logger.info(f"Question: {state['question']}") |
|
logger.info(f"Answer: {state['answer']}") |
|
|
|
return state |
|
|
|
except Exception as e: |
|
logger.error(f"Error in agent node: {str(e)}", exc_info=True) |
|
state["answer"] = f"Error: {str(e)}" |
|
return state |
|
|
|
def build_agent_graph(): |
|
|
|
graph = StateGraph(AgentState) |
|
|
|
|
|
graph.add_node("agent", AgentNode()) |
|
|
|
|
|
graph.add_edge("agent", END) |
|
|
|
|
|
graph.set_entry_point("agent") |
|
|
|
|
|
return graph.compile() |
|
|
|
|
|
agent_graph = build_agent_graph() |