thecollabagepatch commited on
Commit
5139a47
·
1 Parent(s): 227a9e0

another attempt at RT speedup for L4

Browse files
Files changed (1) hide show
  1. app.py +30 -20
app.py CHANGED
@@ -1,3 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from magenta_rt import system, audio as au
2
  import numpy as np
3
  from fastapi import FastAPI, UploadFile, File, Form, Body, HTTPException, Response, Request, WebSocket, WebSocketDisconnect
@@ -15,7 +44,7 @@ from utils import (
15
 
16
  from jam_worker import JamWorker, JamParams, JamChunk
17
  import uuid, threading
18
- import os
19
  import logging
20
 
21
  import gradio as gr
@@ -25,25 +54,6 @@ from typing import Optional
25
  import json, asyncio, base64
26
  import time
27
 
28
- # ---- Perf knobs (add at top of app.py) ----
29
- os.environ.setdefault("JAX_PLATFORMS", "cuda") # prefer GPU
30
- os.environ.setdefault("XLA_FLAGS",
31
- "--xla_gpu_enable_triton_gemm=true "
32
- "--xla_gpu_enable_latency_hiding_scheduler=true "
33
- "--xla_gpu_autotune_level=2")
34
- # TF32 is enabled by default on Ampere/Ada for matmul; ensure not disabled:
35
- os.environ.setdefault("NVIDIA_TF32_OVERRIDE", "0")
36
-
37
- import jax
38
- jax.config.update("jax_default_matmul_precision", "fastest") # allow TF32
39
- # Optional: persist XLA compile artifacts across restarts (saves warmup time)
40
- try:
41
- from jax.experimental.compilation_cache import compilation_cache as cc
42
- cc.initialize_cache(os.environ.get("JAX_CACHE_DIR", "/home/appuser/.cache/jax"))
43
- except Exception:
44
- pass
45
- # --------------------------------------------
46
-
47
 
48
 
49
  from starlette.websockets import WebSocketState
 
1
+ import os
2
+ # Useful XLA GPU optimizations (harmless if a flag is unknown)
3
+ os.environ.setdefault(
4
+ "XLA_FLAGS",
5
+ " ".join([
6
+ "--xla_gpu_enable_triton_gemm=true",
7
+ "--xla_gpu_enable_latency_hiding_scheduler=true",
8
+ "--xla_gpu_autotune_level=2",
9
+ ])
10
+ )
11
+
12
+ # Optional: persist JAX compile cache across restarts (reduces warmup time)
13
+ os.environ.setdefault("JAX_CACHE_DIR", "/home/appuser/.cache/jax")
14
+
15
+ import jax
16
+ # ✅ Valid choices include: "default", "high", "highest", "tensorfloat32", "float32", etc.
17
+ # TF32 is the sweet spot on Ampere/Ada GPUs for ~1.1–1.3× matmul speedups.
18
+ jax.config.update("jax_default_matmul_precision", "tensorfloat32")
19
+
20
+ # Initialize the on-disk compilation cache (best-effort)
21
+ try:
22
+ from jax.experimental.compilation_cache import compilation_cache as cc
23
+ cc.initialize_cache(os.environ["JAX_CACHE_DIR"])
24
+ except Exception:
25
+ pass
26
+ # --------------------------------------------------------------------
27
+
28
+
29
+
30
  from magenta_rt import system, audio as au
31
  import numpy as np
32
  from fastapi import FastAPI, UploadFile, File, Form, Body, HTTPException, Response, Request, WebSocket, WebSocketDisconnect
 
44
 
45
  from jam_worker import JamWorker, JamParams, JamChunk
46
  import uuid, threading
47
+
48
  import logging
49
 
50
  import gradio as gr
 
54
  import json, asyncio, base64
55
  import time
56
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
 
58
 
59
  from starlette.websockets import WebSocketState