thecollabagepatch commited on
Commit
c1e9a88
·
1 Parent(s): 5139a47

let's see if base model runs RT on L4

Browse files
Files changed (1) hide show
  1. app.py +5 -2
app.py CHANGED
@@ -15,7 +15,10 @@ os.environ.setdefault("JAX_CACHE_DIR", "/home/appuser/.cache/jax")
15
  import jax
16
  # ✅ Valid choices include: "default", "high", "highest", "tensorfloat32", "float32", etc.
17
  # TF32 is the sweet spot on Ampere/Ada GPUs for ~1.1–1.3× matmul speedups.
18
- jax.config.update("jax_default_matmul_precision", "tensorfloat32")
 
 
 
19
 
20
  # Initialize the on-disk compilation cache (best-effort)
21
  try:
@@ -447,7 +450,7 @@ def get_mrt():
447
  if _MRT is None:
448
  with _MRT_LOCK:
449
  if _MRT is None:
450
- _MRT = system.MagentaRT(tag="large", guidance_weight=5.0, device="gpu", lazy=False)
451
  return _MRT
452
 
453
  _WARMED = False
 
15
  import jax
16
  # ✅ Valid choices include: "default", "high", "highest", "tensorfloat32", "float32", etc.
17
  # TF32 is the sweet spot on Ampere/Ada GPUs for ~1.1–1.3× matmul speedups.
18
+ try:
19
+ jax.config.update("jax_default_matmul_precision", "tensorfloat32")
20
+ except Exception:
21
+ jax.config.update("jax_default_matmul_precision", "high") # older alias
22
 
23
  # Initialize the on-disk compilation cache (best-effort)
24
  try:
 
450
  if _MRT is None:
451
  with _MRT_LOCK:
452
  if _MRT is None:
453
+ _MRT = system.MagentaRT(tag="base", guidance_weight=5.0, device="gpu", lazy=False)
454
  return _MRT
455
 
456
  _WARMED = False