Commit
·
5139a47
1
Parent(s):
227a9e0
another attempt at RT speedup for L4
Browse files
app.py
CHANGED
@@ -1,3 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
from magenta_rt import system, audio as au
|
2 |
import numpy as np
|
3 |
from fastapi import FastAPI, UploadFile, File, Form, Body, HTTPException, Response, Request, WebSocket, WebSocketDisconnect
|
@@ -15,7 +44,7 @@ from utils import (
|
|
15 |
|
16 |
from jam_worker import JamWorker, JamParams, JamChunk
|
17 |
import uuid, threading
|
18 |
-
|
19 |
import logging
|
20 |
|
21 |
import gradio as gr
|
@@ -25,25 +54,6 @@ from typing import Optional
|
|
25 |
import json, asyncio, base64
|
26 |
import time
|
27 |
|
28 |
-
# ---- Perf knobs (add at top of app.py) ----
|
29 |
-
os.environ.setdefault("JAX_PLATFORMS", "cuda") # prefer GPU
|
30 |
-
os.environ.setdefault("XLA_FLAGS",
|
31 |
-
"--xla_gpu_enable_triton_gemm=true "
|
32 |
-
"--xla_gpu_enable_latency_hiding_scheduler=true "
|
33 |
-
"--xla_gpu_autotune_level=2")
|
34 |
-
# TF32 is enabled by default on Ampere/Ada for matmul; ensure not disabled:
|
35 |
-
os.environ.setdefault("NVIDIA_TF32_OVERRIDE", "0")
|
36 |
-
|
37 |
-
import jax
|
38 |
-
jax.config.update("jax_default_matmul_precision", "fastest") # allow TF32
|
39 |
-
# Optional: persist XLA compile artifacts across restarts (saves warmup time)
|
40 |
-
try:
|
41 |
-
from jax.experimental.compilation_cache import compilation_cache as cc
|
42 |
-
cc.initialize_cache(os.environ.get("JAX_CACHE_DIR", "/home/appuser/.cache/jax"))
|
43 |
-
except Exception:
|
44 |
-
pass
|
45 |
-
# --------------------------------------------
|
46 |
-
|
47 |
|
48 |
|
49 |
from starlette.websockets import WebSocketState
|
|
|
1 |
+
import os
|
2 |
+
# Useful XLA GPU optimizations (harmless if a flag is unknown)
|
3 |
+
os.environ.setdefault(
|
4 |
+
"XLA_FLAGS",
|
5 |
+
" ".join([
|
6 |
+
"--xla_gpu_enable_triton_gemm=true",
|
7 |
+
"--xla_gpu_enable_latency_hiding_scheduler=true",
|
8 |
+
"--xla_gpu_autotune_level=2",
|
9 |
+
])
|
10 |
+
)
|
11 |
+
|
12 |
+
# Optional: persist JAX compile cache across restarts (reduces warmup time)
|
13 |
+
os.environ.setdefault("JAX_CACHE_DIR", "/home/appuser/.cache/jax")
|
14 |
+
|
15 |
+
import jax
|
16 |
+
# ✅ Valid choices include: "default", "high", "highest", "tensorfloat32", "float32", etc.
|
17 |
+
# TF32 is the sweet spot on Ampere/Ada GPUs for ~1.1–1.3× matmul speedups.
|
18 |
+
jax.config.update("jax_default_matmul_precision", "tensorfloat32")
|
19 |
+
|
20 |
+
# Initialize the on-disk compilation cache (best-effort)
|
21 |
+
try:
|
22 |
+
from jax.experimental.compilation_cache import compilation_cache as cc
|
23 |
+
cc.initialize_cache(os.environ["JAX_CACHE_DIR"])
|
24 |
+
except Exception:
|
25 |
+
pass
|
26 |
+
# --------------------------------------------------------------------
|
27 |
+
|
28 |
+
|
29 |
+
|
30 |
from magenta_rt import system, audio as au
|
31 |
import numpy as np
|
32 |
from fastapi import FastAPI, UploadFile, File, Form, Body, HTTPException, Response, Request, WebSocket, WebSocketDisconnect
|
|
|
44 |
|
45 |
from jam_worker import JamWorker, JamParams, JamChunk
|
46 |
import uuid, threading
|
47 |
+
|
48 |
import logging
|
49 |
|
50 |
import gradio as gr
|
|
|
54 |
import json, asyncio, base64
|
55 |
import time
|
56 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
57 |
|
58 |
|
59 |
from starlette.websockets import WebSocketState
|