gperdrizet commited on
Commit
4facc97
·
verified ·
1 Parent(s): 55d469a

Progress on Anthropic agent.

Browse files
.gitignore CHANGED
@@ -1,3 +1,4 @@
 
1
  .vscode
2
  .venv
3
  logs
 
1
+ __pycache__
2
  .vscode
3
  .venv
4
  logs
classes/__pycache__/client.cpython-310.pyc DELETED
Binary file (7.14 kB)
 
client/anthropic_bridge.py ADDED
@@ -0,0 +1,302 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''Classes to connect to Anthropic inference endpoint'''
2
+
3
+ import abc
4
+ from typing import Dict, List, Any, Optional
5
+ import anthropic
6
+ from client.mcp_client import MCPClientWrapper, ToolDef, ToolParameter, ToolInvocationResult
7
+
8
+ DEFAULT_ANTHROPIC_MODEL = 'claude-3-haiku-20240307'
9
+
10
+ # Type mapping from Python/MCP types to JSON Schema types
11
+ TYPE_MAPPING = {
12
+ 'int': 'integer',
13
+ 'bool': 'boolean',
14
+ 'str': 'string',
15
+ 'float': 'number',
16
+ 'list': 'array',
17
+ 'dict': 'object',
18
+ 'boolean': 'boolean',
19
+ 'string': 'string',
20
+ 'integer': 'integer',
21
+ 'number': 'number',
22
+ 'array': 'array',
23
+ 'object': 'object'
24
+ }
25
+
26
+
27
+ class LLMBridge(abc.ABC):
28
+ '''Abstract base class for LLM bridge implementations.'''
29
+
30
+ def __init__(self, mcp_client: MCPClientWrapper):
31
+ '''Initialize the LLM bridge with an MCPClient instance.
32
+
33
+ Args:
34
+ mcp_client: An initialized MCPClient instance
35
+ '''
36
+ self.mcp_client = mcp_client
37
+ self.tools = None
38
+
39
+
40
+ async def fetch_tools(self) -> List[ToolDef]:
41
+ '''Fetch available tools from the MCP endpoint.
42
+
43
+
44
+ Returns:
45
+ List of ToolDef objects
46
+ '''
47
+ self.tools = await self.mcp_client.list_tools()
48
+ return self.tools
49
+
50
+
51
+ @abc.abstractmethod
52
+ async def format_tools(self, tools: List[ToolDef]) -> Any:
53
+ '''Format tools for the specific LLM provider.
54
+
55
+ Args:
56
+ tools: List of ToolDef objects
57
+
58
+ Returns:
59
+ Formatted tools in the LLM-specific format
60
+ '''
61
+ pass
62
+
63
+
64
+ @abc.abstractmethod
65
+ async def submit_query(self, query: str, formatted_tools: Any) -> Dict[str, Any]:
66
+ '''Submit a query to the LLM with the formatted tools.
67
+
68
+ Args:
69
+ query: User query string
70
+ formatted_tools: Tools in the LLM-specific format
71
+
72
+ Returns:
73
+ LLM response
74
+ '''
75
+ pass
76
+
77
+
78
+ @abc.abstractmethod
79
+ async def parse_tool_call(self, llm_response: Any) -> Optional[Dict[str, Any]]:
80
+ '''Parse the LLM response to extract tool calls.
81
+
82
+ Args:
83
+ llm_response: Response from the LLM
84
+
85
+ Returns:
86
+ Dictionary with tool name and parameters, or None if no tool call
87
+ '''
88
+ pass
89
+
90
+
91
+ async def execute_tool(self, tool_name: str, kwargs: Dict[str, Any]) -> ToolInvocationResult:
92
+ '''Execute a tool with the given parameters.
93
+
94
+ Args:
95
+ tool_name: Name of the tool to invoke
96
+ kwargs: Dictionary of parameters to pass to the tool
97
+
98
+ Returns:
99
+ ToolInvocationResult containing the tool's response
100
+ '''
101
+ return await self.mcp_client.invoke_tool(tool_name, kwargs)
102
+
103
+
104
+ async def process_query(self, query: str) -> Dict[str, Any]:
105
+ '''Process a user query through the LLM and execute any tool calls.
106
+
107
+ This method handles the full flow:
108
+ 1. Fetch tools if not already fetched
109
+ 2. Format tools for the LLM
110
+ 3. Submit query to LLM
111
+ 4. Parse tool calls from LLM response
112
+ 5. Execute tool if needed
113
+
114
+ Args:
115
+ query: User query string
116
+
117
+ Returns:
118
+ Dictionary containing the LLM response, tool call, and tool result
119
+ '''
120
+ # 1. Fetch tools if not already fetched
121
+ if self.tools is None:
122
+ await self.fetch_tools()
123
+
124
+ # 2. Format tools for the LLM
125
+ formatted_tools = await self.format_tools(self.tools)
126
+
127
+ # 3. Submit query to LLM
128
+ llm_response = await self.submit_query(query, formatted_tools)
129
+
130
+ # 4. Parse tool calls from LLM response
131
+ tool_call = await self.parse_tool_call(llm_response)
132
+
133
+ result = {
134
+ 'llm_response': llm_response,
135
+ 'tool_call': tool_call,
136
+ 'tool_result': None
137
+ }
138
+
139
+ # 5. Execute tool if needed
140
+ if tool_call:
141
+ tool_name = tool_call.get('name')
142
+ kwargs = tool_call.get('parameters', {})
143
+ tool_result = await self.execute_tool(tool_name, kwargs)
144
+ result['tool_result'] = tool_result
145
+
146
+ return result
147
+
148
+
149
+ class AnthropicBridge(LLMBridge):
150
+ '''Anthropic-specific implementation of the LLM Bridge.'''
151
+
152
+ def __init__(self, mcp_client, api_key, model=DEFAULT_ANTHROPIC_MODEL): # Use imported default
153
+ '''Initialize Anthropic bridge with API key and model.
154
+
155
+ Args:
156
+ mcp_client: An initialized MCPClient instance
157
+ api_key: Anthropic API key
158
+ model: Anthropic model to use (default: from models.py)
159
+ '''
160
+ super().__init__(mcp_client)
161
+ self.llm_client = anthropic.Anthropic(api_key=api_key)
162
+ self.model = model
163
+
164
+
165
+ async def format_tools(self, tools: List[ToolDef]) -> List[Dict[str, Any]]:
166
+ '''Format tools for Anthropic.
167
+
168
+ Args:
169
+ tools: List of ToolDef objects
170
+
171
+ Returns:
172
+ List of tools in Anthropic format
173
+ '''
174
+ return to_anthropic_format(tools)
175
+
176
+
177
+ async def submit_query(
178
+ self,
179
+ query: str,
180
+ formatted_tools: List[Dict[str, Any]]
181
+ ) -> Dict[str, Any]:
182
+ '''Submit a query to Anthropic with the formatted tools.
183
+
184
+ Args:
185
+ query: User query string
186
+ formatted_tools: Tools in Anthropic format
187
+
188
+ Returns:
189
+ Anthropic API response
190
+ '''
191
+ response = self.llm_client.messages.create(
192
+ model=self.model,
193
+ max_tokens=4096,
194
+ system='You are a helpful tool-using assistant.',
195
+ messages=[
196
+ {'role': 'user', 'content': query}
197
+ ],
198
+ tools=formatted_tools
199
+ )
200
+
201
+ return response
202
+
203
+
204
+ async def parse_tool_call(self, llm_response: Any) -> Optional[Dict[str, Any]]:
205
+ '''Parse the Anthropic response to extract tool calls.
206
+
207
+ Args:
208
+ llm_response: Response from Anthropic
209
+
210
+ Returns:
211
+ Dictionary with tool name and parameters, or None if no tool call
212
+ '''
213
+ for content in llm_response.content:
214
+ if content.type == 'tool_use':
215
+ return {
216
+ 'name': content.name,
217
+ 'parameters': content.input
218
+ }
219
+
220
+ return None
221
+
222
+
223
+
224
+ def to_anthropic_format(tools: List[ToolDef]) -> List[Dict[str, Any]]:
225
+ '''Convert ToolDef objects to Anthropic tool format.
226
+
227
+ Args:
228
+ tools: List of ToolDef objects to convert
229
+
230
+ Returns:
231
+ List of dictionaries in Anthropic tool format
232
+ '''
233
+
234
+ anthropic_tools = []
235
+ for tool in tools:
236
+ anthropic_tool = {
237
+ 'name': tool.name,
238
+ 'description': tool.description,
239
+ 'input_schema': {
240
+ 'type': 'object',
241
+ 'properties': {},
242
+ 'required': []
243
+ }
244
+ }
245
+
246
+ # Add properties
247
+ for param in tool.parameters:
248
+ # Map the type or use the original if no mapping exists
249
+ schema_type = TYPE_MAPPING.get(param.parameter_type, param.parameter_type)
250
+
251
+ param_schema = {
252
+ 'type': schema_type, # Use mapped type
253
+ 'description': param.description
254
+ }
255
+
256
+ # For arrays, we need to specify the items type
257
+ if schema_type == 'array':
258
+ item_type = _infer_array_item_type(param)
259
+ param_schema['items'] = {'type': item_type}
260
+
261
+ anthropic_tool['input_schema']['properties'][param.name] = param_schema
262
+
263
+ # Add default value if provided
264
+ if param.default is not None:
265
+ anthropic_tool['input_schema']['properties'][param.name]['default'] = param.default
266
+
267
+ # Add to required list if required
268
+ if param.required:
269
+ anthropic_tool['input_schema']['required'].append(param.name)
270
+
271
+ anthropic_tools.append(anthropic_tool)
272
+ return anthropic_tools
273
+
274
+
275
+ def _infer_array_item_type(param: ToolParameter) -> str:
276
+ '''Infer the item type for an array parameter based on its name and description.
277
+
278
+ Args:
279
+ param: The ToolParameter object
280
+
281
+ Returns:
282
+ The inferred JSON Schema type for array items
283
+ '''
284
+ # Default to string items
285
+ item_type = 'string'
286
+
287
+ # Check if parameter name contains hints about item type
288
+ param_name_lower = param.name.lower()
289
+ if any(hint in param_name_lower for hint in ['language', 'code', 'tag', 'name', 'id']):
290
+ item_type = 'string'
291
+ elif any(hint in param_name_lower for hint in ['number', 'count', 'amount', 'index']):
292
+ item_type = 'integer'
293
+
294
+ # Also check the description for hints
295
+ if param.description:
296
+ desc_lower = param.description.lower()
297
+ if 'string' in desc_lower or 'text' in desc_lower or 'language' in desc_lower:
298
+ item_type = 'string'
299
+ elif 'number' in desc_lower or 'integer' in desc_lower or 'int' in desc_lower:
300
+ item_type = 'integer'
301
+
302
+ return item_type
classes/client.py → client/mcp_client.py RENAMED
@@ -46,6 +46,18 @@ class ToolDef:
46
  identifier: str = ''
47
 
48
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  class MCPConnectionError(Exception):
50
  '''Exception raised when MCP connection fails'''
51
  pass
@@ -68,11 +80,11 @@ class MCPClientWrapper:
68
  max_retries: Maximum number of retry attempts
69
  '''
