File size: 10,298 Bytes
7c27bf6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4facc97
 
 
 
 
 
 
 
 
 
 
 
7c27bf6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4facc97
7c27bf6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4facc97
7c27bf6
 
 
 
 
 
 
 
 
 
4facc97
7c27bf6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4facc97
 
 
7c27bf6
 
 
 
 
4facc97
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
'''Classes for handling MCP server connection and operations.'''

import asyncio
import logging
from typing import Any, Dict, List, Optional
from urllib.parse import urlparse
from dataclasses import dataclass

from mcp import ClientSession
from mcp.client.sse import sse_client


@dataclass
class ToolParameter:
    '''Represents a parameter for a tool.
    
    Attributes:
        name: Parameter name
        parameter_type: Parameter type (e.g., 'string', 'number')
        description: Parameter description
        required: Whether the parameter is required
        default: Default value for the parameter
    '''
    name: str
    parameter_type: str
    description: str
    required: bool = False
    default: Any = None


@dataclass
class ToolDef:
    '''Represents a tool definition.
    
    Attributes:
        name: Tool name
        description: Tool description
        parameters: List of ToolParameter objects
        metadata: Optional dictionary of additional metadata
        identifier: Tool identifier (defaults to name)
    '''
    name: str
    description: str
    parameters: List[ToolParameter]
    metadata: Optional[Dict[str, Any]] = None
    identifier: str = ''


@dataclass
class ToolInvocationResult:
    '''Represents the result of a tool invocation.
    
    Attributes:
        content: Result content as a string
        error_code: Error code (0 for success, 1 for error)
    '''
    content: str
    error_code: int


class MCPConnectionError(Exception):
    '''Exception raised when MCP connection fails'''
    pass


class MCPTimeoutError(Exception):
    '''Exception raised when MCP operation times out'''
    pass


