gperdrizet commited on
Commit
f97da2b
·
verified ·
1 Parent(s): a28b1b4

Improved re-prompting after LLM makes tool call.

Browse files
client/anthropic_bridge.py CHANGED
@@ -62,11 +62,13 @@ class LLMBridge(abc.ABC):
62
 
63
 
64
  @abc.abstractmethod
65
- async def submit_query(self, query: str, formatted_tools: Any) -> Dict[str, Any]:
66
  '''Submit a query to the LLM with the formatted tools.
67
 
68
  Args:
69
- query: User query string
 
 
70
  formatted_tools: Tools in the LLM-specific format
71
 
72
  Returns:
@@ -101,7 +103,7 @@ class LLMBridge(abc.ABC):
101
  return await self.mcp_client.invoke_tool(tool_name, kwargs)
102
 
103
 
104
- async def process_query(self, query: str) -> Dict[str, Any]:
105
  '''Process a user query through the LLM and execute any tool calls.
106
 
107
  This method handles the full flow:
@@ -125,7 +127,7 @@ class LLMBridge(abc.ABC):
125
  formatted_tools = await self.format_tools(self.tools)
126
 
127
  # 3. Submit query to LLM
128
- llm_response = await self.submit_query(query, formatted_tools)
129
 
130
  # 4. Parse tool calls from LLM response
131
  tool_call = await self.parse_tool_call(llm_response)
@@ -176,7 +178,8 @@ class AnthropicBridge(LLMBridge):
176
 
177
  async def submit_query(
178
  self,
179
- query: str,
 
180
  formatted_tools: List[Dict[str, Any]]
181
  ) -> Dict[str, Any]:
182
  '''Submit a query to Anthropic with the formatted tools.
@@ -191,10 +194,7 @@ class AnthropicBridge(LLMBridge):
191
  response = self.llm_client.messages.create(
192
  model=self.model,
193
  max_tokens=4096,
194
- system='You are a helpful tool-using assistant.',
195
- # messages=[
196
- # {'role': 'user', 'content': query}
197
- # ],
198
  messages=query,
199
  tools=formatted_tools
200
  )
 
62
 
63
 
64
  @abc.abstractmethod
65
+ async def submit_query(self, system_prompt: str, query: List[Dict], formatted_tools: Any) -> Dict[str, Any]:
66
  '''Submit a query to the LLM with the formatted tools.
67
 
68
  Args:
69
+ query: User query string # messages=[
70
+ # {'role': 'user', 'content': query}
71
+ # ],
72
  formatted_tools: Tools in the LLM-specific format
73
 
74
  Returns:
 
103
  return await self.mcp_client.invoke_tool(tool_name, kwargs)
104
 
105
 
106
+ async def process_query(self, system_prompt: str, query: List[Dict]) -> Dict[str, Any]:
107
  '''Process a user query through the LLM and execute any tool calls.
108
 
109
  This method handles the full flow:
 
127
  formatted_tools = await self.format_tools(self.tools)
128
 
129
  # 3. Submit query to LLM
130
+ llm_response = await self.submit_query(system_prompt, query, formatted_tools)
131
 
132
  # 4. Parse tool calls from LLM response
133
  tool_call = await self.parse_tool_call(llm_response)
 
178
 
179
  async def submit_query(
180
  self,
181
+ system_prompt: str,
182
+ query: List[Dict],
183
  formatted_tools: List[Dict[str, Any]]
184
  ) -> Dict[str, Any]:
185
  '''Submit a query to Anthropic with the formatted tools.
 
194
  response = self.llm_client.messages.create(
195
  model=self.model,
196
  max_tokens=4096,
197
+ system=system_prompt,
 
 
 
198
  messages=query,
199
  tools=formatted_tools
200
  )
client/interface.py CHANGED
@@ -2,44 +2,92 @@
2
 
3
  import json
4
  import logging
 
5
  from gradio.components.chatbot import ChatMessage
 
 
6
  from client.anthropic_bridge import AnthropicBridge
7
 
8
- async def agent_input(bridge: AnthropicBridge, chat_history: list) -> list:
 
 
 
 
9
  '''Handles model interactions.'''
10
 
11
  function_logger = logging.getLogger(__name__ + '.agent_input')
12
 
13
  input_messages = format_chat_history(chat_history)
14
- result = await bridge.process_query(input_messages)
15
- function_logger.info(result)
16
 
17
- try:
18
- chat_history.append({
19
- "role": "assistant",
20
- "content": result['llm_response'].content[0].text
21
- })
22
 
23
- except AttributeError:
24
- function_logger.info('Model called the tool, but did not talk about it')
25
 
