Spaces:
Running
Running
import { Client } from "@modelcontextprotocol/sdk/client/index.js"; | |
import { WebSocketClientTransport } from "@modelcontextprotocol/sdk/client/websocket.js"; | |
import { StreamableHTTPClientTransport } from "@modelcontextprotocol/sdk/client/streamableHttp.js"; | |
import { SSEClientTransport } from "@modelcontextprotocol/sdk/client/sse.js"; | |
import type { Tool } from "@modelcontextprotocol/sdk/types.js"; | |
import type { | |
MCPServerConfig, | |
MCPServerConnection, | |
MCPClientState, | |
MCPToolResult, | |
} from "../types/mcp.js"; | |
export class MCPClientService { | |
private clients: Map<string, Client> = new Map(); | |
private connections: Map<string, MCPServerConnection> = new Map(); | |
private listeners: Array<(state: MCPClientState) => void> = []; | |
private healthCheckInterval: NodeJS.Timeout | null = null; | |
constructor() { | |
// Load saved server configurations from localStorage | |
this.loadServerConfigs(); | |
// Start health check every 5 minutes | |
this.startHealthCheck(); | |
} | |
// Start periodic health check (every 5 minutes) | |
private startHealthCheck() { | |
// Clear any existing interval | |
if (this.healthCheckInterval) { | |
clearInterval(this.healthCheckInterval); | |
} | |
// Set up new interval for 5 minutes (300,000 ms) | |
this.healthCheckInterval = setInterval(() => { | |
this.performHealthCheck(); | |
}, 5 * 60 * 1000); | |
} | |
// Perform health check on all connected servers | |
private async performHealthCheck() { | |
console.log("Performing MCP health check..."); | |
for (const [serverId, connection] of this.connections) { | |
if (connection.config.enabled && connection.isConnected) { | |
try { | |
const client = this.clients.get(serverId); | |
if (client) { | |
// Try to list tools as a health check | |
await client.listTools(); | |
console.log(`Health check passed for ${serverId}`); | |
} | |
} catch (error) { | |
console.warn(`Health check failed for ${serverId}:`, error); | |
// Mark as disconnected and try to reconnect | |
connection.isConnected = false; | |
connection.lastError = "Health check failed"; | |
// Try to reconnect | |
try { | |
await this.connectToServer(serverId); | |
console.log(`Reconnected to ${serverId}`); | |
} catch (reconnectError) { | |
console.error( | |
`Failed to reconnect to ${serverId}:`, | |
reconnectError | |
); | |
} | |
} | |
} | |
} | |
this.notifyStateChange(); | |
} | |
// Clean up health check on destruction | |
public cleanup() { | |
if (this.healthCheckInterval) { | |
clearInterval(this.healthCheckInterval); | |
this.healthCheckInterval = null; | |
} | |
} | |
// Add state change listener | |
addStateListener(listener: (state: MCPClientState) => void) { | |
this.listeners.push(listener); | |
} | |
// Remove state change listener | |
removeStateListener(listener: (state: MCPClientState) => void) { | |
const index = this.listeners.indexOf(listener); | |
if (index > -1) { | |
this.listeners.splice(index, 1); | |
} | |
} | |
// Notify all listeners of state changes | |
private notifyStateChange() { | |
const state = this.getState(); | |
this.listeners.forEach((listener) => listener(state)); | |
} | |
// Get current MCP client state | |
getState(): MCPClientState { | |
const servers: Record<string, MCPServerConnection> = {}; | |
for (const [id, connection] of this.connections) { | |
servers[id] = connection; | |
} | |
return { | |
servers, | |
isLoading: false, | |
error: undefined, | |
}; | |
} | |
// Load server configurations from localStorage | |
private loadServerConfigs() { | |
try { | |
const stored = localStorage.getItem("mcp-servers"); | |
if (stored) { | |
const configs: MCPServerConfig[] = JSON.parse(stored); | |
configs.forEach((config) => { | |
const connection: MCPServerConnection = { | |
config, | |
isConnected: false, | |
tools: [], | |
lastError: undefined, | |
lastConnected: undefined, | |
}; | |
this.connections.set(config.id, connection); | |
}); | |
} | |
} catch (error) { | |
console.error("Failed to load MCP server configs:", error); | |
} | |
} | |
// Save server configurations to localStorage | |
private saveServerConfigs() { | |
try { | |
const configs = Array.from(this.connections.values()).map( | |
(conn) => conn.config | |
); | |
localStorage.setItem("mcp-servers", JSON.stringify(configs)); | |
} catch (error) { | |
console.error("Failed to save MCP server configs:", error); | |
} | |
} | |
// Add a new MCP server | |
async addServer(config: MCPServerConfig): Promise<void> { | |
const connection: MCPServerConnection = { | |
config, | |
isConnected: false, | |
tools: [], | |
lastError: undefined, | |
lastConnected: undefined, | |
}; | |
this.connections.set(config.id, connection); | |
this.saveServerConfigs(); | |
this.notifyStateChange(); | |
// Auto-connect if enabled | |
if (config.enabled) { | |
await this.connectToServer(config.id); | |
} | |
} | |
// Remove an MCP server | |
async removeServer(serverId: string): Promise<void> { | |
// Disconnect first if connected | |
await this.disconnectFromServer(serverId); | |
// Remove from our maps | |
this.connections.delete(serverId); | |
this.clients.delete(serverId); | |
this.saveServerConfigs(); | |
this.notifyStateChange(); | |
} | |
// Connect to an MCP server | |
async connectToServer(serverId: string): Promise<void> { | |
const connection = this.connections.get(serverId); | |
if (!connection) { | |
throw new Error(`Server ${serverId} not found`); | |
} | |
if (connection.isConnected) { | |
return; // Already connected | |
} | |
try { | |
// Create client | |
const client = new Client( | |
{ | |
name: "LFM2-WebGPU", | |
version: "1.0.0", | |
}, | |
{ | |
capabilities: { | |
tools: {}, | |
}, | |
} | |
); | |
// Create transport based on config | |
let transport; | |
const url = new URL(connection.config.url); | |
// Prepare headers for authentication | |
const headers: Record<string, string> = {}; | |
if (connection.config.auth) { | |
switch (connection.config.auth.type) { | |
case "bearer": | |
if (connection.config.auth.token) { | |
headers[ | |
"Authorization" | |
] = `Bearer ${connection.config.auth.token}`; | |
} | |
break; | |
case "basic": | |
if ( | |
connection.config.auth.username && | |
connection.config.auth.password | |
) { | |
const credentials = btoa( | |
`${connection.config.auth.username}:${connection.config.auth.password}` | |
); | |
headers["Authorization"] = `Basic ${credentials}`; | |
} | |
break; | |
case "oauth": | |
if (connection.config.auth.token) { | |
headers[ | |
"Authorization" | |
] = `Bearer ${connection.config.auth.token}`; | |
} | |
break; | |
} | |
} | |
switch (connection.config.transport) { | |
case "websocket": { | |
// Convert HTTP/HTTPS URLs to WS/WSS | |
const wsUrl = new URL(connection.config.url); | |
wsUrl.protocol = wsUrl.protocol === "https:" ? "wss:" : "ws:"; | |
transport = new WebSocketClientTransport(wsUrl); | |
// Note: WebSocket auth headers would need to be passed differently | |
// For now, auth is only supported on HTTP-based transports | |
break; | |
} | |
case "streamable-http": | |
transport = new StreamableHTTPClientTransport(url, { | |
requestInit: | |
Object.keys(headers).length > 0 ? { headers } : undefined, | |
}); | |
break; | |
case "sse": | |
transport = new SSEClientTransport(url, { | |
requestInit: | |
Object.keys(headers).length > 0 ? { headers } : undefined, | |
}); | |
break; | |
default: | |
throw new Error( | |
`Unsupported transport: ${connection.config.transport}` | |
); | |
} | |
// Set up error handling | |
client.onerror = (error) => { | |
console.error(`MCP Client error for ${serverId}:`, error); | |
connection.lastError = error.message; | |
connection.isConnected = false; | |
this.notifyStateChange(); | |
}; | |
// Connect to the server | |
await client.connect(transport); | |
// List available tools | |
const toolsResult = await client.listTools(); | |
// Update connection state | |
connection.isConnected = true; | |
connection.tools = toolsResult.tools; | |
connection.lastError = undefined; | |
connection.lastConnected = new Date(); | |
// Store client reference | |
this.clients.set(serverId, client); | |
this.notifyStateChange(); | |
} catch (error) { | |
console.error(`Failed to connect to MCP server ${serverId}:`, error); | |
connection.isConnected = false; | |
connection.lastError = | |
error instanceof Error ? error.message : "Connection failed"; | |
this.notifyStateChange(); | |
throw error; | |
} | |
} | |
// Disconnect from an MCP server | |
async disconnectFromServer(serverId: string): Promise<void> { | |
const client = this.clients.get(serverId); | |
const connection = this.connections.get(serverId); | |
if (client) { | |
try { | |
await client.close(); | |
} catch (error) { | |
console.error(`Error disconnecting from ${serverId}:`, error); | |
} | |
this.clients.delete(serverId); | |
} | |
if (connection) { | |
connection.isConnected = false; | |
connection.tools = []; | |
this.notifyStateChange(); | |
} | |
} | |
// Get all tools from all connected servers | |
getAllTools(): Tool[] { | |
const allTools: Tool[] = []; | |
for (const connection of this.connections.values()) { | |
if (connection.isConnected && connection.config.enabled) { | |
allTools.push(...connection.tools); | |
} | |
} | |
return allTools; | |
} | |
// Call a tool on an MCP server | |
async callTool( | |
serverId: string, | |
toolName: string, | |
args: Record<string, unknown> | |
): Promise<MCPToolResult> { | |
const client = this.clients.get(serverId); | |
const connection = this.connections.get(serverId); | |
if (!client || !connection?.isConnected) { | |
throw new Error(`Not connected to server ${serverId}`); | |
} | |
try { | |
const result = await client.callTool({ | |
name: toolName, | |
arguments: args, | |
}); | |
return { | |
content: Array.isArray(result.content) ? result.content : [], | |
isError: Boolean(result.isError), | |
}; | |
} catch (error) { | |
console.error(`Error calling tool ${toolName} on ${serverId}:`, error); | |
throw error; | |
} | |
} | |
// Test connection to a server without saving it | |
async testConnection(config: MCPServerConfig): Promise<boolean> { | |
try { | |
const client = new Client( | |
{ | |
name: "LFM2-WebGPU-Test", | |
version: "1.0.0", | |
}, | |
{ | |
capabilities: { | |
tools: {}, | |
}, | |
} | |
); | |
let transport; | |
const url = new URL(config.url); | |
switch (config.transport) { | |
case "websocket": { | |
const wsUrl = new URL(config.url); | |
wsUrl.protocol = wsUrl.protocol === "https:" ? "wss:" : "ws:"; | |
transport = new WebSocketClientTransport(wsUrl); | |
break; | |
} | |
case "streamable-http": | |
transport = new StreamableHTTPClientTransport(url); | |
break; | |
case "sse": | |
transport = new SSEClientTransport(url); | |
break; | |
default: | |
throw new Error(`Unsupported transport: ${config.transport}`); | |
} | |
await client.connect(transport); | |
await client.close(); | |
return true; | |
} catch (error) { | |
console.error("Test connection failed:", error); | |
return false; | |
} | |
} | |
// Connect to all enabled servers | |
async connectAll(): Promise<void> { | |
const promises = Array.from(this.connections.entries()) | |
.filter( | |
([, connection]) => connection.config.enabled && !connection.isConnected | |
) | |
.map(([serverId]) => | |
this.connectToServer(serverId).catch((error) => { | |
console.error(`Failed to connect to ${serverId}:`, error); | |
}) | |
); | |
await Promise.all(promises); | |
} | |
// Disconnect from all servers | |
async disconnectAll(): Promise<void> { | |
const promises = Array.from(this.connections.keys()).map((serverId) => | |
this.disconnectFromServer(serverId) | |
); | |
await Promise.all(promises); | |
} | |
} | |