thecollabagepatch commited on
Commit
4843704
·
1 Parent(s): e140f31

jax cache for faster compilation attempt

Browse files
Files changed (1) hide show
  1. app.py +44 -20
app.py CHANGED
@@ -1,4 +1,4 @@
1
- import os
2
 
3
  # ---- Space mode gating (place above any JAX import!) ----
4
  SPACE_MODE = os.getenv("SPACE_MODE")
@@ -28,21 +28,33 @@ else:
28
  # Optional: persist JAX compile cache across restarts (reduces warmup time)
29
  os.environ.setdefault("JAX_CACHE_DIR", "/home/appuser/.cache/jax")
30
 
31
- import jax
32
- # ✅ Valid choices include: "default", "high", "highest", "tensorfloat32", "float32", etc.
33
- # TF32 is the sweet spot on Ampere/Ada GPUs for ~1.1–1.3× matmul speedups.
 
 
 
34
  try:
35
- jax.config.update("jax_default_matmul_precision", "tensorfloat32")
 
 
 
 
 
36
  except Exception:
37
- jax.config.update("jax_default_matmul_precision", "high") # older alias
 
 
 
 
 
38
 
39
- # Initialize the on-disk compilation cache (best-effort)
40
  try:
41
- from jax.experimental.compilation_cache import compilation_cache as cc
42
- cc.initialize_cache(os.environ["JAX_CACHE_DIR"])
43
- except Exception:
44
- pass
45
- # --------------------------------------------------------------------
46
 
47
 
48
 
@@ -67,8 +79,6 @@ from one_shot_generation import generate_loop_continuation_with_mrt, generate_st
67
 
68
  import uuid, threading
69
 
70
- import logging
71
-
72
  import gradio as gr
73
  from typing import Optional, Union, Literal
74
 
@@ -357,15 +367,21 @@ _WARMUP_LOCK = threading.Lock()
357
 
358
  def _mrt_warmup():
359
  """
360
- Build a minimal, bar-aligned silent context and run one 2s generate_chunk
361
- to trigger XLA JIT & autotune so first real request is fast.
362
  """
363
  global _WARMED
364
  with _WARMUP_LOCK:
365
  if _WARMED:
366
  return
367
  try:
368
- mrt = get_mrt()
 
 
 
 
 
 
369
 
370
  # --- derive timing from model config ---
371
  codec_fps = float(mrt.codec.frame_rate)
@@ -406,10 +422,18 @@ def _mrt_warmup():
406
  state.context_tokens = context_tokens
407
  style_vec = mrt.embed_style("warmup")
408
 
409
- # --- one throwaway chunk (~2s) ---
410
- _wav, _state = mrt.generate_chunk(state=state, style=style_vec)
 
 
 
 
 
 
 
 
411
 
412
- logging.info("MagentaRT warmup complete.")
413
  finally:
414
  try:
415
  os.unlink(tmp_path)
 
1
+ import logging, os
2
 
3
  # ---- Space mode gating (place above any JAX import!) ----
4
  SPACE_MODE = os.getenv("SPACE_MODE")
 
28
  # Optional: persist JAX compile cache across restarts (reduces warmup time)
29
  os.environ.setdefault("JAX_CACHE_DIR", "/home/appuser/.cache/jax")
30
 
31
+ # --- JAX persistent compilation cache (new + old APIs), plus extra XLA caches ---
32
+
33
+
34
+ CACHE_DIR = os.environ.get("JAX_CACHE_DIR", "/home/appuser/.cache/jax")
35
+
36
+ # Prefer new API (JAX ≥ 0.4.26 / 0.5+), fall back to older initialize_cache
37
  try:
38
+ from jax.experimental import compilation_cache as cc # new-style
39
+ if hasattr(cc, "set_cache_dir"):
40
+ cc.set_cache_dir(CACHE_DIR)
41
+ logging.info("JAX persistent cache (set_cache_dir) -> %s", CACHE_DIR)
42
+ else:
43
+ raise ImportError
44
  except Exception:
45
+ try:
46
+ from jax.experimental.compilation_cache import compilation_cache as cc_old # old-style
47
+ cc_old.initialize_cache(CACHE_DIR)
48
+ logging.info("JAX persistent cache (initialize_cache) -> %s", CACHE_DIR)
49
+ except Exception as e:
50
+ logging.warning("JAX persistent cache init skipped: %s", e)
51
 
52
+ # Extra XLA caches piggyback on the persistent cache (best effort)
53
  try:
54
+ import jax
55
+ jax.config.update("jax_persistent_cache_enable_xla_caches", "all")
56
+ except Exception as e:
57
+ logging.info("XLA extra caches not enabled: %s", e)
 
58
 
59
 
60
 
 
79
 
80
  import uuid, threading
81
 
 
 
82
  import gradio as gr
83
  from typing import Optional, Union, Literal
84
 
 
367
 
368
  def _mrt_warmup():
369
  """
370
+ Build a minimal, bar-aligned silent context and run a couple of ~2s generate_chunk
371
+ passes to trigger JIT, fill persistent caches, and run XLA autotune.
372
  """
373
  global _WARMED
374
  with _WARMUP_LOCK:
375
  if _WARMED:
376
  return
377
  try:
378
+ # Touch JAX backend early (brings up CUDA context etc.)
379
+ try:
380
+ import jax; _ = jax.devices()
381
+ except Exception:
382
+ pass
383
+
384
+ mrt = get_mrt() # will build model and (with our earlier changes) ensure assets if envs are set
385
 
386
  # --- derive timing from model config ---
387
  codec_fps = float(mrt.codec.frame_rate)
 
422
  state.context_tokens = context_tokens
423
  style_vec = mrt.embed_style("warmup")
424
 
425
+ # --- prime compiled paths & autotune: run twice ---
426
+ wav1, state = mrt.generate_chunk(state=state, style=style_vec) # compile + autotune
427
+ wav2, _ = mrt.generate_chunk(state=state, style=style_vec) # hit cached executables
428
+
429
+ # Optional sanity: ensure we didn't return all zeros
430
+ try:
431
+ if np.abs(wav2.samples).mean() <= 1e-7:
432
+ logging.warning("Warmup produced near-silence; continuing.")
433
+ except Exception:
434
+ pass
435
 
436
+ logging.info("MagentaRT warmup complete (persistent cache primed).")
437
  finally:
438
  try:
439
  os.unlink(tmp_path)