gperdrizet commited on
Commit
6d8aa95
·
verified ·
1 Parent(s): dbed421

Cleaned up

Browse files
Files changed (1) hide show
  1. rss_client.py +63 -45
rss_client.py CHANGED
@@ -1,3 +1,5 @@
 
 
1
  import asyncio
2
  import logging
3
  from pathlib import Path
@@ -9,7 +11,6 @@ from dataclasses import dataclass
9
  import gradio as gr
10
  from mcp import ClientSession
11
  from mcp.client.sse import sse_client
12
- # from pydantic import BaseModel
13
 
14
 
15
  # Make sure log directory exists
@@ -34,15 +35,15 @@ logger = logging.getLogger(__name__)
34
 
35
  @dataclass
36
  class ToolParameter:
37
- """Represents a parameter for a tool.
38
 
39
  Attributes:
40
  name: Parameter name
41
- parameter_type: Parameter type (e.g., "string", "number")
42
  description: Parameter description
43
  required: Whether the parameter is required
44
  default: Default value for the parameter
45
- """
46
  name: str
47
  parameter_type: str
48
  description: str
@@ -52,7 +53,7 @@ class ToolParameter:
52
 
53
  @dataclass
54
  class ToolDef:
55
- """Represents a tool definition.
56
 
57
  Attributes:
58
  name: Tool name
@@ -60,21 +61,21 @@ class ToolDef:
60
  parameters: List of ToolParameter objects
61
  metadata: Optional dictionary of additional metadata
62
  identifier: Tool identifier (defaults to name)
63
- """
64
  name: str
65
  description: str
66
  parameters: List[ToolParameter]
67
  metadata: Optional[Dict[str, Any]] = None
68
- identifier: str = ""
69
 
70
 
71
  class MCPConnectionError(Exception):
72
- """Exception raised when MCP connection fails"""
73
  pass
74
 
75
 
76
  class MCPTimeoutError(Exception):
77
- """Exception raised when MCP operation times out"""
78
  pass
79
 
80
 
@@ -82,22 +83,22 @@ class MCPClientWrapper:
82
  '''Client for interacting with Model Context Protocol (MCP) endpoints'''
83
 
84
  def __init__(self, endpoint: str, timeout: float = 30.0, max_retries: int = 3):
85
- """Initialize MCP client with endpoint URL
86
 
87
  Args:
88
  endpoint: The MCP endpoint URL (must be http or https)
89
  timeout: Connection timeout in seconds
90
  max_retries: Maximum number of retry attempts
91
- """
92
- if urlparse(endpoint).scheme not in ("http", "https"):
93
- raise ValueError(f"Endpoint {endpoint} is not a valid HTTP(S) URL")
94
  self.endpoint = endpoint
95
  self.timeout = timeout
96
  self.max_retries = max_retries
97
 
98
 
99
  async def _execute_with_retry(self, operation_name: str, operation_func):
100
- """Execute an operation with retry logic and proper error handling
101
 
102
  Args:
103
  operation_name: Name of the operation for logging
@@ -109,25 +110,32 @@ class MCPClientWrapper:
109
  Raises:
110
  MCPConnectionError: If connection fails after all retries
111
  MCPTimeoutError: If operation times out
112
- """
113
  last_exception = None
114
 
115
  for attempt in range(self.max_retries):
116
  try:
117
- logger.debug(f"Attempting {operation_name} (attempt {attempt + 1}/{self.max_retries})")
 
 
 
 
 
118
 
119
  # Execute with timeout
120
  result = await asyncio.wait_for(operation_func(), timeout=self.timeout)
121
- logger.debug(f"{operation_name} completed successfully")
122
  return result
123
 
124
  except asyncio.TimeoutError as e:
125
- last_exception = MCPTimeoutError(f"{operation_name} timed out after {self.timeout} seconds")
126
- logger.warning(f"{operation_name} timed out on attempt {attempt + 1}")
 
 
127
 
128
- except Exception as e:
129
  last_exception = e
130
- logger.warning(f"{operation_name} failed on attempt {attempt + 1}: {str(e)}")
131
 