26
- if result['tool_result']:
27
- articles = json.loads(result['tool_result'].content)['text']
28
- function_logger.info(articles)
29
- tmp_chat_history = chat_history.copy()
30
- tmp_chat_history.append({
31
- "role": "assistant",
32
- "content": ('Here are the three most recent entries from the RSS ' +
33
- f'feed in JSON format. Tell the user what you have found: {json.dumps(articles)}')
34
- })
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
- tmp_input_messages = format_chat_history(tmp_chat_history)
37
- function_logger.info(tmp_input_messages)
38
- result = await bridge.process_query(tmp_input_messages)
39
 
40
  chat_history.append({
41
  "role": "assistant",
42
- "content": result['llm_response'].content[0].text
43
  })
44
 
45
  return chat_history
 
2
 
3
  import json
4
  import logging
5
+ from anthropic.types import text_block
6
  from gradio.components.chatbot import ChatMessage
7
+
8
+ from client import prompts
9
  from client.anthropic_bridge import AnthropicBridge
10
 
11
+ async def agent_input(
12
+ bridge: AnthropicBridge,
13
+ chat_history: list
14
+ ) -> list:
15
+
16
  '''Handles model interactions.'''
17
 
18
  function_logger = logging.getLogger(__name__ + '.agent_input')
19
 
20
  input_messages = format_chat_history(chat_history)
21
+ result = await bridge.process_query(prompts.DEFAULT_SYSTEM_PROMPT, input_messages)
 
22
 
23
+ if result['tool_result']:
24
+ tool_call = result['tool_call']
25
+ tool_name = tool_call['name']
 
 
26
 
27
+ if tool_name == 'rss_mcp_server_get_feed':
 
28
 
29
+ tool_parameters = tool_call['parameters']
30
+ website = tool_parameters['website']
31
+ user_query = input_messages[-1]['content']
32
+ response_content = result['llm_response'].content[0]
33
+
34
+ if isinstance(response_content, text_block.TextBlock):
35
+ intermediate_reply = response_content.text
36
+ else:
37
+ intermediate_reply = f'I Will check the {website} RSS feed for you'
38
+
39
+ function_logger.info('User query: %s', user_query)
40
+ function_logger.info('Model intermediate reply: %s', intermediate_reply)
41
+ function_logger.info('LLM called %s on %s', tool_name, website)
42
+
43
+ articles = json.loads(result['tool_result'].content)['text']
44
+
45
+ prompt = prompts.GET_FEED_PROMPT.substitute(
46
+ website=website,
47
+ user_query=user_query,
48
+ intermediate_reply=intermediate_reply,
49
+ articles=articles
50
+ )
51
+
52
+ input_message =[{
53
+ 'role': 'user',
54
+ 'content': prompt
55
+ }]
56
+
57
+ function_logger.info('Re-prompting input %s', input_message)
58
+ result = await bridge.process_query(prompts.GET_FEED_SYSTEM_PROMPT, input_message)
59
+
60
+ try:
61
+
62
+ final_reply = result['llm_response'].content[0].text
63
+
64
+ except (IndexError, AttributeError):
65
+ final_reply = 'No final reply from model'
66
+
67
+ function_logger.info('LLM final reply: %s', final_reply)
68
+
69
+ chat_history.append({
70
+ "role": "assistant",
71
+ "content": intermediate_reply
72
+ })
73
+
74
+ chat_history.append({
75
+ "role": "assistant",
76
+ "content": final_reply
77
+ })
78
+
79
+ else:
80
+ try:
81
+ reply = result['llm_response'].content[0].text
82
+
83
+ except AttributeError:
84
+ reply = 'Bad reply - could not parse'
85
 
86
+ function_logger.info('Direct, no-tool reply: %s', reply)
 
 
87
 
88
  chat_history.append({
89
  "role": "assistant",
90
+ "content": reply
91
  })
92
 
93
  return chat_history
client/mcp_client.py CHANGED
@@ -83,8 +83,6 @@ class MCPClientWrapper:
83
  self.endpoint = endpoint
84
  self.timeout = timeout
85
  self.max_retries = max_retries
86
- # self.tools = None
87
- # self.anthropic = Anthropic()
88
 
89
 
90
  async def _execute_with_retry(self, operation_name: str, operation_func):
 
83
  self.endpoint = endpoint
84
  self.timeout = timeout
85
  self.max_retries = max_retries
 
 
86
 
87
 
88
  async def _execute_with_retry(self, operation_name: str, operation_func):
