brentyi commited on
Commit
ae1e5e7
·
1 Parent(s): 108d4e0

Add subprotocols support for websocket proxy

Browse files
Files changed (1) hide show
  1. viser_proxy_manager.py +34 -8
viser_proxy_manager.py CHANGED
@@ -83,7 +83,13 @@ class ViserProxyManager:
83
  @app.websocket("/viser/{server_id}")
84
  async def websocket_proxy(websocket: WebSocket, server_id: str):
85
  """Proxy WebSocket connections to the appropriate Viser server."""
86
- await websocket.accept()
 
 
 
 
 
 
87
 
88
  server = self._server_from_session_hash.get(server_id)
89
  if server is None:
@@ -92,15 +98,35 @@ class ViserProxyManager:
92
 
93
  # Determine target WebSocket URL
94
  target_ws_url = f"ws://127.0.0.1:{server.get_port()}"
95
-
96
- if not target_ws_url:
97
- await websocket.close(code=1008, reason="Not Found")
98
- return
99
-
100
  try:
101
- # Connect to the target WebSocket
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
  async with websockets.connect(
103
- target_ws_url, max_size=self._max_websocket_message_size_bytes
 
 
104
  ) as ws_target:
105
  # Create tasks for bidirectional communication
106
  async def forward_to_target():
 
83
  @app.websocket("/viser/{server_id}")
84
  async def websocket_proxy(websocket: WebSocket, server_id: str):
85
  """Proxy WebSocket connections to the appropriate Viser server."""
86
+ # Parse client's requested subprotocols
87
+ client_subprotocols = websocket.headers.get(
88
+ "sec-websocket-protocol", ""
89
+ ).split(",")
90
+ client_subprotocols = [
91
+ p.strip() for p in client_subprotocols if len(p.strip()) > 0
92
+ ]
93
 
94
  server = self._server_from_session_hash.get(server_id)
95
  if server is None:
 
98
 
99
  # Determine target WebSocket URL
100
  target_ws_url = f"ws://127.0.0.1:{server.get_port()}"
 
 
 
 
 
101
  try:
102
+ # First connect to the target server to determine which subprotocol it selects
103
+ selected_protocol = None
104
+
105
+ # Only attempt subprotocol negotiation if client requested any
106
+ if client_subprotocols:
107
+ try:
108
+ async with websockets.connect(
109
+ target_ws_url,
110
+ subprotocols=client_subprotocols, # type: ignore
111
+ max_size=self._max_websocket_message_size_bytes,
112
+ ) as ws:
113
+ # Get the selected protocol from the server
114
+ selected_protocol = ws.subprotocol
115
+ except Exception:
116
+ # If connection fails, we'll try again without subprotocol negotiation
117
+ pass
118
+
119
+ # Now accept the client connection with the protocol selected by the target server
120
+ if selected_protocol is not None:
121
+ await websocket.accept(subprotocol=selected_protocol)
122
+ else:
123
+ await websocket.accept()
124
+
125
+ # Establish the main connection to the target server
126
  async with websockets.connect(
127
+ target_ws_url,
128
+ max_size=self._max_websocket_message_size_bytes,
129
+ subprotocols=client_subprotocols if client_subprotocols else None, # type: ignore
130
  ) as ws_target:
131
  # Create tasks for bidirectional communication
132
  async def forward_to_target():