thecollabagepatch commited on
Commit
8381f2e
·
1 Parent(s): 8999b96

reverting jax cache changes

Browse files
Files changed (1) hide show
  1. app.py +20 -44
app.py CHANGED
@@ -1,4 +1,4 @@
1
- import logging, os
2
 
3
  # ---- Space mode gating (place above any JAX import!) ----
4
  SPACE_MODE = os.getenv("SPACE_MODE")
@@ -28,33 +28,21 @@ else:
28
  # Optional: persist JAX compile cache across restarts (reduces warmup time)
29
  os.environ.setdefault("JAX_CACHE_DIR", "/home/appuser/.cache/jax")
30
 
31
- # --- JAX persistent compilation cache (new + old APIs), plus extra XLA caches ---
32
-
33
-
34
- CACHE_DIR = os.environ.get("JAX_CACHE_DIR", "/home/appuser/.cache/jax")
35
-
36
- # Prefer new API (JAX ≥ 0.4.26 / 0.5+), fall back to older initialize_cache
37
  try:
38
- from jax.experimental import compilation_cache as cc # new-style
39
- if hasattr(cc, "set_cache_dir"):
40
- cc.set_cache_dir(CACHE_DIR)
41
- logging.info("JAX persistent cache (set_cache_dir) -> %s", CACHE_DIR)
42
- else:
43
- raise ImportError
44
  except Exception:
45
- try:
46
- from jax.experimental.compilation_cache import compilation_cache as cc_old # old-style
47
- cc_old.initialize_cache(CACHE_DIR)
48
- logging.info("JAX persistent cache (initialize_cache) -> %s", CACHE_DIR)
49
- except Exception as e:
50
- logging.warning("JAX persistent cache init skipped: %s", e)
51
 
52
- # Extra XLA caches piggyback on the persistent cache (best effort)
53
  try:
54
- import jax
55
- jax.config.update("jax_persistent_cache_enable_xla_caches", "all")
56
- except Exception as e:
57
- logging.info("XLA extra caches not enabled: %s", e)
 
58
 
59
 
60
 
@@ -79,6 +67,8 @@ from one_shot_generation import generate_loop_continuation_with_mrt, generate_st
79
 
80
  import uuid, threading
81
 
 
 
82
  import gradio as gr
83
  from typing import Optional, Union, Literal
84
 
@@ -367,21 +357,15 @@ _WARMUP_LOCK = threading.Lock()
367
 
368
  def _mrt_warmup():
369
  """
370
- Build a minimal, bar-aligned silent context and run a couple of ~2s generate_chunk
371
- passes to trigger JIT, fill persistent caches, and run XLA autotune.
372
  """
373
  global _WARMED
374
  with _WARMUP_LOCK:
375
  if _WARMED:
376
  return
377
  try:
378
- # Touch JAX backend early (brings up CUDA context etc.)
379
- try:
380
- import jax; _ = jax.devices()
381
- except Exception:
382
- pass
383
-
384
- mrt = get_mrt() # will build model and (with our earlier changes) ensure assets if envs are set
385
 
386
  # --- derive timing from model config ---
387
  codec_fps = float(mrt.codec.frame_rate)
@@ -422,18 +406,10 @@ def _mrt_warmup():
422
  state.context_tokens = context_tokens
423
  style_vec = mrt.embed_style("warmup")
424
 
425
- # --- prime compiled paths & autotune: run twice ---
426
- wav1, state = mrt.generate_chunk(state=state, style=style_vec) # compile + autotune
427
- wav2, _ = mrt.generate_chunk(state=state, style=style_vec) # hit cached executables
428
-
429
- # Optional sanity: ensure we didn't return all zeros
430
- try:
431
- if np.abs(wav2.samples).mean() <= 1e-7:
432
- logging.warning("Warmup produced near-silence; continuing.")
433
- except Exception:
434
- pass
435
 
436
- logging.info("MagentaRT warmup complete (persistent cache primed).")
437
  finally:
438
  try:
439
  os.unlink(tmp_path)
 
1
+ import os
2
 
3
  # ---- Space mode gating (place above any JAX import!) ----
4
  SPACE_MODE = os.getenv("SPACE_MODE")
 
28
  # Optional: persist JAX compile cache across restarts (reduces warmup time)
29
  os.environ.setdefault("JAX_CACHE_DIR", "/home/appuser/.cache/jax")
30
 
31
+ import jax
32
+ # ✅ Valid choices include: "default", "high", "highest", "tensorfloat32", "float32", etc.
33
+ # TF32 is the sweet spot on Ampere/Ada GPUs for ~1.1–1.3× matmul speedups.
 
 
 
34
  try:
35
+ jax.config.update("jax_default_matmul_precision", "tensorfloat32")
 
 
 
 
 
36
  except Exception:
37
+ jax.config.update("jax_default_matmul_precision", "high") # older alias
 
 
 
 
 
38
 
39
+ # Initialize the on-disk compilation cache (best-effort)
40
  try:
41
+ from jax.experimental.compilation_cache import compilation_cache as cc
42
+ cc.initialize_cache(os.environ["JAX_CACHE_DIR"])
43
+ except Exception:
44
+ pass
45
+ # --------------------------------------------------------------------
46
 
47
 
48
 
 
67
 
68
  import uuid, threading
69
 
70
+ import logging
71
+
72
  import gradio as gr
73
  from typing import Optional, Union, Literal
74
 
 
357
 
358
  def _mrt_warmup():
359
  """
360
+ Build a minimal, bar-aligned silent context and run one 2s generate_chunk
361
+ to trigger XLA JIT & autotune so first real request is fast.
362
  """
363
  global _WARMED
364
  with _WARMUP_LOCK:
365
  if _WARMED:
366
  return
367
  try:
368
+ mrt = get_mrt()
 
 
 
 
 
 
369
 
370
  # --- derive timing from model config ---
371
  codec_fps = float(mrt.codec.frame_rate)
 
406
  state.context_tokens = context_tokens
407
  style_vec = mrt.embed_style("warmup")
408
 
409
+ # --- one throwaway chunk (~2s) ---
410
+ _wav, _state = mrt.generate_chunk(state=state, style=style_vec)
 
 
 
 
 
 
 
 
411
 
412
+ logging.info("MagentaRT warmup complete.")
413
  finally:
414
  try:
415
  os.unlink(tmp_path)