helloparthshah Kunal Pai harshil-21 commited on
Commit
60b4d0f
·
1 Parent(s): d648fe6

Implemented thinking and fixed a bug with agent creator

Browse files

Co-authored-by: Kunal Pai <kunpai@users.noreply.github.com>
Co-authored-by: Harshil Patel <hpppatel@ucdavis.edu>

Files changed (2) hide show
  1. src/agent_manager.py +5 -3
  2. src/manager.py +42 -29
src/agent_manager.py CHANGED
@@ -146,12 +146,13 @@ class AgentManager():
146
  if not agent_class:
147
  raise ValueError(f"Unsupported base model {base_model}")
148
 
 
 
149
  self.validate_budget(create_cost)
150
 
151
  self.budget_manager.add_to_expense(create_cost)
152
  # create agent
153
- return agent_class(agent_name, base_model, system_prompt, create_cost,invoke_cost)
154
-
155
 
156
  def get_agent(self, agent_name: str) -> Agent:
157
  """Get existing agent by name"""
@@ -172,7 +173,8 @@ class AgentManager():
172
  simplified_agents[name] = {
173
  "description": data.get("description", ""),
174
  "create_cost": data.get("create_cost", 0),
175
- "invoke_cost": data.get("invoke_cost", 0)
 
176
  }
177
  return simplified_agents
178
  else:
 
146
  if not agent_class:
147
  raise ValueError(f"Unsupported base model {base_model}")
148
 
149
+ created_agent = agent_class(agent_name, base_model, system_prompt, create_cost,invoke_cost)
150
+
151
  self.validate_budget(create_cost)
152
 
153
  self.budget_manager.add_to_expense(create_cost)
154
  # create agent
155
+ return created_agent
 
156
 
157
  def get_agent(self, agent_name: str) -> Agent:
158
  """Get existing agent by name"""
 
173
  simplified_agents[name] = {
174
  "description": data.get("description", ""),
175
  "create_cost": data.get("create_cost", 0),
176
+ "invoke_cost": data.get("invoke_cost", 0),
177
+ "base_model": data.get("base_model", ""),
178
  }
179
  return simplified_agents
180
  else:
src/manager.py CHANGED
@@ -38,26 +38,9 @@ class GeminiManager:
38
  with open(system_prompt_file, 'r', encoding="utf8") as f:
39
  self.system_prompt = f.read()
40
  self.messages = []
41
-
42
  def generate_response(self, messages):
43
  tools = self.toolsLoader.getTools()
44
- function = types.FunctionDeclaration(
45
- name="DigestConversation",
46
- description="Digest the conversation and store the summary provided.",
47
- parameters=types.Schema(
48
- type = "object",
49
- properties={
50
- # string that summarizes the conversation
51
- "summary": types.Schema(
52
- type="string",
53
- description="A summary of the conversation including all the important points.",
54
- ),
55
- },
56
- required=["summary"],
57
- ),
58
- )
59
- toolType = types.Tool(function_declarations=[function])
60
- tools.append(toolType)
61
  return self.client.models.generate_content(
62
  model=self.model_name,
63
  contents=messages,
@@ -70,17 +53,23 @@ class GeminiManager:
70
 
71
  def handle_tool_calls(self, response):
72
  parts = []
 
73
  for function_call in response.function_calls:
 
 
74
  toolResponse = None
75
  logger.info(
76
  f"Function Name: {function_call.name}, Arguments: {function_call.args}")
77
- if function_call.name == "DigestConversation":
78
- logger.info("Digesting conversation...")
79
- summary = function_call.args["summary"]
80
- return {
81
- "role": "summary",
82
- "content": f"{summary}",
 
 
83
  }
 
84
  try:
85
  toolResponse = self.toolsLoader.runTool(
86
  function_call.name, function_call.args)
@@ -88,10 +77,20 @@ class GeminiManager:
88
  logger.warning(f"Error running tool: {e}")
89
  toolResponse = {
90
  "status": "error",
91
- "message": f"Tool {function_call.name} failed to run.",
92
  "output": str(e),
93
  }
94
  logger.debug(f"Tool Response: {toolResponse}")
 
 
 
 
 
 
 
 
 
 
95
  tool_content = types.Part.from_function_response(
96
  name=function_call.name,
97
  response={"result": toolResponse})
@@ -99,6 +98,16 @@ class GeminiManager:
99
  self.toolsLoader.load_tools()
100
  except Exception as e:
101
  logger.info(f"Error loading tools: {e}. Deleting the tool.")
 
 
 
 
 
 
 
 
 
 
102
  # delete the created tool
103
  self.toolsLoader.delete_tool(
104
  toolResponse['output']['tool_name'], toolResponse['output']['tool_file_path'])
@@ -106,7 +115,8 @@ class GeminiManager:
106
  name=function_call.name,
107
  response={"result": f"{function_call.name} with {function_call.args} doesn't follow the required format, please read the other tool implementations for reference." + str(e)})
