Commit
·
227a9e0
1
Parent(s):
d54b5ce
attempting RT speedup for L4
Browse files
app.py
CHANGED
@@ -24,6 +24,28 @@ from typing import Optional
|
|
24 |
|
25 |
import json, asyncio, base64
|
26 |
import time
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
27 |
from starlette.websockets import WebSocketState
|
28 |
try:
|
29 |
from uvicorn.protocols.utils import ClientDisconnected # uvicorn >= 0.20
|
|
|
24 |
|
25 |
import json, asyncio, base64
|
26 |
import time
|
27 |
+
|
28 |
+
# ---- Perf knobs (add at top of app.py) ----
|
29 |
+
os.environ.setdefault("JAX_PLATFORMS", "cuda") # prefer GPU
|
30 |
+
os.environ.setdefault("XLA_FLAGS",
|
31 |
+
"--xla_gpu_enable_triton_gemm=true "
|
32 |
+
"--xla_gpu_enable_latency_hiding_scheduler=true "
|
33 |
+
"--xla_gpu_autotune_level=2")
|
34 |
+
# TF32 is enabled by default on Ampere/Ada for matmul; ensure not disabled:
|
35 |
+
os.environ.setdefault("NVIDIA_TF32_OVERRIDE", "0")
|
36 |
+
|
37 |
+
import jax
|
38 |
+
jax.config.update("jax_default_matmul_precision", "fastest") # allow TF32
|
39 |
+
# Optional: persist XLA compile artifacts across restarts (saves warmup time)
|
40 |
+
try:
|
41 |
+
from jax.experimental.compilation_cache import compilation_cache as cc
|
42 |
+
cc.initialize_cache(os.environ.get("JAX_CACHE_DIR", "/home/appuser/.cache/jax"))
|
43 |
+
except Exception:
|
44 |
+
pass
|
45 |
+
# --------------------------------------------
|
46 |
+
|
47 |
+
|
48 |
+
|
49 |
from starlette.websockets import WebSocketState
|
50 |
try:
|
51 |
from uvicorn.protocols.utils import ClientDisconnected # uvicorn >= 0.20
|