Commit
·
4843704
1
Parent(s):
e140f31
jax cache for faster compilation attempt
Browse files
app.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1 |
-
import os
|
2 |
|
3 |
# ---- Space mode gating (place above any JAX import!) ----
|
4 |
SPACE_MODE = os.getenv("SPACE_MODE")
|
@@ -28,21 +28,33 @@ else:
|
|
28 |
# Optional: persist JAX compile cache across restarts (reduces warmup time)
|
29 |
os.environ.setdefault("JAX_CACHE_DIR", "/home/appuser/.cache/jax")
|
30 |
|
31 |
-
|
32 |
-
|
33 |
-
|
|
|
|
|
|
|
34 |
try:
|
35 |
-
jax.
|
|
|
|
|
|
|
|
|
|
|
36 |
except Exception:
|
37 |
-
|
|
|
|
|
|
|
|
|
|
|
38 |
|
39 |
-
#
|
40 |
try:
|
41 |
-
|
42 |
-
|
43 |
-
except Exception:
|
44 |
-
|
45 |
-
# --------------------------------------------------------------------
|
46 |
|
47 |
|
48 |
|
@@ -67,8 +79,6 @@ from one_shot_generation import generate_loop_continuation_with_mrt, generate_st
|
|
67 |
|
68 |
import uuid, threading
|
69 |
|
70 |
-
import logging
|
71 |
-
|
72 |
import gradio as gr
|
73 |
from typing import Optional, Union, Literal
|
74 |
|
@@ -357,15 +367,21 @@ _WARMUP_LOCK = threading.Lock()
|
|
357 |
|
358 |
def _mrt_warmup():
|
359 |
"""
|
360 |
-
Build a minimal, bar-aligned silent context and run
|
361 |
-
to trigger
|
362 |
"""
|
363 |
global _WARMED
|
364 |
with _WARMUP_LOCK:
|
365 |
if _WARMED:
|
366 |
return
|
367 |
try:
|
368 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
369 |
|
370 |
# --- derive timing from model config ---
|
371 |
codec_fps = float(mrt.codec.frame_rate)
|
@@ -406,10 +422,18 @@ def _mrt_warmup():
|
|
406 |
state.context_tokens = context_tokens
|
407 |
style_vec = mrt.embed_style("warmup")
|
408 |
|
409 |
-
# ---
|
410 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
411 |
|
412 |
-
logging.info("MagentaRT warmup complete.")
|
413 |
finally:
|
414 |
try:
|
415 |
os.unlink(tmp_path)
|
|
|
1 |
+
import logging, os
|
2 |
|
3 |
# ---- Space mode gating (place above any JAX import!) ----
|
4 |
SPACE_MODE = os.getenv("SPACE_MODE")
|
|
|
28 |
# Optional: persist JAX compile cache across restarts (reduces warmup time)
|
29 |
os.environ.setdefault("JAX_CACHE_DIR", "/home/appuser/.cache/jax")
|
30 |
|
31 |
+
# --- JAX persistent compilation cache (new + old APIs), plus extra XLA caches ---
|
32 |
+
|
33 |
+
|
34 |
+
CACHE_DIR = os.environ.get("JAX_CACHE_DIR", "/home/appuser/.cache/jax")
|
35 |
+
|
36 |
+
# Prefer new API (JAX ≥ 0.4.26 / 0.5+), fall back to older initialize_cache
|
37 |
try:
|
38 |
+
from jax.experimental import compilation_cache as cc # new-style
|
39 |
+
if hasattr(cc, "set_cache_dir"):
|
40 |
+
cc.set_cache_dir(CACHE_DIR)
|
41 |
+
logging.info("JAX persistent cache (set_cache_dir) -> %s", CACHE_DIR)
|
42 |
+
else:
|
43 |
+
raise ImportError
|
44 |
except Exception:
|
45 |
+
try:
|
46 |
+
from jax.experimental.compilation_cache import compilation_cache as cc_old # old-style
|
47 |
+
cc_old.initialize_cache(CACHE_DIR)
|
48 |
+
logging.info("JAX persistent cache (initialize_cache) -> %s", CACHE_DIR)
|
49 |
+
except Exception as e:
|
50 |
+
logging.warning("JAX persistent cache init skipped: %s", e)
|
51 |
|
52 |
+
# Extra XLA caches piggyback on the persistent cache (best effort)
|
53 |
try:
|
54 |
+
import jax
|
55 |
+
jax.config.update("jax_persistent_cache_enable_xla_caches", "all")
|
56 |
+
except Exception as e:
|
57 |
+
logging.info("XLA extra caches not enabled: %s", e)
|
|
|
58 |
|
59 |
|
60 |
|
|
|
79 |
|
80 |
import uuid, threading
|
81 |
|
|
|
|
|
82 |
import gradio as gr
|
83 |
from typing import Optional, Union, Literal
|
84 |
|
|
|
367 |
|
368 |
def _mrt_warmup():
|
369 |
"""
|
370 |
+
Build a minimal, bar-aligned silent context and run a couple of ~2s generate_chunk
|
371 |
+
passes to trigger JIT, fill persistent caches, and run XLA autotune.
|
372 |
"""
|
373 |
global _WARMED
|
374 |
with _WARMUP_LOCK:
|
375 |
if _WARMED:
|
376 |
return
|
377 |
try:
|
378 |
+
# Touch JAX backend early (brings up CUDA context etc.)
|
379 |
+
try:
|
380 |
+
import jax; _ = jax.devices()
|
381 |
+
except Exception:
|
382 |
+
pass
|
383 |
+
|
384 |
+
mrt = get_mrt() # will build model and (with our earlier changes) ensure assets if envs are set
|
385 |
|
386 |
# --- derive timing from model config ---
|
387 |
codec_fps = float(mrt.codec.frame_rate)
|
|
|
422 |
state.context_tokens = context_tokens
|
423 |
style_vec = mrt.embed_style("warmup")
|
424 |
|
425 |
+
# --- prime compiled paths & autotune: run twice ---
|
426 |
+
wav1, state = mrt.generate_chunk(state=state, style=style_vec) # compile + autotune
|
427 |
+
wav2, _ = mrt.generate_chunk(state=state, style=style_vec) # hit cached executables
|
428 |
+
|
429 |
+
# Optional sanity: ensure we didn't return all zeros
|
430 |
+
try:
|
431 |
+
if np.abs(wav2.samples).mean() <= 1e-7:
|
432 |
+
logging.warning("Warmup produced near-silence; continuing.")
|
433 |
+
except Exception:
|
434 |
+
pass
|
435 |
|
436 |
+
logging.info("MagentaRT warmup complete (persistent cache primed).")
|
437 |
finally:
|
438 |
try:
|
439 |
os.unlink(tmp_path)
|