Commit
·
c1e9a88
1
Parent(s):
5139a47
let's see if base model runs RT on L4
Browse files
app.py
CHANGED
@@ -15,7 +15,10 @@ os.environ.setdefault("JAX_CACHE_DIR", "/home/appuser/.cache/jax")
|
|
15 |
import jax
|
16 |
# ✅ Valid choices include: "default", "high", "highest", "tensorfloat32", "float32", etc.
|
17 |
# TF32 is the sweet spot on Ampere/Ada GPUs for ~1.1–1.3× matmul speedups.
|
18 |
-
|
|
|
|
|
|
|
19 |
|
20 |
# Initialize the on-disk compilation cache (best-effort)
|
21 |
try:
|
@@ -447,7 +450,7 @@ def get_mrt():
|
|
447 |
if _MRT is None:
|
448 |
with _MRT_LOCK:
|
449 |
if _MRT is None:
|
450 |
-
_MRT = system.MagentaRT(tag="
|
451 |
return _MRT
|
452 |
|
453 |
_WARMED = False
|
|
|
15 |
import jax
|
16 |
# ✅ Valid choices include: "default", "high", "highest", "tensorfloat32", "float32", etc.
|
17 |
# TF32 is the sweet spot on Ampere/Ada GPUs for ~1.1–1.3× matmul speedups.
|
18 |
+
try:
|
19 |
+
jax.config.update("jax_default_matmul_precision", "tensorfloat32")
|
20 |
+
except Exception:
|
21 |
+
jax.config.update("jax_default_matmul_precision", "high") # older alias
|
22 |
|
23 |
# Initialize the on-disk compilation cache (best-effort)
|
24 |
try:
|
|
|
450 |
if _MRT is None:
|
451 |
with _MRT_LOCK:
|
452 |
if _MRT is None:
|
453 |
+
_MRT = system.MagentaRT(tag="base", guidance_weight=5.0, device="gpu", lazy=False)
|
454 |
return _MRT
|
455 |
|
456 |
_WARMED = False
|