Commit
·
85e8363
1
Parent(s):
0577e3b
new SPACE_MODE env variable so the template can stay up
Browse files
app.py
CHANGED
@@ -1,13 +1,21 @@
|
|
1 |
import os
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
|
12 |
# Optional: persist JAX compile cache across restarts (reduces warmup time)
|
13 |
os.environ.setdefault("JAX_CACHE_DIR", "/home/appuser/.cache/jax")
|
@@ -32,7 +40,7 @@ except Exception:
|
|
32 |
|
33 |
from magenta_rt import system, audio as au
|
34 |
import numpy as np
|
35 |
-
from fastapi import FastAPI, UploadFile, File, Form, Body, HTTPException, Response, Request, WebSocket, WebSocketDisconnect, Query
|
36 |
import tempfile, io, base64, math, threading
|
37 |
from fastapi.middleware.cors import CORSMiddleware
|
38 |
from contextlib import contextmanager
|
@@ -76,6 +84,35 @@ from pydantic import BaseModel
|
|
76 |
|
77 |
from model_management import CheckpointManager, AssetManager, ModelSelector, ModelSelect
|
78 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
79 |
# ---- Finetune assets (mean & centroids) --------------------------------------
|
80 |
# _FINETUNE_REPO_DEFAULT = os.getenv("MRT_ASSETS_REPO", "thepatch/magenta-ft")
|
81 |
_ASSETS_REPO_ID: str | None = None
|
@@ -1108,7 +1145,44 @@ def jam_status(session_id: str):
|
|
1108 |
|
1109 |
@app.get("/health")
|
1110 |
def health():
|
1111 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1112 |
|
1113 |
@app.middleware("http")
|
1114 |
async def log_requests(request: Request, call_next):
|
|
|
1 |
import os
|
2 |
+
|
3 |
+
# ---- Space mode gating (place above any JAX import!) ----
|
4 |
+
SPACE_MODE = os.getenv("SPACE_MODE", "serve") # "serve" | "template"
|
5 |
+
|
6 |
+
if SPACE_MODE != "serve":
|
7 |
+
# In template mode, force JAX to CPU so it won't try to load CUDA plugins
|
8 |
+
os.environ.setdefault("JAX_PLATFORMS", "cpu")
|
9 |
+
else:
|
10 |
+
# Only set GPU-friendly XLA flags when we actually intend to serve on GPU
|
11 |
+
os.environ.setdefault(
|
12 |
+
"XLA_FLAGS",
|
13 |
+
" ".join([
|
14 |
+
"--xla_gpu_enable_triton_gemm=true",
|
15 |
+
"--xla_gpu_enable_latency_hiding_scheduler=true",
|
16 |
+
"--xla_gpu_autotune_level=2",
|
17 |
+
])
|
18 |
+
)
|
19 |
|
20 |
# Optional: persist JAX compile cache across restarts (reduces warmup time)
|
21 |
os.environ.setdefault("JAX_CACHE_DIR", "/home/appuser/.cache/jax")
|
|
|
40 |
|
41 |
from magenta_rt import system, audio as au
|
42 |
import numpy as np
|
43 |
+
from fastapi import FastAPI, UploadFile, File, Form, Body, HTTPException, Response, Request, WebSocket, WebSocketDisconnect, Query, JSONResponse
|
44 |
import tempfile, io, base64, math, threading
|
45 |
from fastapi.middleware.cors import CORSMiddleware
|
46 |
from contextlib import contextmanager
|
|
|
84 |
|
85 |
from model_management import CheckpointManager, AssetManager, ModelSelector, ModelSelect
|
86 |
|
87 |
+
def _gpu_probe() -> dict:
|
88 |
+
"""
|
89 |
+
Returns:
|
90 |
+
{
|
91 |
+
"ok": bool,
|
92 |
+
"backend": str | None, # "gpu" | "cpu" | "tpu" | None
|
93 |
+
"has_gpu": bool,
|
94 |
+
"devices": list[str], # e.g. ["gpu:0", "gpu:1"]
|
95 |
+
"error": str | None,
|
96 |
+
}
|
97 |
+
"""
|
98 |
+
try:
|
99 |
+
import jax
|
100 |
+
try:
|
101 |
+
backend = jax.default_backend() # "gpu", "cpu", "tpu"
|
102 |
+
except Exception:
|
103 |
+
from jax.lib import xla_bridge
|
104 |
+
backend = getattr(xla_bridge.get_backend(), "platform", None)
|
105 |
+
|
106 |
+
try:
|
107 |
+
devices = jax.devices()
|
108 |
+
has_gpu = any(getattr(d, "platform", "") in ("gpu", "cuda", "rocm") for d in devices)
|
109 |
+
dev_list = [f"{getattr(d, 'platform', '?')}:{getattr(d, 'id', '?')}" for d in devices]
|
110 |
+
return {"ok": True, "backend": backend, "has_gpu": has_gpu, "devices": dev_list, "error": None}
|
111 |
+
except Exception as e:
|
112 |
+
return {"ok": False, "backend": backend, "has_gpu": False, "devices": [], "error": f"jax.devices failed: {e}"}
|
113 |
+
except Exception as e:
|
114 |
+
return {"ok": False, "backend": None, "has_gpu": False, "devices": [], "error": f"jax import failed: {e}"}
|
115 |
+
|
116 |
# ---- Finetune assets (mean & centroids) --------------------------------------
|
117 |
# _FINETUNE_REPO_DEFAULT = os.getenv("MRT_ASSETS_REPO", "thepatch/magenta-ft")
|
118 |
_ASSETS_REPO_ID: str | None = None
|
|
|
1145 |
|
1146 |
@app.get("/health")
|
1147 |
def health():
|
1148 |
+
# 1) Template mode → not ready (encourage duplication on GPU)
|
1149 |
+
if SPACE_MODE != "serve":
|
1150 |
+
return JSONResponse(
|
1151 |
+
status_code=503,
|
1152 |
+
content={
|
1153 |
+
"ok": False,
|
1154 |
+
"status": "template_mode",
|
1155 |
+
"message": "This Space is a GPU template. Duplicate it and select an L40s/A100-class runtime to use the API.",
|
1156 |
+
"mode": SPACE_MODE,
|
1157 |
+
},
|
1158 |
+
)
|
1159 |
+
|
1160 |
+
# 2) Runtime hardware probe
|
1161 |
+
probe = _gpu_probe()
|
1162 |
+
if not probe["ok"] or not probe["has_gpu"] or probe.get("backend") != "gpu":
|
1163 |
+
return JSONResponse(
|
1164 |
+
status_code=503,
|
1165 |
+
content={
|
1166 |
+
"ok": False,
|
1167 |
+
"status": "gpu_unavailable",
|
1168 |
+
"message": "GPU is not visible to JAX. Select a GPU runtime (e.g., L40s) to serve.",
|
1169 |
+
"probe": probe,
|
1170 |
+
"mode": SPACE_MODE,
|
1171 |
+
},
|
1172 |
+
)
|
1173 |
+
|
1174 |
+
# 3) Ready; include operational hints
|
1175 |
+
warmed = bool(_WARMED)
|
1176 |
+
with jam_lock:
|
1177 |
+
active_jams = sum(1 for w in jam_registry.values() if w.is_alive())
|
1178 |
+
return {
|
1179 |
+
"ok": True,
|
1180 |
+
"status": "ready" if warmed else "initializing",
|
1181 |
+
"mode": SPACE_MODE,
|
1182 |
+
"warmed": warmed,
|
1183 |
+
"active_jams": active_jams,
|
1184 |
+
"probe": probe,
|
1185 |
+
}
|
1186 |
|
1187 |
@app.middleware("http")
|
1188 |
async def log_requests(request: Request, call_next):
|