132
  # Don't retry on certain types of errors
133
  if isinstance(e, (ValueError, TypeError)):
@@ -136,56 +144,64 @@ class MCPClientWrapper:
136
  # Wait before retry (exponential backoff)
137
  if attempt < self.max_retries - 1:
138
  wait_time = 2 ** attempt
139
- logger.debug(f"Waiting {wait_time} seconds before retry")
140
  await asyncio.sleep(wait_time)
141
 
142
  # All retries failed
143
  if isinstance(last_exception, MCPTimeoutError):
144
  raise last_exception
145
  else:
146
- raise MCPConnectionError(f"{operation_name} failed after {self.max_retries} attempts: {str(last_exception)}")
 
 
147
 
148
  async def _safe_sse_operation(self, operation_func):
149
- """Safely execute an SSE operation with proper task cleanup
150
 
151
  Args:
152
  operation_func: Function that takes (streams, session) as arguments
153
 
154
  Returns:
155
  Result of the operation
156
- """
157
  streams = None
158
  session = None
159
 
160
  try:
161
  # Create SSE client with proper error handling
162
  streams = sse_client(self.endpoint)
 
163
  async with streams as stream_context:
 
164
  # Create session with proper cleanup
165
  session = ClientSession(*stream_context)
 
166
  async with session as session_context:
167
  await session_context.initialize()
168
  return await operation_func(session_context)
169
 
170
  except Exception as e:
171
- logger.error(f"SSE operation failed: {str(e)}")
 
172
  # Ensure proper cleanup of any remaining tasks
173
  if session:
174
  try:
175
  # Cancel any pending tasks in the session
176
  tasks = [task for task in asyncio.all_tasks() if not task.done()]
177
  if tasks:
178
- logger.debug(f"Cancelling {len(tasks)} pending tasks")
179
  for task in tasks:
180
  task.cancel()
 
181
  # Wait for tasks to be cancelled
182
  await asyncio.gather(*tasks, return_exceptions=True)
183
- except Exception as cleanup_error:
184
- logger.warning(f"Error during task cleanup: {cleanup_error}")
 
185
  raise
186
 
187
  async def list_tools(self) -> List[ToolDef]:
188
- """List available tools from the MCP endpoint
189
 
190
  Returns:
191
  List of ToolDef objects describing available tools
@@ -193,31 +209,34 @@ class MCPClientWrapper:
193
  Raises:
194
  MCPConnectionError: If connection fails
195
  MCPTimeoutError: If operation times out
196
- """
197
  async def _list_tools_operation():
198
  async def _operation(session):
 
199
  tools_result = await session.list_tools()
200
  tools = []
201
 
202
  for tool in tools_result.tools:
203
  parameters = []
204
- required_params = tool.inputSchema.get("required", [])
205
- for param_name, param_schema in tool.inputSchema.get("properties", {}).items():
 
206
  parameters.append(
207
  ToolParameter(
208
  name=param_name,
209
- parameter_type=param_schema.get("type", "string"),
210
- description=param_schema.get("description", ""),
211
  required=param_name in required_params,
212
- default=param_schema.get("default"),
213
  )
214
  )
 
215
  tools.append(
216
  ToolDef(
217
  name=tool.name,
218
  description=tool.description,
219
  parameters=parameters,
220
- metadata={"endpoint": self.endpoint},
221
  identifier=tool.name # Using name as identifier
222
  )
223
  )
@@ -225,21 +244,20 @@ class MCPClientWrapper:
225
 
226
  return await self._safe_sse_operation(_operation)
227
 
228
- return await self._execute_with_retry("list_tools", _list_tools_operation)
229
 
230
  client = MCPClientWrapper('https://agents-mcp-hackathon-rss-mcp-server.hf.space/gradio_api/mcp/sse')
231
 
232
  # def gradio_interface():
233
- with gr.Blocks(title="MCP RSS client") as demo:
234
- gr.Markdown("# MCP RSS reader")
235
- gr.Markdown("Connect to your MCP RSS server")
236
 
237
  connect_btn = gr.Button('Connect')