70
 
71
- if urlparse(endpoint).scheme not in ('http', 'https'):
72
- raise ValueError(f'Endpoint {endpoint} is not a valid HTTP(S) URL')
73
  self.endpoint = endpoint
74
  self.timeout = timeout
75
  self.max_retries = max_retries
 
 
76
 
77
 
78
  async def _execute_with_retry(self, operation_name: str, operation_func):
@@ -136,6 +148,7 @@ class MCPClientWrapper:
136
  f'{operation_name} failed after {self.max_retries} attempts: {str(last_exception)}'
137
  )
138
 
 
139
  async def _safe_sse_operation(self, operation_func):
140
  '''Safely execute an SSE operation with proper task cleanup
141
 
@@ -184,6 +197,7 @@ class MCPClientWrapper:
184
  logger.warning('Error during task cleanup: %s', cleanup_error)
185
  raise
186
 
 
187
  async def list_tools(self) -> List[ToolDef]:
188
  '''List available tools from the MCP endpoint
189
 
@@ -194,6 +208,7 @@ class MCPClientWrapper:
194
  MCPConnectionError: If connection fails
195
  MCPTimeoutError: If operation times out
196
  '''
 
197
  async def _list_tools_operation():
198
  async def _operation(session):
199
 
@@ -224,8 +239,74 @@ class MCPClientWrapper:
224
  identifier=tool.name # Using name as identifier
225
  )
226
  )
 
 
 
227
  return tools
228
 
229
  return await self._safe_sse_operation(_operation)
230
 
231
  return await self._execute_with_retry('list_tools', _list_tools_operation)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  identifier: str = ''
47
 
48
 
49
+ @dataclass
50
+ class ToolInvocationResult:
51
+ '''Represents the result of a tool invocation.
52
+
53
+ Attributes:
54
+ content: Result content as a string
55
+ error_code: Error code (0 for success, 1 for error)
56
+ '''
57
+ content: str
58
+ error_code: int
59
+
60
+
61
  class MCPConnectionError(Exception):
62
  '''Exception raised when MCP connection fails'''
63
  pass
 
80
  max_retries: Maximum number of retry attempts
81
  '''
