Commit
·
8381f2e
1
Parent(s):
8999b96
reverting jax cache changes
Browse files
app.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1 |
-
import
|
2 |
|
3 |
# ---- Space mode gating (place above any JAX import!) ----
|
4 |
SPACE_MODE = os.getenv("SPACE_MODE")
|
@@ -28,33 +28,21 @@ else:
|
|
28 |
# Optional: persist JAX compile cache across restarts (reduces warmup time)
|
29 |
os.environ.setdefault("JAX_CACHE_DIR", "/home/appuser/.cache/jax")
|
30 |
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
CACHE_DIR = os.environ.get("JAX_CACHE_DIR", "/home/appuser/.cache/jax")
|
35 |
-
|
36 |
-
# Prefer new API (JAX ≥ 0.4.26 / 0.5+), fall back to older initialize_cache
|
37 |
try:
|
38 |
-
|
39 |
-
if hasattr(cc, "set_cache_dir"):
|
40 |
-
cc.set_cache_dir(CACHE_DIR)
|
41 |
-
logging.info("JAX persistent cache (set_cache_dir) -> %s", CACHE_DIR)
|
42 |
-
else:
|
43 |
-
raise ImportError
|
44 |
except Exception:
|
45 |
-
|
46 |
-
from jax.experimental.compilation_cache import compilation_cache as cc_old # old-style
|
47 |
-
cc_old.initialize_cache(CACHE_DIR)
|
48 |
-
logging.info("JAX persistent cache (initialize_cache) -> %s", CACHE_DIR)
|
49 |
-
except Exception as e:
|
50 |
-
logging.warning("JAX persistent cache init skipped: %s", e)
|
51 |
|
52 |
-
#
|
53 |
try:
|
54 |
-
import
|
55 |
-
|
56 |
-
except Exception
|
57 |
-
|
|
|
58 |
|
59 |
|
60 |
|
@@ -79,6 +67,8 @@ from one_shot_generation import generate_loop_continuation_with_mrt, generate_st
|
|
79 |
|
80 |
import uuid, threading
|
81 |
|
|
|
|
|
82 |
import gradio as gr
|
83 |
from typing import Optional, Union, Literal
|
84 |
|
@@ -367,21 +357,15 @@ _WARMUP_LOCK = threading.Lock()
|
|
367 |
|
368 |
def _mrt_warmup():
|
369 |
"""
|
370 |
-
Build a minimal, bar-aligned silent context and run
|
371 |
-
|
372 |
"""
|
373 |
global _WARMED
|
374 |
with _WARMUP_LOCK:
|
375 |
if _WARMED:
|
376 |
return
|
377 |
try:
|
378 |
-
|
379 |
-
try:
|
380 |
-
import jax; _ = jax.devices()
|
381 |
-
except Exception:
|
382 |
-
pass
|
383 |
-
|
384 |
-
mrt = get_mrt() # will build model and (with our earlier changes) ensure assets if envs are set
|
385 |
|
386 |
# --- derive timing from model config ---
|
387 |
codec_fps = float(mrt.codec.frame_rate)
|
@@ -422,18 +406,10 @@ def _mrt_warmup():
|
|
422 |
state.context_tokens = context_tokens
|
423 |
style_vec = mrt.embed_style("warmup")
|
424 |
|
425 |
-
# ---
|
426 |
-
|
427 |
-
wav2, _ = mrt.generate_chunk(state=state, style=style_vec) # hit cached executables
|
428 |
-
|
429 |
-
# Optional sanity: ensure we didn't return all zeros
|
430 |
-
try:
|
431 |
-
if np.abs(wav2.samples).mean() <= 1e-7:
|
432 |
-
logging.warning("Warmup produced near-silence; continuing.")
|
433 |
-
except Exception:
|
434 |
-
pass
|
435 |
|
436 |
-
logging.info("MagentaRT warmup complete
|
437 |
finally:
|
438 |
try:
|
439 |
os.unlink(tmp_path)
|
|
|
1 |
+
import os
|
2 |
|
3 |
# ---- Space mode gating (place above any JAX import!) ----
|
4 |
SPACE_MODE = os.getenv("SPACE_MODE")
|
|
|
28 |
# Optional: persist JAX compile cache across restarts (reduces warmup time)
|
29 |
os.environ.setdefault("JAX_CACHE_DIR", "/home/appuser/.cache/jax")
|
30 |
|
31 |
+
import jax
|
32 |
+
# ✅ Valid choices include: "default", "high", "highest", "tensorfloat32", "float32", etc.
|
33 |
+
# TF32 is the sweet spot on Ampere/Ada GPUs for ~1.1–1.3× matmul speedups.
|
|
|
|
|
|
|
34 |
try:
|
35 |
+
jax.config.update("jax_default_matmul_precision", "tensorfloat32")
|
|
|
|
|
|
|
|
|
|
|
36 |
except Exception:
|
37 |
+
jax.config.update("jax_default_matmul_precision", "high") # older alias
|
|
|
|
|
|
|
|
|
|
|
38 |
|
39 |
+
# Initialize the on-disk compilation cache (best-effort)
|
40 |
try:
|
41 |
+
from jax.experimental.compilation_cache import compilation_cache as cc
|
42 |
+
cc.initialize_cache(os.environ["JAX_CACHE_DIR"])
|
43 |
+
except Exception:
|
44 |
+
pass
|
45 |
+
# --------------------------------------------------------------------
|
46 |
|
47 |
|
48 |
|
|
|
67 |
|
68 |
import uuid, threading
|
69 |
|
70 |
+
import logging
|
71 |
+
|
72 |
import gradio as gr
|
73 |
from typing import Optional, Union, Literal
|
74 |
|
|
|
357 |
|
358 |
def _mrt_warmup():
|
359 |
"""
|
360 |
+
Build a minimal, bar-aligned silent context and run one 2s generate_chunk
|
361 |
+
to trigger XLA JIT & autotune so first real request is fast.
|
362 |
"""
|
363 |
global _WARMED
|
364 |
with _WARMUP_LOCK:
|
365 |
if _WARMED:
|
366 |
return
|
367 |
try:
|
368 |
+
mrt = get_mrt()
|
|
|
|
|
|
|
|
|
|
|
|
|
369 |
|
370 |
# --- derive timing from model config ---
|
371 |
codec_fps = float(mrt.codec.frame_rate)
|
|
|
406 |
state.context_tokens = context_tokens
|
407 |
style_vec = mrt.embed_style("warmup")
|
408 |
|
409 |
+
# --- one throwaway chunk (~2s) ---
|
410 |
+
_wav, _state = mrt.generate_chunk(state=state, style=style_vec)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
411 |
|
412 |
+
logging.info("MagentaRT warmup complete.")
|
413 |
finally:
|
414 |
try:
|
415 |
os.unlink(tmp_path)
|