Improved re-prompting after LLM makes tool call.
Browse files- client/anthropic_bridge.py +9 -9
- client/interface.py +71 -23
- client/mcp_client.py +0 -2
- client/prompts.py +32 -0
- rss_client.py +21 -4
client/anthropic_bridge.py
CHANGED
@@ -62,11 +62,13 @@ class LLMBridge(abc.ABC):
|
|
62 |
|
63 |
|
64 |
@abc.abstractmethod
|
65 |
-
async def submit_query(self,
|
66 |
'''Submit a query to the LLM with the formatted tools.
|
67 |
|
68 |
Args:
|
69 |
-
query: User query string
|
|
|
|
|
70 |
formatted_tools: Tools in the LLM-specific format
|
71 |
|
72 |
Returns:
|
@@ -101,7 +103,7 @@ class LLMBridge(abc.ABC):
|
|
101 |
return await self.mcp_client.invoke_tool(tool_name, kwargs)
|
102 |
|
103 |
|
104 |
-
async def process_query(self, query:
|
105 |
'''Process a user query through the LLM and execute any tool calls.
|
106 |
|
107 |
This method handles the full flow:
|
@@ -125,7 +127,7 @@ class LLMBridge(abc.ABC):
|
|
125 |
formatted_tools = await self.format_tools(self.tools)
|
126 |
|
127 |
# 3. Submit query to LLM
|
128 |
-
llm_response = await self.submit_query(query, formatted_tools)
|
129 |
|
130 |
# 4. Parse tool calls from LLM response
|
131 |
tool_call = await self.parse_tool_call(llm_response)
|
@@ -176,7 +178,8 @@ class AnthropicBridge(LLMBridge):
|
|
176 |
|
177 |
async def submit_query(
|
178 |
self,
|
179 |
-
|
|
|
180 |
formatted_tools: List[Dict[str, Any]]
|
181 |
) -> Dict[str, Any]:
|
182 |
'''Submit a query to Anthropic with the formatted tools.
|
@@ -191,10 +194,7 @@ class AnthropicBridge(LLMBridge):
|
|
191 |
response = self.llm_client.messages.create(
|
192 |
model=self.model,
|
193 |
max_tokens=4096,
|
194 |
-
system=
|
195 |
-
# messages=[
|
196 |
-
# {'role': 'user', 'content': query}
|
197 |
-
# ],
|
198 |
messages=query,
|
199 |
tools=formatted_tools
|
200 |
)
|
|
|
62 |
|
63 |
|
64 |
@abc.abstractmethod
|
65 |
+
async def submit_query(self, system_prompt: str, query: List[Dict], formatted_tools: Any) -> Dict[str, Any]:
|
66 |
'''Submit a query to the LLM with the formatted tools.
|
67 |
|
68 |
Args:
|
69 |
+
query: User query string # messages=[
|
70 |
+
# {'role': 'user', 'content': query}
|
71 |
+
# ],
|
72 |
formatted_tools: Tools in the LLM-specific format
|
73 |
|
74 |
Returns:
|
|
|
103 |
return await self.mcp_client.invoke_tool(tool_name, kwargs)
|
104 |
|
105 |
|
106 |
+
async def process_query(self, system_prompt: str, query: List[Dict]) -> Dict[str, Any]:
|
107 |
'''Process a user query through the LLM and execute any tool calls.
|
108 |
|
109 |
This method handles the full flow:
|
|
|
127 |
formatted_tools = await self.format_tools(self.tools)
|
128 |
|
129 |
# 3. Submit query to LLM
|
130 |
+
llm_response = await self.submit_query(system_prompt, query, formatted_tools)
|
131 |
|
132 |
# 4. Parse tool calls from LLM response
|
133 |
tool_call = await self.parse_tool_call(llm_response)
|
|
|
178 |
|
179 |
async def submit_query(
|
180 |
self,
|
181 |
+
system_prompt: str,
|
182 |
+
query: List[Dict],
|
183 |
formatted_tools: List[Dict[str, Any]]
|
184 |
) -> Dict[str, Any]:
|
185 |
'''Submit a query to Anthropic with the formatted tools.
|
|
|
194 |
response = self.llm_client.messages.create(
|
195 |
model=self.model,
|
196 |
max_tokens=4096,
|
197 |
+
system=system_prompt,
|
|
|
|
|
|
|
198 |
messages=query,
|
199 |
tools=formatted_tools
|
200 |
)
|
client/interface.py
CHANGED
@@ -2,44 +2,92 @@
|
|
2 |
|
3 |
import json
|
4 |
import logging
|
|
|
5 |
from gradio.components.chatbot import ChatMessage
|
|
|
|
|
6 |
from client.anthropic_bridge import AnthropicBridge
|
7 |
|
8 |
-
async def agent_input(
|
|
|
|
|
|
|
|
|
9 |
'''Handles model interactions.'''
|
10 |
|
11 |
function_logger = logging.getLogger(__name__ + '.agent_input')
|
12 |
|
13 |
input_messages = format_chat_history(chat_history)
|
14 |
-
result = await bridge.process_query(input_messages)
|
15 |
-
function_logger.info(result)
|
16 |
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
"content": result['llm_response'].content[0].text
|
21 |
-
})
|
22 |
|
23 |
-
|
24 |
-
function_logger.info('Model called the tool, but did not talk about it')
|
25 |
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
35 |
|
36 |
-
|
37 |
-
function_logger.info(tmp_input_messages)
|
38 |
-
result = await bridge.process_query(tmp_input_messages)
|
39 |
|
40 |
chat_history.append({
|
41 |
"role": "assistant",
|
42 |
-
"content":
|
43 |
})
|
44 |
|
45 |
return chat_history
|
|
|
2 |
|
3 |
import json
|
4 |
import logging
|
5 |
+
from anthropic.types import text_block
|
6 |
from gradio.components.chatbot import ChatMessage
|
7 |
+
|
8 |
+
from client import prompts
|
9 |
from client.anthropic_bridge import AnthropicBridge
|
10 |
|
11 |
+
async def agent_input(
|
12 |
+
bridge: AnthropicBridge,
|
13 |
+
chat_history: list
|
14 |
+
) -> list:
|
15 |
+
|
16 |
'''Handles model interactions.'''
|
17 |
|
18 |
function_logger = logging.getLogger(__name__ + '.agent_input')
|
19 |
|
20 |
input_messages = format_chat_history(chat_history)
|
21 |
+
result = await bridge.process_query(prompts.DEFAULT_SYSTEM_PROMPT, input_messages)
|
|
|
22 |
|
23 |
+
if result['tool_result']:
|
24 |
+
tool_call = result['tool_call']
|
25 |
+
tool_name = tool_call['name']
|
|
|
|
|
26 |
|
27 |
+
if tool_name == 'rss_mcp_server_get_feed':
|
|
|
28 |
|
29 |
+
tool_parameters = tool_call['parameters']
|
30 |
+
website = tool_parameters['website']
|
31 |
+
user_query = input_messages[-1]['content']
|
32 |
+
response_content = result['llm_response'].content[0]
|
33 |
+
|
34 |
+
if isinstance(response_content, text_block.TextBlock):
|
35 |
+
intermediate_reply = response_content.text
|
36 |
+
else:
|
37 |
+
intermediate_reply = f'I Will check the {website} RSS feed for you'
|
38 |
+
|
39 |
+
function_logger.info('User query: %s', user_query)
|
40 |
+
function_logger.info('Model intermediate reply: %s', intermediate_reply)
|
41 |
+
function_logger.info('LLM called %s on %s', tool_name, website)
|
42 |
+
|
43 |
+
articles = json.loads(result['tool_result'].content)['text']
|
44 |
+
|
45 |
+
prompt = prompts.GET_FEED_PROMPT.substitute(
|
46 |
+
website=website,
|
47 |
+
user_query=user_query,
|
48 |
+
intermediate_reply=intermediate_reply,
|
49 |
+
articles=articles
|
50 |
+
)
|
51 |
+
|
52 |
+
input_message =[{
|
53 |
+
'role': 'user',
|
54 |
+
'content': prompt
|
55 |
+
}]
|
56 |
+
|
57 |
+
function_logger.info('Re-prompting input %s', input_message)
|
58 |
+
result = await bridge.process_query(prompts.GET_FEED_SYSTEM_PROMPT, input_message)
|
59 |
+
|
60 |
+
try:
|
61 |
+
|
62 |
+
final_reply = result['llm_response'].content[0].text
|
63 |
+
|
64 |
+
except (IndexError, AttributeError):
|
65 |
+
final_reply = 'No final reply from model'
|
66 |
+
|
67 |
+
function_logger.info('LLM final reply: %s', final_reply)
|
68 |
+
|
69 |
+
chat_history.append({
|
70 |
+
"role": "assistant",
|
71 |
+
"content": intermediate_reply
|
72 |
+
})
|
73 |
+
|
74 |
+
chat_history.append({
|
75 |
+
"role": "assistant",
|
76 |
+
"content": final_reply
|
77 |
+
})
|
78 |
+
|
79 |
+
else:
|
80 |
+
try:
|
81 |
+
reply = result['llm_response'].content[0].text
|
82 |
+
|
83 |
+
except AttributeError:
|
84 |
+
reply = 'Bad reply - could not parse'
|
85 |
|
86 |
+
function_logger.info('Direct, no-tool reply: %s', reply)
|
|
|
|
|
87 |
|
88 |
chat_history.append({
|
89 |
"role": "assistant",
|
90 |
+
"content": reply
|
91 |
})
|
92 |
|
93 |
return chat_history
|
client/mcp_client.py
CHANGED
@@ -83,8 +83,6 @@ class MCPClientWrapper:
|
|
83 |
self.endpoint = endpoint
|
84 |
self.timeout = timeout
|
85 |
self.max_retries = max_retries
|
86 |
-
# self.tools = None
|
87 |
-
# self.anthropic = Anthropic()
|
88 |
|
89 |
|
90 |
async def _execute_with_retry(self, operation_name: str, operation_func):
|
|
|
83 |
self.endpoint = endpoint
|
84 |
self.timeout = timeout
|
85 |
self.max_retries = max_retries
|
|
|
|
|
86 |
|
87 |
|
88 |
async def _execute_with_retry(self, operation_name: str, operation_func):
|
client/prompts.py
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''Collection of prompts for Claude API calls in different
|
2 |
+
conversational contexts.'''
|
3 |
+
|
4 |
+
from string import Template
|
5 |
+
|
6 |
+
DEFAULT_SYSTEM_PROMPT = 'You are a helpful tool-using assistant.'
|
7 |
+
|
8 |
+
GET_FEED_SYSTEM_PROMPT = '''
|
9 |
+
You are a helpful assistant. Your job is to facilitate interactions between
|
10 |
+
Human users and LLM agents.
|
11 |
+
'''
|
12 |
+
|
13 |
+
GET_FEED_PROMPT = Template(
|
14 |
+
'''
|
15 |
+
Below is an exchange between a user and an agent. The user has asked
|
16 |
+
the agent to get new content from the $website RSS feed. In order to
|
17 |
+
complete the request, the agent has called a function which returned
|
18 |
+
the RSS feed content from $website in JSON format. Your job is to
|
19 |
+
complete the exchange by using the returned JSON RSS feed data to write
|
20 |
+
a human readable reply to the user.
|
21 |
+
|
22 |
+
user: $user_query
|
23 |
+
|
24 |
+
agent: $intermediate_reply
|
25 |
+
|
26 |
+
function call: get_feed_content($website)
|
27 |
+
|
28 |
+
function return: $articles
|
29 |
+
|
30 |
+
assistant:
|
31 |
+
'''
|
32 |
+
)
|
rss_client.py
CHANGED
@@ -29,17 +29,28 @@ logging.basicConfig(
|
|
29 |
|
30 |
logger = logging.getLogger(__name__)
|
31 |
|
|
|
32 |
RSS_CLIENT = MCPClientWrapper(
|
33 |
'https://agents-mcp-hackathon-rss-mcp-server.hf.space/gradio_api/mcp/sse'
|
34 |
)
|
35 |
|
|
|
36 |
BRIDGE = AnthropicBridge(
|
37 |
RSS_CLIENT,
|
38 |
api_key=os.environ['ANTHROPIC_API_KEY']
|
39 |
)
|
40 |
|
41 |
async def send_message(message: str, chat_history: list) -> str:
|
42 |
-
'''Submits user message to agent
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
43 |
|
44 |
function_logger = logging.getLogger(__name__ + '.submit_input')
|
45 |
function_logger.info('Submitting user message: %s', message)
|
@@ -65,20 +76,26 @@ with gr.Blocks(title='MCP RSS client') as demo:
|
|
65 |
|
66 |
chatbot = gr.Chatbot(
|
67 |
value=[],
|
68 |
-
height=
|
69 |
type='messages',
|
70 |
show_copy_button=True
|
71 |
)
|
72 |
|
73 |
msg = gr.Textbox(
|
|
|
74 |
label='Ask about content or articles on a site or platform',
|
75 |
placeholder='Is there anything new on Hacker News?',
|
76 |
scale=4
|
77 |
)
|
78 |
|
79 |
connect_btn.click(RSS_CLIENT.list_tools, outputs=status) # pylint: disable=no-member
|
80 |
-
|
|
|
|
|
|
|
|
|
|
|
81 |
|
82 |
if __name__ == '__main__':
|
83 |
|
84 |
-
demo.launch(
|
|
|
29 |
|
30 |
logger = logging.getLogger(__name__)
|
31 |
|
32 |
+
# Handle MCP server connection and interactions
|
33 |
RSS_CLIENT = MCPClientWrapper(
|
34 |
'https://agents-mcp-hackathon-rss-mcp-server.hf.space/gradio_api/mcp/sse'
|
35 |
)
|
36 |
|
37 |
+
# Handles Anthropic API I/O
|
38 |
BRIDGE = AnthropicBridge(
|
39 |
RSS_CLIENT,
|
40 |
api_key=os.environ['ANTHROPIC_API_KEY']
|
41 |
)
|
42 |
|
43 |
async def send_message(message: str, chat_history: list) -> str:
|
44 |
+
'''Submits user message to agent.
|
45 |
+
|
46 |
+
Args:
|
47 |
+
message: the new message from the user as a string
|
48 |
+
chat_history: list containing conversation history where each element is
|
49 |
+
a dictionary with keys 'role' and 'content'
|
50 |
+
|
51 |
+
Returns
|
52 |
+
New chat history with model's response to user added.
|
53 |
+
'''
|
54 |
|
55 |
function_logger = logging.getLogger(__name__ + '.submit_input')
|
56 |
function_logger.info('Submitting user message: %s', message)
|
|
|
76 |
|
77 |
chatbot = gr.Chatbot(
|
78 |
value=[],
|
79 |
+
height=200,
|
80 |
type='messages',
|
81 |
show_copy_button=True
|
82 |
)
|
83 |
|
84 |
msg = gr.Textbox(
|
85 |
+
'Are there any new posts on Hacker News?',
|
86 |
label='Ask about content or articles on a site or platform',
|
87 |
placeholder='Is there anything new on Hacker News?',
|
88 |
scale=4
|
89 |
)
|
90 |
|
91 |
connect_btn.click(RSS_CLIENT.list_tools, outputs=status) # pylint: disable=no-member
|
92 |
+
|
93 |
+
msg.submit( # pylint: disable=no-member
|
94 |
+
send_message,
|
95 |
+
[msg, chatbot],
|
96 |
+
[msg, chatbot]
|
97 |
+
)
|
98 |
|
99 |
if __name__ == '__main__':
|
100 |
|
101 |
+
demo.launch(server_name="0.0.0.0", server_port=7860)
|