thecollabagepatch commited on
Commit
85e8363
·
1 Parent(s): 0577e3b

new SPACE_MODE env variable so the template can stay up

Browse files
Files changed (1) hide show
  1. app.py +85 -11
app.py CHANGED
@@ -1,13 +1,21 @@
1
  import os
2
- # Useful XLA GPU optimizations (harmless if a flag is unknown)
3
- os.environ.setdefault(
4
- "XLA_FLAGS",
5
- " ".join([
6
- "--xla_gpu_enable_triton_gemm=true",
7
- "--xla_gpu_enable_latency_hiding_scheduler=true",
8
- "--xla_gpu_autotune_level=2",
9
- ])
10
- )
 
 
 
 
 
 
 
 
11
 
12
  # Optional: persist JAX compile cache across restarts (reduces warmup time)
13
  os.environ.setdefault("JAX_CACHE_DIR", "/home/appuser/.cache/jax")
@@ -32,7 +40,7 @@ except Exception:
32
 
33
  from magenta_rt import system, audio as au
34
  import numpy as np
35
- from fastapi import FastAPI, UploadFile, File, Form, Body, HTTPException, Response, Request, WebSocket, WebSocketDisconnect, Query
36
  import tempfile, io, base64, math, threading
37
  from fastapi.middleware.cors import CORSMiddleware
38
  from contextlib import contextmanager
@@ -76,6 +84,35 @@ from pydantic import BaseModel
76
 
77
  from model_management import CheckpointManager, AssetManager, ModelSelector, ModelSelect
78
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
  # ---- Finetune assets (mean & centroids) --------------------------------------
80
  # _FINETUNE_REPO_DEFAULT = os.getenv("MRT_ASSETS_REPO", "thepatch/magenta-ft")
81
  _ASSETS_REPO_ID: str | None = None
@@ -1108,7 +1145,44 @@ def jam_status(session_id: str):
1108
 
1109
  @app.get("/health")
1110
  def health():
1111
- return {"ok": True}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1112
 
1113
  @app.middleware("http")
1114
  async def log_requests(request: Request, call_next):
 
1
  import os
2
+
3
+ # ---- Space mode gating (place above any JAX import!) ----
4
+ SPACE_MODE = os.getenv("SPACE_MODE", "serve") # "serve" | "template"
5
+
6
+ if SPACE_MODE != "serve":
7
+ # In template mode, force JAX to CPU so it won't try to load CUDA plugins
8
+ os.environ.setdefault("JAX_PLATFORMS", "cpu")
9
+ else:
10
+ # Only set GPU-friendly XLA flags when we actually intend to serve on GPU
11
+ os.environ.setdefault(
12
+ "XLA_FLAGS",
13
+ " ".join([
14
+ "--xla_gpu_enable_triton_gemm=true",
15
+ "--xla_gpu_enable_latency_hiding_scheduler=true",
16
+ "--xla_gpu_autotune_level=2",
17
+ ])
18
+ )
19
 
20
  # Optional: persist JAX compile cache across restarts (reduces warmup time)
21
  os.environ.setdefault("JAX_CACHE_DIR", "/home/appuser/.cache/jax")
 
40
 
41
  from magenta_rt import system, audio as au
42
  import numpy as np
43
+ from fastapi import FastAPI, UploadFile, File, Form, Body, HTTPException, Response, Request, WebSocket, WebSocketDisconnect, Query, JSONResponse
44
  import tempfile, io, base64, math, threading
45
  from fastapi.middleware.cors import CORSMiddleware
46
  from contextlib import contextmanager
 
84
 
85
  from model_management import CheckpointManager, AssetManager, ModelSelector, ModelSelect
86
 
87
+ def _gpu_probe() -> dict:
88
+ """
89
+ Returns:
90
+ {
91
+ "ok": bool,
92
+ "backend": str | None, # "gpu" | "cpu" | "tpu" | None
93
+ "has_gpu": bool,
94
+ "devices": list[str], # e.g. ["gpu:0", "gpu:1"]
95
+ "error": str | None,
96
+ }
97
+ """
98
+ try:
99
+ import jax
100
+ try:
101
+ backend = jax.default_backend() # "gpu", "cpu", "tpu"
102
+ except Exception:
103
+ from jax.lib import xla_bridge
104
+ backend = getattr(xla_bridge.get_backend(), "platform", None)
105
+
106
+ try:
107
+ devices = jax.devices()
108
+ has_gpu = any(getattr(d, "platform", "") in ("gpu", "cuda", "rocm") for d in devices)
109
+ dev_list = [f"{getattr(d, 'platform', '?')}:{getattr(d, 'id', '?')}" for d in devices]
110
+ return {"ok": True, "backend": backend, "has_gpu": has_gpu, "devices": dev_list, "error": None}
111
+ except Exception as e:
112
+ return {"ok": False, "backend": backend, "has_gpu": False, "devices": [], "error": f"jax.devices failed: {e}"}
113
+ except Exception as e:
114
+ return {"ok": False, "backend": None, "has_gpu": False, "devices": [], "error": f"jax import failed: {e}"}
115
+
116
  # ---- Finetune assets (mean & centroids) --------------------------------------
117
  # _FINETUNE_REPO_DEFAULT = os.getenv("MRT_ASSETS_REPO", "thepatch/magenta-ft")
118
  _ASSETS_REPO_ID: str | None = None
 
1145
 
1146
  @app.get("/health")
1147
  def health():
1148
+ # 1) Template mode → not ready (encourage duplication on GPU)
1149
+ if SPACE_MODE != "serve":
1150
+ return JSONResponse(
1151
+ status_code=503,
1152
+ content={
1153
+ "ok": False,
1154
+ "status": "template_mode",
1155
+ "message": "This Space is a GPU template. Duplicate it and select an L40s/A100-class runtime to use the API.",
1156
+ "mode": SPACE_MODE,
1157
+ },
1158
+ )
1159
+
1160
+ # 2) Runtime hardware probe
1161
+ probe = _gpu_probe()
1162
+ if not probe["ok"] or not probe["has_gpu"] or probe.get("backend") != "gpu":
1163
+ return JSONResponse(
1164
+ status_code=503,
1165
+ content={
1166
+ "ok": False,
1167
+ "status": "gpu_unavailable",
1168
+ "message": "GPU is not visible to JAX. Select a GPU runtime (e.g., L40s) to serve.",
1169
+ "probe": probe,
1170
+ "mode": SPACE_MODE,
1171
+ },
1172
+ )
1173
+
1174
+ # 3) Ready; include operational hints
1175
+ warmed = bool(_WARMED)
1176
+ with jam_lock:
1177
+ active_jams = sum(1 for w in jam_registry.values() if w.is_alive())
1178
+ return {
1179
+ "ok": True,
1180
+ "status": "ready" if warmed else "initializing",
1181
+ "mode": SPACE_MODE,
1182
+ "warmed": warmed,
1183
+ "active_jams": active_jams,
1184
+ "probe": probe,
1185
+ }
1186
 
1187
  @app.middleware("http")
1188
  async def log_requests(request: Request, call_next):