brentyi commited on
Commit
932e635
·
1 Parent(s): cdf42c7
Files changed (2) hide show
  1. app.py +12 -5
  2. viser_proxy_manager.py +12 -3
app.py CHANGED
@@ -8,8 +8,7 @@ from viser_proxy_manager import ViserProxyManager
8
 
9
 
10
  def main() -> None:
11
- app = fastapi.FastAPI()
12
- viser_manager = ViserProxyManager(app)
13
 
14
  # Create a Gradio interface with title, iframe, and buttons
15
  with gr.Blocks(title="Viser Viewer") as demo:
@@ -28,9 +27,16 @@ def main() -> None:
28
  # Use the request's base URL if available
29
  host = request.headers["host"]
30
 
 
 
 
 
 
 
 
31
  return f"""
32
  <div style="border: 2px solid #ccc; padding: 10px;">
33
- <iframe src="http://{host}/viser/{request.session_hash}/" width="100%" height="500px" frameborder="0"></iframe>
34
  </div>
35
  """
36
 
@@ -56,8 +62,9 @@ def main() -> None:
56
  assert request.session_hash is not None
57
  viser_manager.stop_server(request.session_hash)
58
 
59
- app = gr.mount_gradio_app(app, demo, path="/")
60
- uvicorn.run(app, host="0.0.0.0", port=7860)
 
61
 
62
 
63
  if __name__ == "__main__":
 
8
 
9
 
10
  def main() -> None:
11
+ viser_manager = ViserProxyManager()
 
12
 
13
  # Create a Gradio interface with title, iframe, and buttons
14
  with gr.Blocks(title="Viser Viewer") as demo:
 
27
  # Use the request's base URL if available
28
  host = request.headers["host"]
29
 
30
+ # Determine protocol (use HTTPS for HuggingFace Spaces or other secure environments)
31
+ protocol = (
32
+ "https"
33
+ if request.headers.get("x-forwarded-proto") == "https"
34
+ else "http"
35
+ )
36
+
37
  return f"""
38
  <div style="border: 2px solid #ccc; padding: 10px;">
39
+ <iframe src="{protocol}://{host}/viser/{request.session_hash}/" width="100%" height="500px" frameborder="0"></iframe>
40
  </div>
41
  """
42
 
 
62
  assert request.session_hash is not None
63
  viser_manager.stop_server(request.session_hash)
64
 
65
+ demo.launch(prevent_thread_lock=True)
66
+ viser_manager.setup(demo.app)
67
+ demo.block_thread()
68
 
69
 
70
  if __name__ == "__main__":
viser_proxy_manager.py CHANGED
@@ -14,7 +14,6 @@ class ViserProxyManager:
14
  as well as proxying HTTP and WebSocket requests to the appropriate Viser server.
15
 
16
  Args:
17
- app: The FastAPI application to which the proxy routes will be added.
18
  min_local_port: Minimum local port number to use for Viser servers. Defaults to 8000.
19
  These ports are used only for internal communication and don't need to be publicly exposed.
20
  max_local_port: Maximum local port number to use for Viser servers. Defaults to 9000.
@@ -23,14 +22,24 @@ class ViserProxyManager:
23
 
24
  def __init__(
25
  self,
26
- app: FastAPI,
27
  min_local_port: int = 8000,
28
  max_local_port: int = 9000,
29
  ) -> None:
30
  self._min_port = min_local_port
31
  self._max_port = max_local_port
 
 
 
 
 
 
 
 
 
 
 
32
  self._server_from_session_hash: dict[str, viser.ViserServer] = {}
33
- self._last_port = min_local_port - 1 # Track last port tried
34
 
35
  @app.get("/viser/{server_id}/{proxy_path:path}")
36
  async def proxy(request: Request, server_id: str, proxy_path: str):
 
14
  as well as proxying HTTP and WebSocket requests to the appropriate Viser server.
15
 
16
  Args:
 
17
  min_local_port: Minimum local port number to use for Viser servers. Defaults to 8000.
18
  These ports are used only for internal communication and don't need to be publicly exposed.
19
  max_local_port: Maximum local port number to use for Viser servers. Defaults to 9000.
 
22
 
23
  def __init__(
24
  self,
 
25
  min_local_port: int = 8000,
26
  max_local_port: int = 9000,
27
  ) -> None:
28
  self._min_port = min_local_port
29
  self._max_port = max_local_port
30
+
31
+ def setup(
32
+ self,
33
+ app: FastAPI,
34
+ ) -> None:
35
+ """Set up the Viser proxy manager with the given FastAPI application.
36
+ This should be called after `demo.launch(prevent_thread_lock=True)`.
37
+
38
+ Args:
39
+ app: The FastAPI application to which the proxy routes will be added.
40
+ """
41
  self._server_from_session_hash: dict[str, viser.ViserServer] = {}
42
+ self._last_port = self._min_port - 1 # Track last port tried
43
 
44
  @app.get("/viser/{server_id}/{proxy_path:path}")
45
  async def proxy(request: Request, server_id: str, proxy_path: str):