238
- status = gr.Textbox(label='Connection Status', interactive=False)
239
-
240
- connect_btn.click(client.list_tools, outputs=status)
241
 
242
 
243
- if __name__ == "__main__":
244
 
245
  demo.launch(debug=True)
 
1
+ '''RSS MCP server demonstration client app.'''
2
+
3
  import asyncio
4
  import logging
5
  from pathlib import Path
 
11
  import gradio as gr
12
  from mcp import ClientSession
13
  from mcp.client.sse import sse_client
 
14
 
15
 
16
  # Make sure log directory exists
 
35
 
36
  @dataclass
37
  class ToolParameter:
38
+ '''Represents a parameter for a tool.
39
 
40
  Attributes:
41
  name: Parameter name
42
+ parameter_type: Parameter type (e.g., 'string', 'number')
43
  description: Parameter description
44
  required: Whether the parameter is required
45
  default: Default value for the parameter
46
+ '''
47
  name: str
48
  parameter_type: str
49
  description: str
 
53
 
54
  @dataclass
55
  class ToolDef:
56
+ '''Represents a tool definition.
57
 
58
  Attributes:
59
  name: Tool name
 
61
  parameters: List of ToolParameter objects
62
  metadata: Optional dictionary of additional metadata
63
  identifier: Tool identifier (defaults to name)
64
+ '''
65
  name: str
66
  description: str
67
  parameters: List[ToolParameter]
68
  metadata: Optional[Dict[str, Any]] = None
69
+ identifier: str = ''
70
 
71
 
72
  class MCPConnectionError(Exception):
73
+ '''Exception raised when MCP connection fails'''
74
  pass
75
 
76
 
77
  class MCPTimeoutError(Exception):
78
+ '''Exception raised when MCP operation times out'''
79
  pass
80
 
81
 
 
83
  '''Client for interacting with Model Context Protocol (MCP) endpoints'''
84
 
85
  def __init__(self, endpoint: str, timeout: float = 30.0, max_retries: int = 3):
86
+ '''Initialize MCP client with endpoint URL
87
 
88
  Args:
89
  endpoint: The MCP endpoint URL (must be http or https)
90
  timeout: Connection timeout in seconds
91
  max_retries: Maximum number of retry attempts
92
+ '''
93
+ if urlparse(endpoint).scheme not in ('http', 'https'):
94
+ raise ValueError(f'Endpoint {endpoint} is not a valid HTTP(S) URL')
95
  self.endpoint = endpoint
96
  self.timeout = timeout
97
  self.max_retries = max_retries
98
 
99
 
100
  async def _execute_with_retry(self, operation_name: str, operation_func):
101
+ '''Execute an operation with retry logic and proper error handling
102
 
103
  Args:
104
  operation_name: Name of the operation for logging
 
110
  Raises:
111
  MCPConnectionError: If connection fails after all retries
112
  MCPTimeoutError: If operation times out
113
+ '''
114
  last_exception = None
115
 
116
  for attempt in range(self.max_retries):
117
  try:
118
+ logger.debug(
119
+ 'Attempting %s (attempt %s/%s)',
120
+ operation_name,
121
+ attempt + 1,
122
+ self.max_retries
123
+ )
124
 
125
  # Execute with timeout
126
  result = await asyncio.wait_for(operation_func(), timeout=self.timeout)
127
+ logger.debug('%s completed successfully', operation_name)
128
  return result
129
 
130
  except asyncio.TimeoutError as e:
131
+ last_exception = MCPTimeoutError(
132
+ f'{operation_name} timed out after {self.timeout} seconds'
133
+ )
134
+ logger.warning('%s timed out on attempt %s: %s', operation_name, attempt + 1, e)
135
 
136
+ except Exception as e: # pylint: disable=broad-exception-caught
137
  last_exception = e
138
+ logger.warning('%s failed on attempt %s: %s', operation_name, attempt + 1, str(e))
139
 
140
  # Don't retry on certain types of errors
141
  if isinstance(e, (ValueError, TypeError)):
 
144
  # Wait before retry (exponential backoff)
145
  if attempt < self.max_retries - 1:
146
  wait_time = 2 ** attempt
147
+ logger.debug('Waiting %s seconds before retry', wait_time)
148
  await asyncio.sleep(wait_time)
149
 
150
  # All retries failed
151
  if isinstance(last_exception, MCPTimeoutError):
152
  raise last_exception
153
  else:
154
+ raise MCPConnectionError(
155
+ f'{operation_name} failed after {self.max_retries} attempts: {str(last_exception)}'
156
+ )
157
 
158
  async def _safe_sse_operation(self, operation_func):
159
+ '''Safely execute an SSE operation with proper task cleanup
160
 
161
  Args:
162
  operation_func: Function that takes (streams, session) as arguments
163
 
164
  Returns:
165
  Result of the operation
166
+ '''
167
  streams = None
168
  session = None
169
 
170
  try:
171
  # Create SSE client with proper error handling
172
  streams = sse_client(self.endpoint)
173
+
174
  async with streams as stream_context:
175
+
176
  # Create session with proper cleanup
177
  session = ClientSession(*stream_context)
178
+
179
  async with session as session_context:
180
  await session_context.initialize()
181
  return await operation_func(session_context)
182
 
183
  except Exception as e:
184
+ logger.error('SSE operation failed: %s', str(e))
185
+
186
  # Ensure proper cleanup of any remaining tasks
187
  if session:
188
  try:
189
  # Cancel any pending tasks in the session
190
  tasks = [task for task in asyncio.all_tasks() if not task.done()]
191
  if tasks:
192
+ logger.debug('Cancelling %s pending tasks', len(tasks))
193
  for task in tasks:
194
  task.cancel()
195
+
196
  # Wait for tasks to be cancelled
197
  await asyncio.gather(*tasks, return_exceptions=True)
198
+
199
+ except Exception as cleanup_error: # pylint: disable=broad-exception-caught
200
+ logger.warning('Error during task cleanup: %s', cleanup_error)
201
  raise
202
 
203
  async def list_tools(self) -> List[ToolDef]:
204
+ '''List available tools from the MCP endpoint
205
 
206
  Returns:
207
  List of ToolDef objects describing available tools
 
209
  Raises:
210
  MCPConnectionError: If connection fails
211
  MCPTimeoutError: If operation times out
212
+ '''
213
  async def _list_tools_operation():
214
  async def _operation(session):
215
+
216
  tools_result = await session.list_tools()
217
  tools = []
218
 
219
  for tool in tools_result.tools:
220
  parameters = []
221
+ required_params = tool.inputSchema.get('required', [])
222
+
223
+ for param_name, param_schema in tool.inputSchema.get('properties', {}).items():
224
  parameters.append(
225
  ToolParameter(
226
  name=param_name,
227
+ parameter_type=param_schema.get('type', 'string'),
228
+ description=param_schema.get('description', ''),
229
  required=param_name in required_params,
230
+ default=param_schema.get('default'),
231
  )
232
  )
233
+
234
  tools.append(
235
  ToolDef(
236
  name=tool.name,
237
  description=tool.description,
238
  parameters=parameters,
239
+ metadata={'endpoint': self.endpoint},
240
  identifier=tool.name # Using name as identifier
241
  )
242
  )
 
244
 
245
  return await self._safe_sse_operation(_operation)
246
 
247
+ return await self._execute_with_retry('list_tools', _list_tools_operation)
248
 
249
  client = MCPClientWrapper('https://agents-mcp-hackathon-rss-mcp-server.hf.space/gradio_api/mcp/sse')
250
 
251
  # def gradio_interface():
252
+ with gr.Blocks(title='MCP RSS client') as demo:
253
+ gr.Markdown('# MCP RSS reader')
254
+ gr.Markdown('Connect to the MCP RSS server')
255
 
256
  connect_btn = gr.Button('Connect')
257
+ status = gr.Textbox(label='Connection Status', interactive=False, lines=50)
258
+ connect_btn.click(client.list_tools, outputs=status) # pylint: disable=no-member
 
259
 
260
 
261
+ if __name__ == '__main__':
262
 
263
  demo.launch(debug=True)