gperdrizet commited on
Commit
107cced
·
verified ·
1 Parent(s): a32b900

Finished MCP server connection logic, tested reading tools.

Browse files
Files changed (1) hide show
  1. rss_client.py +245 -0
rss_client.py ADDED
@@ -0,0 +1,245 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import logging
3
+ from pathlib import Path
4
+ from logging.handlers import RotatingFileHandler
5
+ from typing import Any, Dict, List, Optional
6
+ from urllib.parse import urlparse
7
+ from dataclasses import dataclass
8
+
9
+ import gradio as gr
10
+ from mcp import ClientSession
11
+ from mcp.client.sse import sse_client
12
+ # from pydantic import BaseModel
13
+
14
+
15
+ # Make sure log directory exists
16
+ Path('logs').mkdir(parents=True, exist_ok=True)
17
+
18
+ # Set-up logger
19
+ logger = logging.getLogger()
20
+
21
+ logging.basicConfig(
22
+ handlers=[RotatingFileHandler(
23
+ 'logs/rss_server.log',
24
+ maxBytes=100000,
25
+ backupCount=10,
26
+ mode='w'
27
+ )],
28
+ level=logging.INFO,
29
+ format='%(levelname)s - %(name)s - %(message)s'
30
+ )
31
+
32
+ logger = logging.getLogger(__name__)
33
+
34
+
35
+ @dataclass
36
+ class ToolParameter:
37
+ """Represents a parameter for a tool.
38
+
39
+ Attributes:
40
+ name: Parameter name
41
+ parameter_type: Parameter type (e.g., "string", "number")
42
+ description: Parameter description
43
+ required: Whether the parameter is required
44
+ default: Default value for the parameter
45
+ """
46
+ name: str
47
+ parameter_type: str
48
+ description: str
49
+ required: bool = False
50
+ default: Any = None
51
+
52
+
53
+ @dataclass
54
+ class ToolDef:
55
+ """Represents a tool definition.
56
+
57
+ Attributes:
58
+ name: Tool name
59
+ description: Tool description
60
+ parameters: List of ToolParameter objects
61
+ metadata: Optional dictionary of additional metadata
62
+ identifier: Tool identifier (defaults to name)
63
+ """
64
+ name: str
65
+ description: str
66
+ parameters: List[ToolParameter]
67
+ metadata: Optional[Dict[str, Any]] = None
68
+ identifier: str = ""
69
+
70
+
71
+ class MCPConnectionError(Exception):
72
+ """Exception raised when MCP connection fails"""
73
+ pass
74
+
75
+
76
+ class MCPTimeoutError(Exception):
77
+ """Exception raised when MCP operation times out"""
78
+ pass
79
+
80
+
81
+ class MCPClientWrapper:
82
+ '''Client for interacting with Model Context Protocol (MCP) endpoints'''
83
+
84
+ def __init__(self, endpoint: str, timeout: float = 30.0, max_retries: int = 3):
85
+ """Initialize MCP client with endpoint URL
86
+
87
+ Args:
88
+ endpoint: The MCP endpoint URL (must be http or https)
89
+ timeout: Connection timeout in seconds
90
+ max_retries: Maximum number of retry attempts
91
+ """
92
+ if urlparse(endpoint).scheme not in ("http", "https"):
93
+ raise ValueError(f"Endpoint {endpoint} is not a valid HTTP(S) URL")
94
+ self.endpoint = endpoint
95
+ self.timeout = timeout
96
+ self.max_retries = max_retries
97
+
98
+
99
+ async def _execute_with_retry(self, operation_name: str, operation_func):
100
+ """Execute an operation with retry logic and proper error handling
101
+
102
+ Args:
103
+ operation_name: Name of the operation for logging
104
+ operation_func: Async function to execute
105
+
106
+ Returns:
107
+ Result of the operation
108
+
109
+ Raises:
110
+ MCPConnectionError: If connection fails after all retries
111
+ MCPTimeoutError: If operation times out
112
+ """
113
+ last_exception = None
114
+
115
+ for attempt in range(self.max_retries):
116
+ try:
117
+ logger.debug(f"Attempting {operation_name} (attempt {attempt + 1}/{self.max_retries})")
118
+
119
+ # Execute with timeout
120
+ result = await asyncio.wait_for(operation_func(), timeout=self.timeout)
121
+ logger.debug(f"{operation_name} completed successfully")
122
+ return result
123
+
124
+ except asyncio.TimeoutError as e:
125
+ last_exception = MCPTimeoutError(f"{operation_name} timed out after {self.timeout} seconds")
126
+ logger.warning(f"{operation_name} timed out on attempt {attempt + 1}")
127
+
128
+ except Exception as e:
129
+ last_exception = e
130
+ logger.warning(f"{operation_name} failed on attempt {attempt + 1}: {str(e)}")
131
+
132
+ # Don't retry on certain types of errors
133
+ if isinstance(e, (ValueError, TypeError)):
134
+ break
135
+
136
+ # Wait before retry (exponential backoff)
137
+ if attempt < self.max_retries - 1:
138
+ wait_time = 2 ** attempt
139
+ logger.debug(f"Waiting {wait_time} seconds before retry")
140
+ await asyncio.sleep(wait_time)
141
+
142
+ # All retries failed
143
+ if isinstance(last_exception, MCPTimeoutError):
144
+ raise last_exception
145
+ else:
146
+ raise MCPConnectionError(f"{operation_name} failed after {self.max_retries} attempts: {str(last_exception)}")
147
+
148
+ async def _safe_sse_operation(self, operation_func):
149
+ """Safely execute an SSE operation with proper task cleanup
150
+
151
+ Args:
152
+ operation_func: Function that takes (streams, session) as arguments
153
+
154
+ Returns:
155
+ Result of the operation
156
+ """
157
+ streams = None
158
+ session = None
159
+
160
+ try:
161
+ # Create SSE client with proper error handling
162
+ streams = sse_client(self.endpoint)
163
+ async with streams as stream_context:
164
+ # Create session with proper cleanup
165
+ session = ClientSession(*stream_context)
166
+ async with session as session_context:
167
+ await session_context.initialize()
168
+ return await operation_func(session_context)
169
+
170
+ except Exception as e:
171
+ logger.error(f"SSE operation failed: {str(e)}")
172
+ # Ensure proper cleanup of any remaining tasks
173
+ if session:
174
+ try:
175
+ # Cancel any pending tasks in the session
176
+ tasks = [task for task in asyncio.all_tasks() if not task.done()]
177
+ if tasks:
178
+ logger.debug(f"Cancelling {len(tasks)} pending tasks")
179
+ for task in tasks:
180
+ task.cancel()
181
+ # Wait for tasks to be cancelled
182
+ await asyncio.gather(*tasks, return_exceptions=True)
183
+ except Exception as cleanup_error:
184
+ logger.warning(f"Error during task cleanup: {cleanup_error}")
185
+ raise
186
+
187
+ async def list_tools(self) -> List[ToolDef]:
188
+ """List available tools from the MCP endpoint
189
+
190
+ Returns:
191
+ List of ToolDef objects describing available tools
192
+
193
+ Raises:
194
+ MCPConnectionError: If connection fails
195
+ MCPTimeoutError: If operation times out
196
+ """
197
+ async def _list_tools_operation():
198
+ async def _operation(session):
199
+ tools_result = await session.list_tools()
200
+ tools = []
201
+
202
+ for tool in tools_result.tools:
203
+ parameters = []
204
+ required_params = tool.inputSchema.get("required", [])
205
+ for param_name, param_schema in tool.inputSchema.get("properties", {}).items():
206
+ parameters.append(
207
+ ToolParameter(
208
+ name=param_name,
209
+ parameter_type=param_schema.get("type", "string"),
210
+ description=param_schema.get("description", ""),
211
+ required=param_name in required_params,
212
+ default=param_schema.get("default"),
213
+ )
214
+ )
215
+ tools.append(
216
+ ToolDef(
217
+ name=tool.name,
218
+ description=tool.description,
219
+ parameters=parameters,
220
+ metadata={"endpoint": self.endpoint},
221
+ identifier=tool.name # Using name as identifier
222
+ )
223
+ )
224
+ return tools
225
+
226
+ return await self._safe_sse_operation(_operation)
227
+
228
+ return await self._execute_with_retry("list_tools", _list_tools_operation)
229
+
230
+ client = MCPClientWrapper('https://agents-mcp-hackathon-rss-mcp-server.hf.space/gradio_api/mcp/sse')
231
+
232
+ # def gradio_interface():
233
+ with gr.Blocks(title="MCP RSS client") as demo:
234
+ gr.Markdown("# MCP RSS reader")
235
+ gr.Markdown("Connect to your MCP RSS server")
236
+
237
+ connect_btn = gr.Button('Connect')
238
+ status = gr.Textbox(label='Connection Status', interactive=False)
239
+
240
+ connect_btn.click(client.list_tools, outputs=status)
241
+
242
+
243
+ if __name__ == "__main__":
244
+
245
+ demo.launch(debug=True)