client/prompts.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''Collection of prompts for Claude API calls in different
2
+ conversational contexts.'''
3
+
4
+ from string import Template
5
+
6
+ DEFAULT_SYSTEM_PROMPT = 'You are a helpful tool-using assistant.'
7
+
8
+ GET_FEED_SYSTEM_PROMPT = '''
9
+ You are a helpful assistant. Your job is to facilitate interactions between
10
+ Human users and LLM agents.
11
+ '''
12
+
13
+ GET_FEED_PROMPT = Template(
14
+ '''
15
+ Below is an exchange between a user and an agent. The user has asked
16
+ the agent to get new content from the $website RSS feed. In order to
17
+ complete the request, the agent has called a function which returned
18
+ the RSS feed content from $website in JSON format. Your job is to
19
+ complete the exchange by using the returned JSON RSS feed data to write
20
+ a human readable reply to the user.
21
+
22
+ user: $user_query
23
+
24
+ agent: $intermediate_reply
25
+
26
+ function call: get_feed_content($website)
27
+
28
+ function return: $articles
29
+
30
+ assistant:
31
+ '''
32
+ )
rss_client.py CHANGED
@@ -29,17 +29,28 @@ logging.basicConfig(
29
 
30
  logger = logging.getLogger(__name__)
31
 
 
32
  RSS_CLIENT = MCPClientWrapper(
33
  'https://agents-mcp-hackathon-rss-mcp-server.hf.space/gradio_api/mcp/sse'
34
  )
35
 
 
36
  BRIDGE = AnthropicBridge(
37
  RSS_CLIENT,
38
  api_key=os.environ['ANTHROPIC_API_KEY']
39
  )
40
 
41
  async def send_message(message: str, chat_history: list) -> str:
42
- '''Submits user message to agent'''
 
 
 
 
 
 
 
 
 
43
 
44
  function_logger = logging.getLogger(__name__ + '.submit_input')
45
  function_logger.info('Submitting user message: %s', message)
@@ -65,20 +76,26 @@ with gr.Blocks(title='MCP RSS client') as demo:
65
 
66
  chatbot = gr.Chatbot(
67
  value=[],
68
- height=500,
69
  type='messages',
70
  show_copy_button=True
71
  )
72
 
73
  msg = gr.Textbox(
 
74
  label='Ask about content or articles on a site or platform',
75
  placeholder='Is there anything new on Hacker News?',
76
  scale=4
77
  )
78
 
79
  connect_btn.click(RSS_CLIENT.list_tools, outputs=status) # pylint: disable=no-member
80
- msg.submit(send_message, [msg, chatbot], [msg, chatbot]) # pylint: disable=no-member
 
 
 
 
 
81
 
82
  if __name__ == '__main__':
83
 
84
- demo.launch(debug=True)
 
29
 
30
  logger = logging.getLogger(__name__)
31
 
32
+ # Handle MCP server connection and interactions
33
  RSS_CLIENT = MCPClientWrapper(
34
  'https://agents-mcp-hackathon-rss-mcp-server.hf.space/gradio_api/mcp/sse'
35
  )
36
 
37
+ # Handles Anthropic API I/O
38
  BRIDGE = AnthropicBridge(
39
  RSS_CLIENT,
40
  api_key=os.environ['ANTHROPIC_API_KEY']
41
  )
42
 
43
  async def send_message(message: str, chat_history: list) -> str:
44
+ '''Submits user message to agent.
45
+
46
+ Args:
47
+ message: the new message from the user as a string
48
+ chat_history: list containing conversation history where each element is
49
+ a dictionary with keys 'role' and 'content'
50
+
51
+ Returns
52
+ New chat history with model's response to user added.
53
+ '''
54
 
55
  function_logger = logging.getLogger(__name__ + '.submit_input')
56
  function_logger.info('Submitting user message: %s', message)
 
76
 
77
  chatbot = gr.Chatbot(
78
  value=[],
79
+ height=200,
80
  type='messages',
81
  show_copy_button=True
82
  )
83
 
84
  msg = gr.Textbox(
85
+ 'Are there any new posts on Hacker News?',
86
  label='Ask about content or articles on a site or platform',
87
  placeholder='Is there anything new on Hacker News?',
88
  scale=4
89
  )
90
 
91
  connect_btn.click(RSS_CLIENT.list_tools, outputs=status) # pylint: disable=no-member
92
+
93
+ msg.submit( # pylint: disable=no-member
94
+ send_message,
95
+ [msg, chatbot],
96
+ [msg, chatbot]
97
+ )
98
 
99
  if __name__ == '__main__':
100
 
101
+ demo.launch(server_name="0.0.0.0", server_port=7860)