thecollabagepatch commited on
Commit
227a9e0
·
1 Parent(s): d54b5ce

attempting RT speedup for L4

Browse files
Files changed (1) hide show
  1. app.py +22 -0
app.py CHANGED
@@ -24,6 +24,28 @@ from typing import Optional
24
 
25
  import json, asyncio, base64
26
  import time
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  from starlette.websockets import WebSocketState
28
  try:
29
  from uvicorn.protocols.utils import ClientDisconnected # uvicorn >= 0.20
 
24
 
25
  import json, asyncio, base64
26
  import time
27
+
28
+ # ---- Perf knobs (add at top of app.py) ----
29
+ os.environ.setdefault("JAX_PLATFORMS", "cuda") # prefer GPU
30
+ os.environ.setdefault("XLA_FLAGS",
31
+ "--xla_gpu_enable_triton_gemm=true "
32
+ "--xla_gpu_enable_latency_hiding_scheduler=true "
33
+ "--xla_gpu_autotune_level=2")
34
+ # TF32 is enabled by default on Ampere/Ada for matmul; ensure not disabled:
35
+ os.environ.setdefault("NVIDIA_TF32_OVERRIDE", "0")
36
+
37
+ import jax
38
+ jax.config.update("jax_default_matmul_precision", "fastest") # allow TF32
39
+ # Optional: persist XLA compile artifacts across restarts (saves warmup time)
40
+ try:
41
+ from jax.experimental.compilation_cache import compilation_cache as cc
42
+ cc.initialize_cache(os.environ.get("JAX_CACHE_DIR", "/home/appuser/.cache/jax"))
43
+ except Exception:
44
+ pass
45
+ # --------------------------------------------
46
+
47
+
48
+
49
  from starlette.websockets import WebSocketState
50
  try:
51
  from uvicorn.protocols.utils import ClientDisconnected # uvicorn >= 0.20