class MCPClientWrapper:
    '''Main client wrapper class for interacting with Model Context Protocol (MCP) endpoints'''

    def __init__(self, endpoint: str, timeout: float = 30.0, max_retries: int = 3):
        '''Initialize MCP client with endpoint URL
        
        Args:
            endpoint: The MCP endpoint URL (must be http or https)
            timeout: Connection timeout in seconds
            max_retries: Maximum number of retry attempts
        '''

        self.endpoint = endpoint
        self.timeout = timeout
        self.max_retries = max_retries


    async def _execute_with_retry(self, operation_name: str, operation_func):
        '''Execute an operation with retry logic and proper error handling
        
        Args:
            operation_name: Name of the operation for logging
            operation_func: Async function to execute
            
        Returns:
            Result of the operation
            
        Raises:
            MCPConnectionError: If connection fails after all retries
            MCPTimeoutError: If operation times out
        '''

        logger = logging.getLogger(__name__ + '_execute_with_retry')

        last_exception = None

        for attempt in range(self.max_retries):
            try:
                logger.debug(
                    'Attempting %s (attempt %s/%s)',
                    operation_name,
                    attempt + 1,
                    self.max_retries
                )

                # Execute with timeout
                result = await asyncio.wait_for(operation_func(), timeout=self.timeout)
                logger.debug('%s completed successfully', operation_name)
                return result

            except asyncio.TimeoutError as e:
                last_exception = MCPTimeoutError(
                    f'{operation_name} timed out after {self.timeout} seconds'
                )
                logger.warning('%s timed out on attempt %s: %s', operation_name, attempt + 1, e)

            except Exception as e: # pylint: disable=broad-exception-caught
                last_exception = e
                logger.warning('%s failed on attempt %s: %s', operation_name, attempt + 1, str(e))

                # Don't retry on certain types of errors
                if isinstance(e, (ValueError, TypeError)):
                    break

            # Wait before retry (exponential backoff)
            if attempt < self.max_retries - 1:
                wait_time = 2 ** attempt
                logger.debug('Waiting %s seconds before retry', wait_time)
                await asyncio.sleep(wait_time)

        # All retries failed
        if isinstance(last_exception, MCPTimeoutError):
            raise last_exception
        else:
            raise MCPConnectionError(
                f'{operation_name} failed after {self.max_retries} attempts: {str(last_exception)}'
            )


    async def _safe_sse_operation(self, operation_func):
        '''Safely execute an SSE operation with proper task cleanup
        
        Args:
            operation_func: Function that takes (streams, session) as arguments
            
        Returns:
            Result of the operation
        '''

        logger = logging.getLogger(__name__ + '_safe_sse_operation')

        streams = None
        session = None

        try:
            # Create SSE client with proper error handling
            streams = sse_client(self.endpoint)

            async with streams as stream_context:

                # Create session with proper cleanup
                session = ClientSession(*stream_context)

                async with session as session_context:
                    await session_context.initialize()
                    return await operation_func(session_context)

        except Exception as e:
            logger.error('SSE operation failed: %s', str(e))

            # Ensure proper cleanup of any remaining tasks
            if session:
                try:
                    # Cancel any pending tasks in the session
                    tasks = [task for task in asyncio.all_tasks() if not task.done()]
                    if tasks:
                        logger.debug('Cancelling %s pending tasks', len(tasks))
                        for task in tasks:
                            task.cancel()

                        # Wait for tasks to be cancelled
                        await asyncio.gather(*tasks, return_exceptions=True)

                except Exception as cleanup_error: # pylint: disable=broad-exception-caught
                    logger.warning('Error during task cleanup: %s', cleanup_error)
            raise


    async def list_tools(self) -> List[ToolDef]:
        '''List available tools from the MCP endpoint
        
        Returns:
            List of ToolDef objects describing available tools
            
        Raises:
            MCPConnectionError: If connection fails
            MCPTimeoutError: If operation times out
        '''

        async def _list_tools_operation():
            async def _operation(session):

                tools_result = await session.list_tools()
                tools = []

                for tool in tools_result.tools:
                    parameters = []
                    required_params = tool.inputSchema.get('required', [])

                    for param_name, param_schema in tool.inputSchema.get('properties', {}).items():
                        parameters.append(
                            ToolParameter(
                                name=param_name,
                                parameter_type=param_schema.get('type', 'string'),
                                description=param_schema.get('description', ''),
                                required=param_name in required_params,
                                default=param_schema.get('default'),
                            )
                        )

                    tools.append(
                        ToolDef(
                            name=tool.name,
                            description=tool.description,
                            parameters=parameters,
                            metadata={'endpoint': self.endpoint},
                            identifier=tool.name  # Using name as identifier
                        )
                    )

                self.tools = tools

                return tools

            return await self._safe_sse_operation(_operation)

        return await self._execute_with_retry('list_tools', _list_tools_operation)


    async def invoke_tool(self, tool_name: str, kwargs: Dict[str, Any]) -> ToolInvocationResult:
        '''Invoke a specific tool with parameters
        
        Args:
            tool_name: Name of the tool to invoke
            kwargs: Dictionary of parameters to pass to the tool
            
        Returns:
            ToolInvocationResult containing the tool's response
            
        Raises:
            MCPConnectionError: If connection fails
            MCPTimeoutError: If operation times out
        '''

        async def _invoke_tool_operation():
            async def _operation(session):
                result = await session.call_tool(tool_name, kwargs)
                return ToolInvocationResult(
                    content='\n'.join([result.model_dump_json() for result in result.content]),
                    error_code=1 if result.isError else 0,
                )

            return await self._safe_sse_operation(_operation)

        return await self._execute_with_retry(f'invoke_tool({tool_name})', _invoke_tool_operation)


    async def check_connection(self) -> bool:
        '''Check if the MCP endpoint is reachable
        
        Returns:
            True if connection is successful, False otherwise
        '''

        logger = logging.getLogger(__name__ + '_check_connection')

        try:
            await self.list_tools()
            return True
        except Exception as e: # pylint: disable=broad-exception-caught
            logger.debug('Connection check failed: %s', str(e))
            return False


    def get_endpoint_info(self) -> Dict[str, Any]:
        '''Get information about the configured endpoint
        
        Returns:
            Dictionary with endpoint information
        '''
        parsed = urlparse(self.endpoint)
        return {
            'endpoint': self.endpoint,
            'scheme': parsed.scheme,
            'hostname': parsed.hostname,
            'port': parsed.port,
            'path': parsed.path,
            'timeout': self.timeout,
            'max_retries': self.max_retries
        }