|
'''Classes for handling MCP server connection and operations.''' |
|
|
|
import asyncio |
|
import logging |
|
from typing import Any, Dict, List, Optional |
|
from urllib.parse import urlparse |
|
from dataclasses import dataclass |
|
|
|
from mcp import ClientSession |
|
from mcp.client.sse import sse_client |
|
|
|
|
|
@dataclass |
|
class ToolParameter: |
|
'''Represents a parameter for a tool. |
|
|
|
Attributes: |
|
name: Parameter name |
|
parameter_type: Parameter type (e.g., 'string', 'number') |
|
description: Parameter description |
|
required: Whether the parameter is required |
|
default: Default value for the parameter |
|
''' |
|
name: str |
|
parameter_type: str |
|
description: str |
|
required: bool = False |
|
default: Any = None |
|
|
|
|
|
@dataclass |
|
class ToolDef: |
|
'''Represents a tool definition. |
|
|
|
Attributes: |
|
name: Tool name |
|
description: Tool description |
|
parameters: List of ToolParameter objects |
|
metadata: Optional dictionary of additional metadata |
|
identifier: Tool identifier (defaults to name) |
|
''' |
|
name: str |
|
description: str |
|
parameters: List[ToolParameter] |
|
metadata: Optional[Dict[str, Any]] = None |
|
identifier: str = '' |
|
|
|
|
|
@dataclass |
|
class ToolInvocationResult: |
|
'''Represents the result of a tool invocation. |
|
|
|
Attributes: |
|
content: Result content as a string |
|
error_code: Error code (0 for success, 1 for error) |
|
''' |
|
content: str |
|
error_code: int |
|
|
|
|
|
class MCPConnectionError(Exception): |
|
'''Exception raised when MCP connection fails''' |
|
pass |
|
|
|
|
|
class MCPTimeoutError(Exception): |
|
'''Exception raised when MCP operation times out''' |
|
pass |
|
|
|
|
|
class MCPClientWrapper: |
|
'''Main client wrapper class for interacting with Model Context Protocol (MCP) endpoints''' |
|
|
|
def __init__(self, endpoint: str, timeout: float = 30.0, max_retries: int = 3): |
|
'''Initialize MCP client with endpoint URL |
|
|
|
Args: |
|
endpoint: The MCP endpoint URL (must be http or https) |
|
timeout: Connection timeout in seconds |
|
max_retries: Maximum number of retry attempts |
|
''' |
|
|
|
self.endpoint = endpoint |
|
self.timeout = timeout |
|
self.max_retries = max_retries |
|
|
|
|
|
async def _execute_with_retry(self, operation_name: str, operation_func): |
|
'''Execute an operation with retry logic and proper error handling |
|
|
|
Args: |
|
operation_name: Name of the operation for logging |
|
operation_func: Async function to execute |
|
|
|
Returns: |
|
Result of the operation |
|
|
|
Raises: |
|
MCPConnectionError: If connection fails after all retries |
|
MCPTimeoutError: If operation times out |
|
''' |
|
|
|
logger = logging.getLogger(__name__ + '_execute_with_retry') |
|
|
|
last_exception = None |
|
|
|
for attempt in range(self.max_retries): |
|
try: |
|
logger.debug( |
|
'Attempting %s (attempt %s/%s)', |
|
operation_name, |
|
attempt + 1, |
|
self.max_retries |
|
) |
|
|
|
|
|
result = await asyncio.wait_for(operation_func(), timeout=self.timeout) |
|
logger.debug('%s completed successfully', operation_name) |
|
return result |
|
|
|
except asyncio.TimeoutError as e: |
|
last_exception = MCPTimeoutError( |
|
f'{operation_name} timed out after {self.timeout} seconds' |
|
) |
|
logger.warning('%s timed out on attempt %s: %s', operation_name, attempt + 1, e) |
|
|
|
except Exception as e: |
|
last_exception = e |
|
logger.warning('%s failed on attempt %s: %s', operation_name, attempt + 1, str(e)) |
|
|
|
|
|
if isinstance(e, (ValueError, TypeError)): |
|
break |
|
|
|
|
|
if attempt < self.max_retries - 1: |
|
wait_time = 2 ** attempt |
|
logger.debug('Waiting %s seconds before retry', wait_time) |
|
await asyncio.sleep(wait_time) |
|
|
|
|
|
if isinstance(last_exception, MCPTimeoutError): |
|
raise last_exception |
|
else: |
|
raise MCPConnectionError( |
|
f'{operation_name} failed after {self.max_retries} attempts: {str(last_exception)}' |
|
) |
|
|
|
|
|
async def _safe_sse_operation(self, operation_func): |
|
'''Safely execute an SSE operation with proper task cleanup |
|
|
|
Args: |
|
operation_func: Function that takes (streams, session) as arguments |
|
|
|
Returns: |
|
Result of the operation |
|
''' |
|
|
|
logger = logging.getLogger(__name__ + '_safe_sse_operation') |
|
|
|
streams = None |
|
session = None |
|
|
|
try: |
|
|
|
streams = sse_client(self.endpoint) |
|
|
|
async with streams as stream_context: |
|
|
|
|
|
session = ClientSession(*stream_context) |
|
|
|
async with session as session_context: |
|
await session_context.initialize() |
|
return await operation_func(session_context) |
|
|
|
except Exception as e: |
|
logger.error('SSE operation failed: %s', str(e)) |
|
|
|
|
|
if session: |
|
try: |
|
|
|
tasks = [task for task in asyncio.all_tasks() if not task.done()] |
|
if tasks: |
|
logger.debug('Cancelling %s pending tasks', len(tasks)) |
|
for task in tasks: |
|
task.cancel() |
|
|
|
|
|
await asyncio.gather(*tasks, return_exceptions=True) |
|
|
|
except Exception as cleanup_error: |
|
logger.warning('Error during task cleanup: %s', cleanup_error) |
|
raise |
|
|
|
|
|
async def list_tools(self) -> List[ToolDef]: |
|
'''List available tools from the MCP endpoint |
|
|
|
Returns: |
|
List of ToolDef objects describing available tools |
|
|
|
Raises: |
|
MCPConnectionError: If connection fails |
|
MCPTimeoutError: If operation times out |
|
''' |
|
|
|
async def _list_tools_operation(): |
|
async def _operation(session): |
|
|
|
tools_result = await session.list_tools() |
|
tools = [] |
|
|
|
for tool in tools_result.tools: |
|
parameters = [] |
|
required_params = tool.inputSchema.get('required', []) |
|
|
|
for param_name, param_schema in tool.inputSchema.get('properties', {}).items(): |
|
parameters.append( |
|
ToolParameter( |
|
name=param_name, |
|
parameter_type=param_schema.get('type', 'string'), |
|
description=param_schema.get('description', ''), |
|
required=param_name in required_params, |
|
default=param_schema.get('default'), |
|
) |
|
) |
|
|
|
tools.append( |
|
ToolDef( |
|
name=tool.name, |
|
description=tool.description, |
|
parameters=parameters, |
|
metadata={'endpoint': self.endpoint}, |
|
identifier=tool.name |
|
) |
|
) |
|
|
|
self.tools = tools |
|
|
|
return tools |
|
|
|
return await self._safe_sse_operation(_operation) |
|
|
|
return await self._execute_with_retry('list_tools', _list_tools_operation) |
|
|
|
|
|
async def invoke_tool(self, tool_name: str, kwargs: Dict[str, Any]) -> ToolInvocationResult: |
|
'''Invoke a specific tool with parameters |
|
|
|
Args: |
|
tool_name: Name of the tool to invoke |
|
kwargs: Dictionary of parameters to pass to the tool |
|
|
|
Returns: |
|
ToolInvocationResult containing the tool's response |
|
|
|
Raises: |
|
MCPConnectionError: If connection fails |
|
MCPTimeoutError: If operation times out |
|
''' |
|
|
|
async def _invoke_tool_operation(): |
|
async def _operation(session): |
|
result = await session.call_tool(tool_name, kwargs) |
|
return ToolInvocationResult( |
|
content='\n'.join([result.model_dump_json() for result in result.content]), |
|
error_code=1 if result.isError else 0, |
|
) |
|
|
|
return await self._safe_sse_operation(_operation) |
|
|
|
return await self._execute_with_retry(f'invoke_tool({tool_name})', _invoke_tool_operation) |
|
|
|
|
|
async def check_connection(self) -> bool: |
|
'''Check if the MCP endpoint is reachable |
|
|
|
Returns: |
|
True if connection is successful, False otherwise |
|
''' |
|
|
|
logger = logging.getLogger(__name__ + '_check_connection') |
|
|
|
try: |
|
await self.list_tools() |
|
return True |
|
except Exception as e: |
|
logger.debug('Connection check failed: %s', str(e)) |
|
return False |
|
|
|
|
|
def get_endpoint_info(self) -> Dict[str, Any]: |
|
'''Get information about the configured endpoint |
|
|
|
Returns: |
|
Dictionary with endpoint information |
|
''' |
|
parsed = urlparse(self.endpoint) |
|
return { |
|
'endpoint': self.endpoint, |
|
'scheme': parsed.scheme, |
|
'hostname': parsed.hostname, |
|
'port': parsed.port, |
|
'path': parsed.path, |
|
'timeout': self.timeout, |
|
'max_retries': self.max_retries |
|
} |
|
|