82
 
 
 
83
  self.endpoint = endpoint
84
  self.timeout = timeout
85
  self.max_retries = max_retries
86
+ # self.tools = None
87
+ # self.anthropic = Anthropic()
88
 
89
 
90
  async def _execute_with_retry(self, operation_name: str, operation_func):
 
148
  f'{operation_name} failed after {self.max_retries} attempts: {str(last_exception)}'
149
  )
150
 
151
+
152
  async def _safe_sse_operation(self, operation_func):
153
  '''Safely execute an SSE operation with proper task cleanup
154
 
 
197
  logger.warning('Error during task cleanup: %s', cleanup_error)
198
  raise
199
 
200
+
201
  async def list_tools(self) -> List[ToolDef]:
202
  '''List available tools from the MCP endpoint
203
 
 
208
  MCPConnectionError: If connection fails
209
  MCPTimeoutError: If operation times out
210
  '''
211
+
212
  async def _list_tools_operation():
213
  async def _operation(session):
214
 
 
239
  identifier=tool.name # Using name as identifier
240
  )
241
  )
242
+
243
+ self.tools = tools
244
+
245
  return tools
246
 
247
  return await self._safe_sse_operation(_operation)
248
 
249
  return await self._execute_with_retry('list_tools', _list_tools_operation)
250
+
251
+
252
+ async def invoke_tool(self, tool_name: str, kwargs: Dict[str, Any]) -> ToolInvocationResult:
253
+ '''Invoke a specific tool with parameters
254
+
255
+ Args:
256
+ tool_name: Name of the tool to invoke
257
+ kwargs: Dictionary of parameters to pass to the tool
258
+
259
+ Returns:
260
+ ToolInvocationResult containing the tool's response
261
+
262
+ Raises:
263
+ MCPConnectionError: If connection fails
264
+ MCPTimeoutError: If operation times out
265
+ '''
266
+
267
+ async def _invoke_tool_operation():
268
+ async def _operation(session):
269
+ result = await session.call_tool(tool_name, kwargs)
270
+ return ToolInvocationResult(
271
+ content='\n'.join([result.model_dump_json() for result in result.content]),
272
+ error_code=1 if result.isError else 0,
273
+ )
274
+
275
+ return await self._safe_sse_operation(_operation)
276
+
277
+ return await self._execute_with_retry(f'invoke_tool({tool_name})', _invoke_tool_operation)
278
+
279
+
280
+ async def check_connection(self) -> bool:
281
+ '''Check if the MCP endpoint is reachable
282
+
283
+ Returns:
284
+ True if connection is successful, False otherwise
285
+ '''
286
+
287
+ logger = logging.getLogger(__name__ + '_check_connection')
288
+
289
+ try:
290
+ await self.list_tools()
291
+ return True
292
+ except Exception as e: # pylint: disable=broad-exception-caught
293
+ logger.debug('Connection check failed: %s', str(e))
294
+ return False
295
+
296
+
297
+ def get_endpoint_info(self) -> Dict[str, Any]:
298
+ '''Get information about the configured endpoint
299
+
300
+ Returns:
301
+ Dictionary with endpoint information
302
+ '''
303
+ parsed = urlparse(self.endpoint)
304
+ return {
305
+ 'endpoint': self.endpoint,
306
+ 'scheme': parsed.scheme,
307
+ 'hostname': parsed.hostname,
308
+ 'port': parsed.port,
309
+ 'path': parsed.path,
310
+ 'timeout': self.timeout,
311
+ 'max_retries': self.max_retries
312
+ }
rss_client.py CHANGED
@@ -1,11 +1,13 @@
1
  '''RSS MCP server demonstration client app.'''
2
 
 
3
  import logging
4
  from pathlib import Path
5
  from logging.handlers import RotatingFileHandler
6
 
7
  import gradio as gr
8
- from classes.client import MCPClientWrapper
 
9
 
10
  # Make sure log directory exists
11
  Path('logs').mkdir(parents=True, exist_ok=True)
@@ -15,7 +17,7 @@ logger = logging.getLogger()
15
 
16
  logging.basicConfig(
17
  handlers=[RotatingFileHandler(
18
- 'logs/rss_server.log',
19
  maxBytes=100000,
20
  backupCount=10,
21
  mode='w'
@@ -27,17 +29,50 @@ logging.basicConfig(
27
  logger = logging.getLogger(__name__)
28
 
29
  client = MCPClientWrapper('https://agents-mcp-hackathon-rss-mcp-server.hf.space/gradio_api/mcp/sse')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
31
  with gr.Blocks(title='MCP RSS client') as demo:
32
  gr.Markdown('# MCP RSS reader')
33
  gr.Markdown(
34
- 'Connect to the MCP RSS server: https://huggingface.co/spaces/Agents-MCP-Hackathon/rss-mcp-server'
 
35
  )
36
 
37
  connect_btn = gr.Button('Connect')
38
- status = gr.Textbox(label='Connection Status', interactive=False, lines=50)
39
- connect_btn.click(client.list_tools, outputs=status) # pylint: disable=no-member
 
 
 
 
 
 
 
40
 
 
 
 
 
 
 
 
 
41
 
42
  if __name__ == '__main__':
43
 
 
1
  '''RSS MCP server demonstration client app.'''
2
 
3
+ import os
4
  import logging
5
  from pathlib import Path
6
  from logging.handlers import RotatingFileHandler
7
 
8
  import gradio as gr
9
+ from client.mcp_client import MCPClientWrapper
10
+ from client.anthropic_bridge import AnthropicBridge
11
 
12
  # Make sure log directory exists
13
  Path('logs').mkdir(parents=True, exist_ok=True)
 
17
 
18
  logging.basicConfig(
19
  handlers=[RotatingFileHandler(
20
+ 'logs/rss_client.log',
21
  maxBytes=100000,
22
  backupCount=10,
23
  mode='w'
 
29
  logger = logging.getLogger(__name__)
30
 
31
  client = MCPClientWrapper('https://agents-mcp-hackathon-rss-mcp-server.hf.space/gradio_api/mcp/sse')
32
+ bridge = AnthropicBridge(
33
+ client,
34
+ api_key=os.environ['ANTHROPIC_API_KEY']
35
+ )
36
+
37
+ async def submit_input(message: str, chat_history: list) -> str:
38
+ '''Submits user message to agent'''
39
+
40
+ function_logger = logging.getLogger(__name__ + '.submit_input')
41
+
42
+ result = await bridge.process_query(message)
43
+ function_logger.info(result)
44
+ chat_history.append({"role": "user", "content": message})
45
+ chat_history.append({"role": "assistant", "content": result['llm_response'].content[0].text})
46
+
47
+ return '', chat_history
48
+
49
 
50
  with gr.Blocks(title='MCP RSS client') as demo:
51
  gr.Markdown('# MCP RSS reader')
52
  gr.Markdown(
53
+ 'Connect to the MCP RSS server: ' +
54
+ 'https://huggingface.co/spaces/Agents-MCP-Hackathon/rss-mcp-server'
55
  )
56
 
57
  connect_btn = gr.Button('Connect')
58
+ status = gr.Textbox(label='Connection Status', interactive=False, lines=10)
59
+
60
+ chatbot = gr.Chatbot(
61
+ value=[],
62
+ height=500,
63
+ type='messages',
64
+ show_copy_button=True,
65
+ avatar_images=('👤', '🤖')
66
+ )
67
 
68
+ msg = gr.Textbox(
69
+ label='Your Question',
70
+ placeholder='Ask about an RSS feed',
71
+ scale=4
72
+ )
73
+
74
+ connect_btn.click(client.list_tools, outputs=status) # pylint: disable=no-member
75
+ msg.submit(submit_input, [msg, chatbot], [msg, chatbot]) # pylint: disable=no-member
76
 
77
  if __name__ == '__main__':
78