108
  parts.append(tool_content)
109
- return {
 
110
  "role": "tool",
111
  "content": repr(types.Content(
112
  role='model' if self.model_name == "gemini-2.5-pro-exp-03-25" else 'tool',
@@ -141,7 +151,7 @@ class GeminiManager:
141
  parts=parts
142
  ))
143
  return formatted_history
144
-
145
  def run(self, messages):
146
  chat_history = self.format_chat_history(messages)
147
  logger.debug(f"Chat history: {chat_history}")
@@ -183,8 +193,11 @@ class GeminiManager:
183
 
184
  # Invoke the function calls if any and attach the response to the messages
185
  if response.function_calls:
186
- calls = self.handle_tool_calls(response)
187
- messages.append(calls)
 
 
 
188
  yield from self.run(messages)
189
  return
190
  yield messages
 
38
  with open(system_prompt_file, 'r', encoding="utf8") as f:
39
  self.system_prompt = f.read()
40
  self.messages = []
41
+
42
  def generate_response(self, messages):
43
  tools = self.toolsLoader.getTools()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  return self.client.models.generate_content(
45
  model=self.model_name,
46
  contents=messages,
 
53
 
54
  def handle_tool_calls(self, response):
55
  parts = []
56
+ i = 0
57
  for function_call in response.function_calls:
58
+ title = ""
59
+ thinking = ""
60
  toolResponse = None
61
  logger.info(
62
  f"Function Name: {function_call.name}, Arguments: {function_call.args}")
63
+ title = f"Invoking `{function_call.name}` with `{function_call.args}`\n"
64
+ yield {
65
+ "role": "assistant",
66
+ "content": thinking,
67
+ "metadata": {
68
+ "title": title,
69
+ "id": i,
70
+ "status": "pending",
71
  }
72
+ }
73
  try:
74
  toolResponse = self.toolsLoader.runTool(
75
  function_call.name, function_call.args)
 
77
  logger.warning(f"Error running tool: {e}")
78
  toolResponse = {
79
  "status": "error",
80
+ "message": f"Tool `{function_call.name}` failed to run.",
81
  "output": str(e),
82
  }
83
  logger.debug(f"Tool Response: {toolResponse}")
84
+ thinking += f"Tool responded with ```\n{toolResponse}\n```\n"
85
+ yield {
86
+ "role": "assistant",
87
+ "content": thinking,
88
+ "metadata": {
89
+ "title": title,
90
+ "id": i,
91
+ "status": "done",
92
+ }
93
+ }
94
  tool_content = types.Part.from_function_response(
95
  name=function_call.name,
96
  response={"result": toolResponse})
 
98
  self.toolsLoader.load_tools()
99
  except Exception as e:
100
  logger.info(f"Error loading tools: {e}. Deleting the tool.")
101
+ thinking += f"Error loading tools: {e}. Deleting the tool.\n"
102
+ yield {
103
+ "role": "assistant",
104
+ "content": thinking,
105
+ "metadata": {
106
+ "title": title,
107
+ "id": i,
108
+ "status": "done",
109
+ }
110
+ }
111
  # delete the created tool
112
  self.toolsLoader.delete_tool(
113
  toolResponse['output']['tool_name'], toolResponse['output']['tool_file_path'])
 
115
  name=function_call.name,
116
  response={"result": f"{function_call.name} with {function_call.args} doesn't follow the required format, please read the other tool implementations for reference." + str(e)})
117
  parts.append(tool_content)
118
+ i += 1
119
+ yield {
120
  "role": "tool",
121
  "content": repr(types.Content(
122
  role='model' if self.model_name == "gemini-2.5-pro-exp-03-25" else 'tool',
 
151
  parts=parts
152
  ))
153
  return formatted_history
154
+
155
  def run(self, messages):
156
  chat_history = self.format_chat_history(messages)
157
  logger.debug(f"Chat history: {chat_history}")
 
193
 
194
  # Invoke the function calls if any and attach the response to the messages
195
  if response.function_calls:
196
+ for call in self.handle_tool_calls(response):
197
+ yield messages + [call]
198
+ if (call.get("role") == "tool"
199
+ or (call.get("role") == "assistant" and call.get("metadata", {}).get("status") == "done")):
200
+ messages.append(call)
201
  yield from self.run(messages)
202
  return
203
  yield messages