File size: 65,160 Bytes
8381f2e 85e8363 e5eaa7e 85e8363 5139a47 8381f2e c1e9a88 8381f2e c1e9a88 8381f2e 5139a47 8381f2e 5139a47 8381f2e 5139a47 d1afbc8 e5eaa7e 46f9035 d1afbc8 842a99f d1afbc8 5139a47 8381f2e 1a7ea3c 49b5fee 1a7ea3c 2bd198e 6896250 227a9e0 d54b5ce 80346b6 147fd8b 49b5fee 147fd8b 384e4ac 85e8363 a8de318 30fdbbc a8de318 30fdbbc 49b5fee 384e4ac 49b5fee f0823f1 49b5fee 147fd8b d54b5ce 2bd198e 9a1b4dc d1afbc8 30fdbbc a8de318 406bd0f d1afbc8 f0823f1 147fd8b 384e4ac 147fd8b 384e4ac f0823f1 147fd8b f0823f1 d1afbc8 87147f5 8381f2e 87147f5 8381f2e 87147f5 b1bc032 87147f5 8381f2e 87147f5 8381f2e 87147f5 30fdbbc 87147f5 f0823f1 87147f5 147fd8b a8de318 384e4ac a8de318 49b5fee f94c8f9 384e4ac f94c8f9 384e4ac f94c8f9 384e4ac f94c8f9 384e4ac f94c8f9 49b5fee f94c8f9 49b5fee 384e4ac 49b5fee 0577e3b 384e4ac f0823f1 0577e3b 384e4ac 0577e3b 49b5fee 384e4ac f0823f1 c211cf7 49b5fee 0577e3b 49b5fee 0577e3b 384e4ac 0577e3b 49b5fee 384e4ac 49b5fee 0577e3b 49b5fee 0577e3b 384e4ac 0577e3b 49b5fee 0577e3b 49b5fee 0577e3b 384e4ac 0577e3b 384e4ac 0577e3b 384e4ac 0577e3b 49b5fee 0577e3b 49b5fee 0577e3b 49b5fee 0577e3b 49b5fee 0577e3b 30fdbbc 49b5fee aba0837 d1afbc8 406bd0f d1afbc8 a8de318 d1afbc8 384e4ac a8de318 d1afbc8 a8de318 d1afbc8 a8de318 af5a7b2 a8de318 af5a7b2 d1afbc8 a8de318 af5a7b2 a8de318 af5a7b2 a8de318 af5a7b2 384e4ac a8de318 8cedcd0 a8de318 af5a7b2 8cedcd0 a8de318 8cedcd0 a8de318 8cedcd0 d1afbc8 85e8363 1a7ea3c e7bec6e 4a4198e 2bd198e 6896250 2bd198e d54b5ce 2bd198e d54b5ce 2bd198e a8de318 2bd198e a8de318 2bd198e a8de318 6896250 2bd198e 6896250 a8de318 384e4ac a8de318 2bd198e 60a96af a8de318 2bd198e 6896250 2bd198e 6896250 60a96af 2bd198e 6896250 2bd198e 6896250 d54b5ce 2bd198e d54b5ce a8de318 d54b5ce 2bd198e a8de318 6896250 d54b5ce 6896250 d54b5ce 6896250 d54b5ce a8de318 d54b5ce 2bd198e 6896250 a8de318 384e4ac 60a96af a8de318 60a96af a8de318 2bd198e d54b5ce 2bd198e d54b5ce 2bd198e 4a4198e dcfd5bb 00fc4dc 46f9035 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096 1097 1098 1099 1100 1101 1102 1103 1104 1105 1106 1107 1108 1109 1110 1111 1112 1113 1114 1115 1116 1117 1118 1119 1120 1121 1122 1123 1124 1125 1126 1127 1128 1129 1130 1131 1132 1133 1134 1135 1136 1137 1138 1139 1140 1141 1142 1143 1144 1145 1146 1147 1148 1149 1150 1151 1152 1153 1154 1155 1156 1157 1158 1159 1160 1161 1162 1163 1164 1165 1166 1167 1168 1169 1170 1171 1172 1173 1174 1175 1176 1177 1178 1179 1180 1181 1182 1183 1184 1185 1186 1187 1188 1189 1190 1191 1192 1193 1194 1195 1196 1197 1198 1199 1200 1201 1202 1203 1204 1205 1206 1207 1208 1209 1210 1211 1212 1213 1214 1215 1216 1217 1218 1219 1220 1221 1222 1223 1224 1225 1226 1227 1228 1229 1230 1231 1232 1233 1234 1235 1236 1237 1238 1239 1240 1241 1242 1243 1244 1245 1246 1247 1248 1249 1250 1251 1252 1253 1254 1255 1256 1257 1258 1259 1260 1261 1262 1263 1264 1265 1266 1267 1268 1269 1270 1271 1272 1273 1274 1275 1276 1277 1278 1279 1280 1281 1282 1283 1284 1285 1286 1287 1288 1289 1290 1291 1292 1293 1294 1295 1296 1297 1298 1299 1300 1301 1302 1303 1304 1305 1306 1307 1308 1309 1310 1311 1312 1313 1314 1315 1316 1317 1318 1319 1320 1321 1322 1323 1324 1325 1326 1327 1328 1329 1330 1331 1332 1333 1334 1335 1336 1337 1338 1339 1340 1341 1342 1343 1344 1345 1346 1347 1348 1349 1350 1351 1352 1353 1354 1355 1356 1357 1358 1359 1360 1361 1362 1363 1364 1365 1366 1367 1368 1369 1370 1371 1372 1373 1374 1375 1376 1377 1378 1379 1380 1381 1382 1383 1384 1385 1386 1387 1388 1389 1390 1391 1392 1393 1394 1395 1396 1397 1398 1399 1400 1401 1402 1403 1404 1405 1406 1407 1408 1409 1410 1411 1412 1413 1414 1415 1416 1417 1418 1419 1420 1421 1422 1423 1424 1425 1426 1427 1428 1429 1430 1431 1432 1433 1434 1435 1436 1437 1438 1439 1440 1441 1442 1443 1444 1445 1446 1447 1448 1449 1450 1451 1452 1453 1454 1455 1456 1457 1458 1459 1460 1461 1462 1463 1464 1465 1466 1467 1468 1469 1470 1471 1472 1473 1474 1475 1476 1477 1478 1479 1480 1481 1482 1483 1484 1485 1486 1487 1488 1489 1490 1491 1492 1493 1494 1495 1496 1497 1498 1499 1500 1501 1502 1503 1504 1505 1506 1507 1508 1509 1510 1511 1512 1513 1514 1515 1516 1517 1518 1519 1520 1521 1522 1523 1524 1525 1526 1527 1528 1529 1530 1531 1532 1533 1534 1535 1536 1537 1538 1539 1540 1541 1542 1543 1544 1545 1546 1547 1548 1549 1550 1551 1552 1553 1554 1555 1556 1557 1558 1559 1560 1561 1562 1563 1564 1565 1566 1567 1568 1569 1570 1571 1572 1573 1574 1575 1576 1577 1578 1579 1580 1581 1582 1583 1584 1585 1586 1587 1588 1589 1590 1591 1592 1593 1594 1595 1596 1597 1598 1599 1600 1601 1602 1603 1604 1605 1606 1607 1608 1609 1610 1611 1612 1613 1614 1615 1616 1617 1618 1619 1620 1621 1622 1623 1624 1625 1626 1627 1628 1629 1630 1631 1632 1633 1634 1635 1636 1637 1638 1639 1640 1641 1642 1643 |
import os
# ---- Space mode gating (place above any JAX import!) ----
SPACE_MODE = os.getenv("SPACE_MODE")
if SPACE_MODE is None:
try:
import jax
SPACE_MODE = "serve" if any(getattr(d, "platform", "") in ("gpu","cuda","rocm") for d in jax.devices()) else "template"
except Exception:
SPACE_MODE = "template"
if SPACE_MODE != "serve":
# In template mode, force JAX to CPU so it won't try to load CUDA plugins
os.environ.setdefault("JAX_PLATFORMS", "cpu")
else:
# Only set GPU-friendly XLA flags when we actually intend to serve on GPU
os.environ.setdefault(
"XLA_FLAGS",
" ".join([
"--xla_gpu_enable_triton_gemm=true",
"--xla_gpu_enable_latency_hiding_scheduler=true",
"--xla_gpu_autotune_level=2",
])
)
# Optional: persist JAX compile cache across restarts (reduces warmup time)
os.environ.setdefault("JAX_CACHE_DIR", "/home/appuser/.cache/jax")
import jax
# ✅ Valid choices include: "default", "high", "highest", "tensorfloat32", "float32", etc.
# TF32 is the sweet spot on Ampere/Ada GPUs for ~1.1–1.3× matmul speedups.
try:
jax.config.update("jax_default_matmul_precision", "tensorfloat32")
except Exception:
jax.config.update("jax_default_matmul_precision", "high") # older alias
# Initialize the on-disk compilation cache (best-effort)
try:
from jax.experimental.compilation_cache import compilation_cache as cc
cc.initialize_cache(os.environ["JAX_CACHE_DIR"])
except Exception:
pass
# --------------------------------------------------------------------
from magenta_rt import system, audio as au
import numpy as np
from fastapi import FastAPI, UploadFile, File, Form, Body, HTTPException, Response, Request, WebSocket, WebSocketDisconnect, Query
from fastapi.responses import JSONResponse, FileResponse, HTMLResponse
import tempfile, io, base64, math, threading
from fastapi.middleware.cors import CORSMiddleware
from contextlib import contextmanager
import soundfile as sf
from math import gcd
from scipy.signal import resample_poly
from utils import (
match_loudness_to_reference, stitch_generated, hard_trim_seconds,
apply_micro_fades, make_bar_aligned_context, take_bar_aligned_tail,
resample_and_snap, wav_bytes_base64
)
from jam_worker import JamWorker, JamParams, JamChunk
from one_shot_generation import generate_loop_continuation_with_mrt, generate_style_only_with_mrt
import uuid, threading
import logging
import gradio as gr
from typing import Optional, Union, Literal
import json, asyncio, base64
import time
from starlette.websockets import WebSocketState
try:
from uvicorn.protocols.utils import ClientDisconnected # uvicorn >= 0.20
except Exception:
class ClientDisconnected(Exception): # fallback
pass
import re, tarfile
from pathlib import Path
from huggingface_hub import snapshot_download, HfApi
from pydantic import BaseModel
from model_management import CheckpointManager, AssetManager, ModelSelector, ModelSelect
def _gpu_probe() -> dict:
"""
Returns:
{
"ok": bool,
"backend": str | None, # "gpu" | "cpu" | "tpu" | None
"has_gpu": bool,
"devices": list[str], # e.g. ["gpu:0", "gpu:1"]
"error": str | None,
}
"""
try:
import jax
try:
backend = jax.default_backend() # "gpu", "cpu", "tpu"
except Exception:
from jax.lib import xla_bridge
backend = getattr(xla_bridge.get_backend(), "platform", None)
try:
devices = jax.devices()
has_gpu = any(getattr(d, "platform", "") in ("gpu", "cuda", "rocm") for d in devices)
dev_list = [f"{getattr(d, 'platform', '?')}:{getattr(d, 'id', '?')}" for d in devices]
return {"ok": True, "backend": backend, "has_gpu": has_gpu, "devices": dev_list, "error": None}
except Exception as e:
return {"ok": False, "backend": backend, "has_gpu": False, "devices": [], "error": f"jax.devices failed: {e}"}
except Exception as e:
return {"ok": False, "backend": None, "has_gpu": False, "devices": [], "error": f"jax import failed: {e}"}
# ---- Finetune assets (mean & centroids) --------------------------------------
# _FINETUNE_REPO_DEFAULT = os.getenv("MRT_ASSETS_REPO", "thepatch/magenta-ft")
_ASSETS_REPO_ID: str | None = None
_MEAN_EMBED: np.ndarray | None = None # shape (D,) dtype float32
_CENTROIDS: np.ndarray | None = None # shape (K, D) dtype float32
# _STEP_RE = re.compile(r"(?:^|/)checkpoint_(\d+)(?:/|\.tar\.gz|\.tgz)?$")
# Create instances (these don't modify globals)
asset_manager = AssetManager()
model_selector = ModelSelector(CheckpointManager(), asset_manager)
def _sync_assets_globals_from_manager():
# Keeps /model/config in sync with what the asset manager has
global _MEAN_EMBED, _CENTROIDS, _ASSETS_REPO_ID
_MEAN_EMBED = asset_manager.mean_embed
_CENTROIDS = asset_manager.centroids
_ASSETS_REPO_ID = asset_manager.assets_repo_id
def _any_jam_running() -> bool:
with jam_lock:
return any(w.is_alive() for w in jam_registry.values())
def _stop_all_jams(timeout: float = 5.0):
with jam_lock:
for sid, w in list(jam_registry.items()):
if w.is_alive():
w.stop()
w.join(timeout=timeout)
jam_registry.pop(sid, None)
async def send_json_safe(ws: WebSocket, obj) -> bool:
"""Try to send. Returns False if the socket is (or becomes) closed."""
if ws.client_state == WebSocketState.DISCONNECTED or ws.application_state == WebSocketState.DISCONNECTED:
return False
try:
await ws.send_text(json.dumps(obj))
return True
except (WebSocketDisconnect, ClientDisconnected, RuntimeError):
return False
except Exception:
return False
# --- Patch T5X mesh helpers for GPUs on JAX >= 0.7 (coords present, no core_on_chip) ---
def _patch_t5x_for_gpu_coords():
try:
import jax
from t5x import partitioning as _t5x_part
old_bounds = getattr(_t5x_part, "bounds_from_last_device", None)
old_getcoords = getattr(_t5x_part, "get_coords", None)
def _bounds_from_last_device_gpu_safe(last_device):
# TPU: coords + core_on_chip
core = getattr(last_device, "core_on_chip", None)
coords = getattr(last_device, "coords", None)
if coords is not None and core is not None:
x, y, z = coords
return x + 1, y + 1, z + 1, core + 1
# Non-TPU (or GPU lacking core_on_chip): hosts x local_devices
return jax.host_count(), jax.local_device_count()
def _get_coords_gpu_safe(device):
core = getattr(device, "core_on_chip", None)
coords = getattr(device, "coords", None)
if coords is not None and core is not None:
return (*coords, core)
# Fallback that works on CPU/GPU
return (device.process_index, device.id % jax.local_device_count())
_t5x_part.bounds_from_last_device = _bounds_from_last_device_gpu_safe
_t5x_part.get_coords = _get_coords_gpu_safe
import logging; logging.info("Patched t5x.partitioning for GPU coords without core_on_chip.")
except Exception as e:
import logging; logging.exception("t5x GPU-coords patch failed: %s", e)
# Call the patch immediately at import time (before MagentaRT init)
_patch_t5x_for_gpu_coords()
jam_registry: dict[str, JamWorker] = {}
jam_lock = threading.Lock()
@contextmanager
def mrt_overrides(mrt, **kwargs):
"""Temporarily set attributes on MRT if they exist; restore after."""
old = {}
try:
for k, v in kwargs.items():
if hasattr(mrt, k):
old[k] = getattr(mrt, k)
setattr(mrt, k, v)
yield
finally:
for k, v in old.items():
setattr(mrt, k, v)
# loudness utils
try:
import pyloudnorm as pyln
_HAS_LOUDNORM = True
except Exception:
_HAS_LOUDNORM = False
# def _combine_styles(mrt, styles_str: str = "", weights_str: str = ""):
# extra = [s.strip() for s in (styles_str or "").split(",") if s.strip()]
# if not extra:
# return mrt.embed_style("warmup")
# sw = [float(x) for x in (weights_str or "").split(",") if x.strip()]
# embeds, weights = [], []
# for i, s in enumerate(extra):
# embeds.append(mrt.embed_style(s))
# weights.append(sw[i] if i < len(sw) else 1.0)
# wsum = sum(weights) or 1.0
# weights = [w/wsum for w in weights]
# import numpy as np
# return np.sum([w*e for w, e in zip(weights, embeds)], axis=0).astype(np.float32)
def build_style_vector(
mrt,
*,
text_styles: list[str] | None = None,
text_weights: list[float] | None = None,
loop_embed: np.ndarray | None = None,
loop_weight: float | None = None,
mean_weight: float | None = None,
centroid_weights: list[float] | None = None,
) -> np.ndarray:
"""
Returns a single style embedding combining:
- loop embedding (optional)
- one or more text style embeddings (optional)
- mean finetune embedding (optional)
- centroid embeddings (optional)
All weights are normalized so they sum to 1 if > 0.
"""
comps: list[np.ndarray] = []
weights: list[float] = []
# loop component
if loop_embed is not None and (loop_weight or 0) > 0:
comps.append(loop_embed.astype(np.float32, copy=False))
weights.append(float(loop_weight))
# text components
if text_styles:
for i, s in enumerate(text_styles):
s = s.strip()
if not s:
continue
w = 1.0
if text_weights and i < len(text_weights):
try: w = float(text_weights[i])
except: w = 1.0
if w <= 0:
continue
e = mrt.embed_style(s)
comps.append(e.astype(np.float32, copy=False))
weights.append(w)
# mean finetune
if mean_weight and (_MEAN_EMBED is not None) and mean_weight > 0:
comps.append(_MEAN_EMBED)
weights.append(float(mean_weight))
# centroid components
if centroid_weights and _CENTROIDS is not None:
K = _CENTROIDS.shape[0]
for k, w in enumerate(centroid_weights[:K]):
try: w = float(w)
except: w = 0.0
if w <= 0:
continue
comps.append(_CENTROIDS[k])
weights.append(w)
if not comps:
# fallback: neutral style if nothing provided
return mrt.embed_style("")
wsum = sum(weights)
if wsum <= 0:
return mrt.embed_style("")
weights = [w/wsum for w in weights]
# weighted sum
out = np.zeros_like(comps[0], dtype=np.float32)
for w, e in zip(weights, comps):
out += w * e.astype(np.float32, copy=False)
return out
# ----------------------------
# FastAPI app with lazy, thread-safe model init
# ----------------------------
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # or lock to your domain(s)
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
_MRT = None
_MRT_LOCK = threading.Lock()
def get_mrt():
global _MRT
if _MRT is None:
with _MRT_LOCK:
if _MRT is None:
ckpt_dir = CheckpointManager.resolve_checkpoint_dir() # uses MRT_CKPT_REPO/STEP if present
_MRT = system.MagentaRT(
tag=os.getenv("MRT_SIZE", "large"),
guidance_weight=5.0,
device="gpu",
checkpoint_dir=ckpt_dir,
lazy=False
)
# If no assets loaded yet, and a repo is configured, load them now.
if asset_manager.mean_embed is None and asset_manager.centroids is None:
repo = os.getenv("MRT_ASSETS_REPO") or os.getenv("MRT_CKPT_REPO")
if repo:
asset_manager.load_finetune_assets_from_hf(repo, None)
_sync_assets_globals_from_manager()
return _MRT
_WARMED = False
_WARMUP_LOCK = threading.Lock()
def _mrt_warmup():
"""
Build a minimal, bar-aligned silent context and run one 2s generate_chunk
to trigger XLA JIT & autotune so first real request is fast.
"""
global _WARMED
with _WARMUP_LOCK:
if _WARMED:
return
try:
mrt = get_mrt()
# --- derive timing from model config ---
codec_fps = float(mrt.codec.frame_rate)
ctx_seconds = float(mrt.config.context_length_frames) / codec_fps
sr = int(mrt.sample_rate)
# We'll align to 120 BPM, 4/4, and generate one ~2s chunk
bpm = 120.0
beats_per_bar = 4
# --- build a silent, stereo context of ctx_seconds ---
samples = int(max(1, round(ctx_seconds * sr)))
silent = np.zeros((samples, 2), dtype=np.float32)
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
sf.write(tmp.name, silent, sr, subtype="PCM_16")
tmp_path = tmp.name
try:
# Load as Waveform and take a tail of exactly ctx_seconds
loop = au.Waveform.from_file(tmp_path).resample(sr).as_stereo()
seconds_per_bar = beats_per_bar * (60.0 / bpm)
ctx_tail = take_bar_aligned_tail(loop, bpm, beats_per_bar, ctx_seconds)
# Tokens for context window
tokens_full = mrt.codec.encode(ctx_tail).astype(np.int32)
tokens = tokens_full[:, :mrt.config.decoder_codec_rvq_depth]
context_tokens = make_bar_aligned_context(
tokens,
bpm=bpm,
fps=float(mrt.codec.frame_rate),
ctx_frames=mrt.config.context_length_frames,
beats_per_bar=beats_per_bar,
)
# Init state and a basic style vector (text token is fine)
state = mrt.init_state()
state.context_tokens = context_tokens
style_vec = mrt.embed_style("warmup")
# --- one throwaway chunk (~2s) ---
_wav, _state = mrt.generate_chunk(state=state, style=style_vec)
logging.info("MagentaRT warmup complete.")
finally:
try:
os.unlink(tmp_path)
except Exception:
pass
_WARMED = True
except Exception as e:
# Never crash on warmup errors; log and continue serving
logging.exception("MagentaRT warmup failed (continuing without warmup): %s", e)
# ----------------------------
# startup and model selection
# ----------------------------
@app.on_event("startup")
def _boot():
# 1) Load finetune assets up front (only if envs are present)
repo = os.getenv("MRT_ASSETS_REPO") or os.getenv("MRT_CKPT_REPO")
if repo:
ok, msg = asset_manager.load_finetune_assets_from_hf(repo, None)
_sync_assets_globals_from_manager() # keep /model/config in sync
logging.info("Startup asset load from %s: %s", repo, "ok" if ok else msg)
else:
logging.info("Startup asset load: no repo env set; skipping.")
# 2) Start warmup in the background (unchanged behavior)
if os.getenv("MRT_WARMUP", "1") != "0":
threading.Thread(target=_mrt_warmup, name="mrt-warmup", daemon=True).start()
@app.get("/model/status")
def model_status():
mrt = get_mrt()
return {
"tag": getattr(mrt, "_tag", "unknown"),
"using_checkpoint_dir": True,
"codec_frame_rate": float(mrt.codec.frame_rate),
"decoder_rvq_depth": int(mrt.config.decoder_codec_rvq_depth),
"context_seconds": float(mrt.config.context_length),
"chunk_seconds": float(mrt.config.chunk_length),
"crossfade_seconds": float(mrt.config.crossfade_length),
"selected_step": os.getenv("MRT_CKPT_STEP"),
"repo": os.getenv("MRT_CKPT_REPO"),
}
@app.post("/model/swap")
def model_swap(step: int = Form(...)):
# stop any active jam if you want to be strict (not shown)
os.environ["MRT_CKPT_STEP"] = str(step)
global _MRT
with _MRT_LOCK:
_MRT = None # force re-create on next get_mrt()
# optionally pre-warm here by calling get_mrt()
return {"reloaded": True, "step": step}
@app.post("/model/assets/load")
def model_assets_load(repo_id: str = Form(None)):
global _MEAN_EMBED, _CENTROIDS, _ASSETS_REPO_ID
ok, msg = asset_manager.load_finetune_assets_from_hf(repo_id, get_mrt())
# Sync globals after loading
_MEAN_EMBED = asset_manager.mean_embed
_CENTROIDS = asset_manager.centroids
_ASSETS_REPO_ID = asset_manager.assets_repo_id
return {"ok": ok, "message": msg, "repo_id": _ASSETS_REPO_ID,
"mean": _MEAN_EMBED is not None,
"centroids": None if _CENTROIDS is None else int(_CENTROIDS.shape[0])}
@app.get("/model/assets/status")
def model_assets_status():
d = None
try:
d = int(get_mrt().style_model.config.embedding_dim)
except Exception:
pass
return {
"repo_id": _ASSETS_REPO_ID,
"mean_loaded": _MEAN_EMBED is not None,
"centroids_loaded": False if _CENTROIDS is None else True,
"centroid_count": None if _CENTROIDS is None else int(_CENTROIDS.shape[0]),
"embedding_dim": d,
}
@app.get("/model/config")
def model_config():
"""
Lightweight config snapshot:
- never calls get_mrt() (no model build / no downloads)
- never calls snapshot_download()
- reports whether a model instance is currently loaded in memory
- best-effort local checkpoint presence (no network)
"""
# Read-only snapshot of in-memory model presence
with _MRT_LOCK:
loaded = (_MRT is not None)
size = os.getenv("MRT_SIZE", "large")
repo = os.getenv("MRT_CKPT_REPO")
rev = os.getenv("MRT_CKPT_REV", "main")
step = os.getenv("MRT_CKPT_STEP")
assets = os.getenv("MRT_ASSETS_REPO")
# Use CheckpointManager for local cache probe (no network)
local_ckpt = None
if step:
try:
from pathlib import Path
import re
step_escaped = re.escape(str(step))
candidates = []
for root in ("/home/appuser/.cache/mrt_ckpt/extracted",
"/home/appuser/.cache/mrt_ckpt/repo"):
p = Path(root)
if not p.exists():
continue
# Look for exact "checkpoint_<step>" directories anywhere under these roots
for d in p.rglob(f"checkpoint_{step}"):
if d.is_dir():
candidates.append(str(d))
local_ckpt = candidates[0] if candidates else None
except Exception:
local_ckpt = None
return {
"size": size,
"repo": repo,
"revision": rev,
"selected_step": step,
"assets_repo": assets,
# in-memory + local cache hints (no network, no model build)
"loaded": loaded,
"active_jam": _any_jam_running(),
"local_checkpoint_dir": local_ckpt, # None if not found locally
# steering assets currently resident in memory
"mean_loaded": (_MEAN_EMBED is not None),
"centroids_loaded": (_CENTROIDS is not None),
"centroid_count": (None if _CENTROIDS is None else int(_CENTROIDS.shape[0])),
"warmup_done": bool(_WARMED),
}
@app.get("/model/checkpoints")
def model_checkpoints(repo_id: str, revision: str = "main"):
steps = CheckpointManager.list_ckpt_steps(repo_id, revision)
return {"repo": repo_id, "revision": revision, "steps": steps, "latest": (steps[-1] if steps else None)}
@app.post("/model/select")
def model_select(req: ModelSelect):
"""
Swap model/checkpoint/assets. If req.prewarm is True, run the full bar-aligned warmup
(_mrt_warmup) synchronously so we only report warmed once the new model is actually ready.
"""
global _MRT, _MEAN_EMBED, _CENTROIDS, _ASSETS_REPO_ID, _WARMED
# 1) Validate the request (no side-effects)
success, validation_result = model_selector.validate_selection(req)
if not success:
if "error" in validation_result:
raise HTTPException(status_code=400, detail=validation_result["error"])
return {"ok": False, **validation_result}
# Augment response surface
validation_result["active_jam"] = _any_jam_running()
# Dry-run path
if req.dry_run:
return {"ok": True, "dry_run": True, **validation_result}
if isinstance(req.step, str) and req.step.lower() == "none": # user asked for stock base
# Clear any resident finetune assets so /model/config reflects "no assets"
asset_manager.mean_embed = None
asset_manager.centroids = None
asset_manager.assets_repo_id = None
# keep module-level mirrors in sync
_MEAN_EMBED = None
_CENTROIDS = None
_ASSETS_REPO_ID = None
# 2) Handle jam policy
if _any_jam_running():
if req.stop_active:
_stop_all_jams()
else:
raise HTTPException(status_code=409, detail="A jam is running; retry with stop_active=true")
# 3) Compute environment changes (no mutation yet)
env_changes = model_selector.prepare_env_changes(req, validation_result)
# Keep current env for rollback
old_env = {
"MRT_SIZE": os.getenv("MRT_SIZE"),
"MRT_CKPT_REPO": os.getenv("MRT_CKPT_REPO"),
"MRT_CKPT_REV": os.getenv("MRT_CKPT_REV"),
"MRT_CKPT_STEP": os.getenv("MRT_CKPT_STEP"),
"MRT_ASSETS_REPO": os.getenv("MRT_ASSETS_REPO"),
}
try:
# 4) Apply env atomically
for key, value in env_changes.items():
if value is None:
os.environ.pop(key, None)
else:
os.environ[key] = str(value)
# 5) Force rebuild of the model and reset warmup state
with _MRT_LOCK:
_MRT = None
with _WARMUP_LOCK:
_WARMED = False # ← critical: don't leak previous model's warmed state
# 6) Load finetune assets if requested (mean/centroids)
if req.sync_assets and validation_result.get("assets_repo"):
ok, msg = asset_manager.load_finetune_assets_from_hf(
validation_result["assets_repo"],
None # don't implicitly instantiate model here; we'll do it below
)
if ok:
_MEAN_EMBED = asset_manager.mean_embed
_CENTROIDS = asset_manager.centroids
_ASSETS_REPO_ID = asset_manager.assets_repo_id
else:
logging.warning("Asset sync skipped/failed: %s", msg)
# 7) Prewarm behavior:
# - If prewarm=True, run the *real* bar-aligned warmup synchronously.
# - This will instantiate the new MRT and set _WARMED=True on success.
if req.prewarm:
_mrt_warmup() # builds MRT internally via get_mrt(), runs generate_chunk, sets _WARMED
# Optional: if you want to always ensure MRT exists (even without prewarm), uncomment:
# else:
# _ = get_mrt()
return {
"ok": True,
**validation_result,
"warmup_done": bool(_WARMED),
}
except Exception as e:
# 8) Roll back env on failure
for k, v in old_env.items():
if v is None:
os.environ.pop(k, None)
else:
os.environ[k] = v
# Also reset model pointer & warmed flag to a safe state
with _MRT_LOCK:
_MRT = None
with _WARMUP_LOCK:
_WARMED = False
logging.exception("Model select failed: %s", e)
raise HTTPException(status_code=500, detail=f"Model select failed: {e}")
# ----------------------------
# one-shot generation
# ----------------------------
@app.post("/generate")
def generate(
loop_audio: UploadFile = File(...),
bpm: float = Form(...),
bars: int = Form(8),
beats_per_bar: int = Form(4),
styles: str = Form("acid house"),
style_weights: str = Form(""),
loop_weight: float = Form(1.0),
loudness_mode: str = Form("auto"),
loudness_headroom_db: float = Form(1.0),
guidance_weight: float = Form(5.0),
temperature: float = Form(1.1),
topk: int = Form(40),
target_sample_rate: int | None = Form(None),
intro_bars_to_drop: int = Form(0), # <— NEW
):
# Read file
data = loop_audio.file.read()
if not data:
return {"error": "Empty file"}
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp:
tmp.write(data)
tmp_path = tmp.name
# Parse styles + weights
extra_styles = [s for s in (styles.split(",") if styles else []) if s.strip()]
weights = [float(x) for x in style_weights.split(",")] if style_weights else None
mrt = get_mrt() # warm once, in this worker thread
# Temporarily override MRT inference knobs for this request
with mrt_overrides(mrt,
guidance_weight=guidance_weight,
temperature=temperature,
topk=topk):
wav, loud_stats = generate_loop_continuation_with_mrt(
mrt,
input_wav_path=tmp_path,
bpm=bpm,
extra_styles=extra_styles,
style_weights=weights,
bars=bars,
beats_per_bar=beats_per_bar,
loop_weight=loop_weight,
loudness_mode=loudness_mode,
loudness_headroom_db=loudness_headroom_db,
intro_bars_to_drop=intro_bars_to_drop, # <— pass through
)
# 1) Figure out the desired SR
inp_info = sf.info(tmp_path)
input_sr = int(inp_info.samplerate)
target_sr = int(target_sample_rate or input_sr)
# 2) Convert to target SR + snap to exact bars
cur_sr = int(mrt.sample_rate)
x = wav.samples if wav.samples.ndim == 2 else wav.samples[:, None]
seconds_per_bar = (60.0 / float(bpm)) * int(beats_per_bar)
expected_secs = float(bars) * seconds_per_bar
x = resample_and_snap(x, cur_sr=cur_sr, target_sr=target_sr, seconds=expected_secs)
# 3) Encode WAV once (no extra write)
audio_b64, total_samples, channels = wav_bytes_base64(x, target_sr)
loop_duration_seconds = total_samples / float(target_sr)
# 4) Metadata
metadata = {
"bpm": int(round(bpm)),
"bars": int(bars),
"beats_per_bar": int(beats_per_bar),
"styles": extra_styles,
"style_weights": weights,
"loop_weight": loop_weight,
"loudness": loud_stats,
"sample_rate": int(target_sr),
"channels": int(channels),
"crossfade_seconds": mrt.config.crossfade_length,
"total_samples": int(total_samples),
"seconds_per_bar": seconds_per_bar,
"loop_duration_seconds": loop_duration_seconds,
"guidance_weight": guidance_weight,
"temperature": temperature,
"topk": topk,
}
return {"audio_base64": audio_b64, "metadata": metadata}
# new endpoint to return a bar-aligned chunk without the need for combined audio
@app.post("/generate_style")
def generate_style(
bpm: float = Form(...),
bars: int = Form(8),
beats_per_bar: int = Form(4),
styles: str = Form("warmup"),
style_weights: str = Form(""),
guidance_weight: float = Form(1.1),
temperature: float = Form(1.1),
topk: int = Form(40),
target_sample_rate: int | None = Form(None),
intro_bars_to_drop: int = Form(0),
):
"""
Style-only, bar-aligned generation (no input audio).
Seeds with 10s of silent context; outputs exactly `bars` at the requested BPM.
"""
mrt = get_mrt()
# Override sampling knobs just for this request
with mrt_overrides(mrt,
guidance_weight=guidance_weight,
temperature=temperature,
topk=topk):
wav, _ = generate_style_only_with_mrt(
mrt,
bpm=bpm,
bars=bars,
beats_per_bar=beats_per_bar,
styles=styles,
style_weights=style_weights,
intro_bars_to_drop=intro_bars_to_drop,
)
# Determine target SR (defaults to model SR = 48k)
cur_sr = int(mrt.sample_rate)
target_sr = int(target_sample_rate or cur_sr)
x = wav.samples if wav.samples.ndim == 2 else wav.samples[:, None]
seconds_per_bar = (60.0 / float(bpm)) * int(beats_per_bar)
expected_secs = float(bars) * seconds_per_bar
# Snap exactly to musical length at the requested sample rate
x = resample_and_snap(x, cur_sr=cur_sr, target_sr=target_sr, seconds=expected_secs)
audio_b64, total_samples, channels = wav_bytes_base64(x, target_sr)
metadata = {
"bpm": int(round(bpm)),
"bars": int(bars),
"beats_per_bar": int(beats_per_bar),
"styles": [s.strip() for s in (styles.split(",") if styles else []) if s.strip()],
"style_weights": [float(y) for y in style_weights.split(",")] if style_weights else None,
"sample_rate": int(target_sr),
"channels": int(channels),
"crossfade_seconds": mrt.config.crossfade_length,
"seconds_per_bar": seconds_per_bar,
"loop_duration_seconds": total_samples / float(target_sr),
"guidance_weight": guidance_weight,
"temperature": temperature,
"topk": topk,
}
return {"audio_base64": audio_b64, "metadata": metadata}
# ----------------------------
# the 'keep jamming' button
# ----------------------------
@app.post("/jam/start")
def jam_start(
loop_audio: UploadFile = File(...),
bpm: float = Form(...),
bars_per_chunk: int = Form(4),
beats_per_bar: int = Form(4),
styles: str = Form(""),
style_weights: str = Form(""),
loop_weight: float = Form(1.0),
# NEW steering params:
mean: float = Form(0.0),
centroid_weights: str = Form(""),
loudness_mode: str = Form("auto"),
loudness_headroom_db: float = Form(1.0),
guidance_weight: float = Form(1.1),
temperature: float = Form(1.1),
topk: int = Form(40),
target_sample_rate: int | None = Form(None),
):
asset_manager.ensure_assets_loaded(get_mrt())
# enforce single active jam per GPU
with jam_lock:
for sid, w in list(jam_registry.items()):
if w.is_alive():
raise HTTPException(status_code=429, detail="A jam is already running. Try again later.")
# read input + prep context/style (reuse your existing code)
data = loop_audio.file.read()
if not data: raise HTTPException(status_code=400, detail="Empty file")
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp:
tmp.write(data); tmp_path = tmp.name
mrt = get_mrt()
loop = au.Waveform.from_file(tmp_path).resample(mrt.sample_rate).as_stereo()
# build tail context + style vec (tail-biased)
codec_fps = float(mrt.codec.frame_rate)
ctx_seconds = float(mrt.config.context_length_frames) / codec_fps
loop_tail = take_bar_aligned_tail(loop, bpm, beats_per_bar, ctx_seconds)
# Parse client style fields (preserves your semantics)
text_list = [s.strip() for s in (styles.split(",") if styles else []) if s.strip()]
try:
tw = [float(x) for x in style_weights.split(",")] if style_weights else []
except ValueError:
tw = []
try:
cw = [float(x) for x in centroid_weights.split(",")] if centroid_weights else []
except ValueError:
cw = []
# Compute loop-tail embed once (same as before)
loop_tail_embed = mrt.embed_style(loop_tail)
# Build final style vector:
# - identical to your previous mix when mean==0 and cw is empty
# - otherwise includes mean and centroid components (weights auto-normalized)
style_vec = build_style_vector(
mrt,
text_styles=text_list,
text_weights=tw,
loop_embed=loop_tail_embed,
loop_weight=float(loop_weight),
mean_weight=float(mean),
centroid_weights=cw,
).astype(np.float32, copy=False)
# target SR (default input SR)
inp_info = sf.info(tmp_path)
input_sr = int(inp_info.samplerate)
target_sr = int(target_sample_rate or input_sr)
params = JamParams(
bpm=bpm,
beats_per_bar=beats_per_bar,
bars_per_chunk=bars_per_chunk,
target_sr=target_sr,
loudness_mode=loudness_mode,
headroom_db=loudness_headroom_db,
style_vec=style_vec,
ref_loop=loop_tail, # For loudness matching
combined_loop=loop, # NEW: Full loop for context setup
guidance_weight=guidance_weight,
temperature=temperature,
topk=topk
)
worker = JamWorker(mrt, params)
sid = str(uuid.uuid4())
with jam_lock:
jam_registry[sid] = worker
worker.start()
return {"session_id": sid}
@app.get("/jam/next")
def jam_next(session_id: str):
"""
Get the next sequential chunk in the jam session.
This ensures chunks are delivered in order without gaps.
"""
with jam_lock:
worker = jam_registry.get(session_id)
if worker is None or not worker.is_alive():
raise HTTPException(status_code=404, detail="Session not found")
# Get the next sequential chunk (this blocks until ready)
chunk = worker.get_next_chunk()
if chunk is None:
raise HTTPException(status_code=408, detail="Chunk not ready within timeout")
return {
"chunk": {
"index": chunk.index,
"audio_base64": chunk.audio_base64,
"metadata": chunk.metadata
}
}
@app.post("/jam/consume")
def jam_consume(session_id: str = Form(...), chunk_index: int = Form(...)):
"""
Mark a chunk as consumed by the frontend.
This helps the worker manage its buffer and generation flow.
"""
with jam_lock:
worker = jam_registry.get(session_id)
if worker is None or not worker.is_alive():
raise HTTPException(status_code=404, detail="Session not found")
worker.mark_chunk_consumed(chunk_index)
return {"consumed": chunk_index}
@app.post("/jam/stop")
def jam_stop(session_id: str = Body(..., embed=True)):
with jam_lock:
worker = jam_registry.get(session_id)
if worker is None:
raise HTTPException(status_code=404, detail="Session not found")
worker.stop()
worker.join(timeout=5.0)
if worker.is_alive():
# It’s daemon=True, so it won’t block process exit, but report it
print(f"⚠️ JamWorker {session_id} did not stop within timeout")
with jam_lock:
jam_registry.pop(session_id, None)
return {"stopped": True}
@app.post("/jam/update")
def jam_update(
session_id: str = Form(...),
# knobs
guidance_weight: Optional[float] = Form(None),
temperature: Optional[float] = Form(None),
topk: Optional[int] = Form(None),
# styles
styles: str = Form(""),
style_weights: str = Form(""),
loop_weight: Optional[float] = Form(None),
use_current_mix_as_style: bool = Form(False),
# NEW steering
mean: Optional[float] = Form(None),
centroid_weights: str = Form(""),
):
asset_manager.ensure_assets_loaded(get_mrt())
with jam_lock:
worker = jam_registry.get(session_id)
if worker is None or not worker.is_alive():
raise HTTPException(status_code=404, detail="Session not found")
# 1) fast knob updates
if any(v is not None for v in (guidance_weight, temperature, topk)):
worker.update_knobs(
guidance_weight=guidance_weight,
temperature=temperature,
topk=topk
)
# 2) rebuild style only if asked
wants_style_update = (
use_current_mix_as_style
or (styles.strip() != "")
or (mean is not None)
or (centroid_weights.strip() != "")
)
if not wants_style_update:
return {"ok": True}
# --- parse inputs (robust) ---
text_list = [s.strip() for s in (styles.split(",") if styles else []) if s.strip()]
try:
tw = [float(x) for x in style_weights.split(",")] if style_weights else []
except ValueError:
tw = []
try:
cw = [float(x) for x in centroid_weights.split(",")] if centroid_weights else []
except ValueError:
cw = []
# Clamp centroid weights to available centroids (if loaded)
max_c = 0 if _CENTROIDS is None else int(_CENTROIDS.shape[0])
if max_c and len(cw) > max_c:
cw = cw[:max_c]
# Snapshot minimal state under lock
with worker._lock:
combined_loop = worker.params.combined_loop if use_current_mix_as_style else None
lw = None
if use_current_mix_as_style:
lw = 1.0 if (loop_weight is None) else float(loop_weight)
mrt = worker.mrt
# Heavy work OUTSIDE the lock
loop_embed = None
if combined_loop is not None:
loop_embed = mrt.embed_style(combined_loop)
style_vec = build_style_vector(
mrt,
text_styles=text_list,
text_weights=tw,
loop_embed=loop_embed, # None => ignored by builder
loop_weight=lw, # None => ignored by builder
mean_weight=(None if mean is None else float(mean)),
centroid_weights=cw, # [] => ignored by builder
).astype(np.float32, copy=False)
# Swap atomically
with worker._lock:
worker.params.style_vec = style_vec
return {"ok": True}
@app.post("/jam/reseed")
def jam_reseed(session_id: str = Form(...), loop_audio: UploadFile = File(None)):
with jam_lock:
worker = jam_registry.get(session_id)
if worker is None or not worker.is_alive():
raise HTTPException(status_code=404, detail="Session not found")
# Option 1: use uploaded new “combined” bounce from the app
if loop_audio is not None:
data = loop_audio.file.read()
if not data:
raise HTTPException(status_code=400, detail="Empty file")
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp:
tmp.write(data); path = tmp.name
wav = au.Waveform.from_file(path).resample(worker.mrt.sample_rate).as_stereo()
else:
# Option 2: reseed from what we’ve been streaming (the model side)
# (Usually better to reseed from the Swift-side “combined” mix you trust.)
s = getattr(worker, "_stream", None)
if s is None or s.shape[0] == 0:
raise HTTPException(status_code=400, detail="No internal stream to reseed from")
wav = au.Waveform(s.astype(np.float32, copy=False), int(worker.mrt.sample_rate)).as_stereo()
worker.reseed_from_waveform(wav)
return {"ok": True}
@app.post("/jam/reseed_splice")
def jam_reseed_splice(
session_id: str = Form(...),
anchor_bars: float = Form(2.0), # how much of the original to re-inject
combined_audio: UploadFile = File(None), # preferred: Swift supplies the current combined mix
):
worker = jam_registry.get(session_id)
if worker is None or not worker.is_alive():
raise HTTPException(status_code=404, detail="Session not found")
# Build a waveform to reseed from
wav = None
if combined_audio is not None:
data = combined_audio.file.read()
if not data:
raise HTTPException(status_code=400, detail="Empty combined_audio")
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp:
tmp.write(data)
path = tmp.name
wav = au.Waveform.from_file(path).resample(worker.mrt.sample_rate).as_stereo()
else:
# Fallback: reseed from the model’s internal stream (less ideal than the Swift-side bounce)
s = getattr(worker, "_stream", None)
if s is None or s.shape[0] == 0:
raise HTTPException(status_code=400, detail="No audio available to reseed from")
wav = au.Waveform(s.astype(np.float32, copy=False), int(worker.mrt.sample_rate)).as_stereo()
# Perform the splice reseed
worker.reseed_splice(wav, anchor_bars=float(anchor_bars))
return {"ok": True, "anchor_bars": float(anchor_bars)}
@app.get("/jam/status")
def jam_status(session_id: str):
with jam_lock:
worker = jam_registry.get(session_id)
if worker is None:
raise HTTPException(status_code=404, detail="Session not found")
running = worker.is_alive()
# Snapshot safely
with worker._lock:
last_generated = int(worker.idx)
last_delivered = int(worker._last_delivered_index)
queued = len(worker.outbox)
buffer_ahead = last_generated - last_delivered
p = worker.params
spb = p.beats_per_bar * (60.0 / p.bpm)
chunk_secs = p.bars_per_chunk * spb
return {
"running": running,
"last_generated_index": last_generated, # Last chunk that finished generating
"last_delivered_index": last_delivered, # Last chunk sent to frontend
"buffer_ahead": buffer_ahead, # How many chunks ahead we are
"queued_chunks": queued, # Total chunks in outbox
"bpm": p.bpm,
"beats_per_bar": p.beats_per_bar,
"bars_per_chunk": p.bars_per_chunk,
"seconds_per_bar": spb,
"chunk_duration_seconds": chunk_secs,
"target_sample_rate": p.target_sr,
"last_chunk_started_at": worker.last_chunk_started_at,
"last_chunk_completed_at": worker.last_chunk_completed_at,
}
@app.get("/health")
def health():
# 1) Template mode → not ready (encourage duplication on GPU)
if SPACE_MODE != "serve":
return JSONResponse(
status_code=503,
content={
"ok": False,
"status": "template_mode",
"message": "This Space is a GPU template. Duplicate it and select an L40s/A100-class runtime to use the API.",
"mode": SPACE_MODE,
},
)
# 2) Runtime hardware probe
probe = _gpu_probe()
if not probe["ok"] or not probe["has_gpu"] or probe.get("backend") != "gpu":
return JSONResponse(
status_code=503,
content={
"ok": False,
"status": "gpu_unavailable",
"message": "GPU is not visible to JAX. Select a GPU runtime (e.g., L40s) to serve.",
"probe": probe,
"mode": SPACE_MODE,
},
)
# 3) Ready; include operational hints
warmed = bool(_WARMED)
with jam_lock:
active_jams = sum(1 for w in jam_registry.values() if w.is_alive())
return {
"ok": True,
"status": "ready" if warmed else "initializing",
"mode": SPACE_MODE,
"warmed": warmed,
"active_jams": active_jams,
"probe": probe,
}
@app.middleware("http")
async def log_requests(request: Request, call_next):
rid = request.headers.get("X-Request-ID", "-")
print(f"📥 {request.method} {request.url.path}?{request.url.query} [rid={rid}]")
try:
response = await call_next(request)
except Exception as e:
print(f"💥 exception for {request.url.path} [rid={rid}]: {e}")
raise
print(f"📤 {response.status_code} {request.url.path} [rid={rid}]")
return response
# ----------------------------
# websockets route
# ----------------------------
@app.websocket("/ws/jam")
async def ws_jam(websocket: WebSocket):
await websocket.accept()
sid = None
worker = None
binary_audio = False
mode = "rt" # or "bar"
# NEW: capture ws in closure
async def send_json(obj):
return await send_json_safe(websocket, obj)
try:
while True:
raw = await websocket.receive_text()
msg = json.loads(raw)
mtype = msg.get("type")
# --- START ---
if mtype == "start":
binary_audio = bool(msg.get("binary_audio", False))
mode = msg.get("mode", "rt")
params = msg.get("params", {}) or {}
sid = msg.get("session_id")
# attach or create
if sid:
with jam_lock:
worker = jam_registry.get(sid)
if worker is None or not worker.is_alive():
await send_json({"type":"error","error":"Session not found"})
continue
else:
# optionally accept base64 loop and start a new worker (bar-mode)
if mode == "bar":
loop_b64 = msg.get("loop_audio_b64")
if not loop_b64:
await send_json({"type":"error","error":"loop_audio_b64 required for mode=bar when no session_id"})
continue
loop_bytes = base64.b64decode(loop_b64)
# mimic /jam/start
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp:
tmp.write(loop_bytes); tmp_path = tmp.name
# build JamParams similar to /jam/start
mrt = get_mrt()
model_sr = int(mrt.sample_rate) # typically 48000
# Defaults for WS: raw loudness @ model SR, unless overridden by client:
target_sr = int(params.get("target_sr", model_sr))
loudness_mode = params.get("loudness_mode", "none")
headroom_db = float(params.get("headroom_db", 1.0))
loop = au.Waveform.from_file(tmp_path).resample(mrt.sample_rate).as_stereo()
codec_fps = float(mrt.codec.frame_rate)
ctx_seconds = float(mrt.config.context_length_frames) / codec_fps
bpm = float(params.get("bpm", 120.0))
bpb = int(params.get("beats_per_bar", 4))
loop_tail = take_bar_aligned_tail(loop, bpm, bpb, ctx_seconds)
# style vector (loop + extra styles)
embeds, weights = [mrt.embed_style(loop_tail)], [float(params.get("loop_weight", 1.0))]
extra = [s for s in (params.get("styles","").split(",")) if s.strip()]
sw = [float(x) for x in params.get("style_weights","").split(",") if x.strip()]
for i, s in enumerate(extra):
embeds.append(mrt.embed_style(s.strip()))
weights.append(sw[i] if i < len(sw) else 1.0)
wsum = sum(weights) or 1.0
weights = [w/wsum for w in weights]
style_vec = np.sum([w*e for w, e in zip(weights, embeds)], axis=0).astype(np.float32)
# target SR fallback: input SR
inp_info = sf.info(tmp_path)
target_sr = int(params.get("target_sr", int(inp_info.samplerate)))
# Build JamParams for WS bar-mode
jp = JamParams(
bpm=bpm, beats_per_bar=bpb, bars_per_chunk=int(params.get("bars_per_chunk", 8)),
target_sr=target_sr,
loudness_mode=loudness_mode, headroom_db=headroom_db,
style_vec=style_vec,
ref_loop=None if loudness_mode == "none" else loop_tail, # disable match by default
combined_loop=loop,
guidance_weight=float(params.get("guidance_weight", 1.1)),
temperature=float(params.get("temperature", 1.1)),
topk=int(params.get("topk", 40)),
)
worker = JamWorker(get_mrt(), jp)
sid = str(uuid.uuid4())
with jam_lock:
# single active jam per GPU, mirroring /jam/start
for _sid, w in list(jam_registry.items()):
if w.is_alive():
await send_json({"type":"error","error":"A jam is already running"})
worker = None; sid = None
break
if worker is not None:
jam_registry[sid] = worker
worker.start()
else:
# mode == "rt" (Colab-style, no loop context)
mrt = get_mrt()
state = mrt.init_state()
# Build silent context (10s) tokens
codec_fps = float(mrt.codec.frame_rate)
ctx_seconds = float(mrt.config.context_length_frames) / codec_fps
sr = int(mrt.sample_rate)
samples = int(max(1, round(ctx_seconds * sr)))
silent = au.Waveform(np.zeros((samples, 2), np.float32), sr)
tokens = mrt.codec.encode(silent).astype(np.int32)[:, :mrt.config.decoder_codec_rvq_depth]
state.context_tokens = tokens
# Parse params (including steering)
asset_manager.ensure_assets_loaded(get_mrt())
styles_str = params.get("styles", "warmup") or ""
style_weights_str = params.get("style_weights", "") or ""
mean_w = float(params.get("mean", 0.0) or 0.0)
cw_str = str(params.get("centroid_weights", "") or "")
text_list = [s.strip() for s in styles_str.split(",") if s.strip()]
try:
text_w = [float(x) for x in style_weights_str.split(",")] if style_weights_str else []
except ValueError:
text_w = []
try:
cw = [float(x) for x in cw_str.split(",") if x.strip() != ""]
except ValueError:
cw = []
# Clamp centroid weights to available centroids
if _CENTROIDS is not None and len(cw) > int(_CENTROIDS.shape[0]):
cw = cw[: int(_CENTROIDS.shape[0])]
# Build initial style vector (no loop_embed in rt mode)
style_vec = build_style_vector(
mrt,
text_styles=text_list,
text_weights=text_w,
loop_embed=None,
loop_weight=None,
mean_weight=mean_w,
centroid_weights=cw,
)
# Stash rt session fields
websocket._mrt = mrt
websocket._state = state
websocket._style_cur = style_vec
websocket._style_tgt = style_vec
websocket._style_ramp_s = float(params.get("style_ramp_seconds", 0.0))
websocket._rt_mean = mean_w
websocket._rt_centroid_weights = cw
websocket._rt_running = True
websocket._rt_sr = sr
websocket._rt_topk = int(params.get("topk", 40))
websocket._rt_temp = float(params.get("temperature", 1.1))
websocket._rt_guid = float(params.get("guidance_weight", 1.1))
websocket._pace = params.get("pace", "asap") # "realtime" | "asap"
# (Optional) report whether steering assets were loaded
assets_ok = (_MEAN_EMBED is not None) or (_CENTROIDS is not None)
await send_json({"type": "started", "mode": "rt", "steering_assets": "loaded" if assets_ok else "none"})
# kick off the ~2s streaming loop
async def _rt_loop():
try:
mrt = websocket._mrt
chunk_secs = (mrt.config.chunk_length_frames * mrt.config.frame_length_samples) / float(mrt.sample_rate)
target_next = time.perf_counter()
while websocket._rt_running:
mrt.guidance_weight = websocket._rt_guid
mrt.temperature = websocket._rt_temp
mrt.topk = websocket._rt_topk
# ramp style
ramp = float(getattr(websocket, "_style_ramp_s", 0.0) or 0.0)
if ramp <= 0.0:
websocket._style_cur = websocket._style_tgt
else:
step = min(1.0, chunk_secs / ramp)
websocket._style_cur = websocket._style_cur + step * (websocket._style_tgt - websocket._style_cur)
wav, new_state = mrt.generate_chunk(state=websocket._state, style=websocket._style_cur)
websocket._state = new_state
x = wav.samples.astype(np.float32, copy=False)
buf = io.BytesIO()
sf.write(buf, x, mrt.sample_rate, subtype="FLOAT", format="WAV")
ok = True
if binary_audio:
try:
await websocket.send_bytes(buf.getvalue())
ok = await send_json({"type": "chunk_meta", "metadata": {"sample_rate": mrt.sample_rate}})
except Exception:
ok = False
else:
b64 = base64.b64encode(buf.getvalue()).decode("utf-8")
ok = await send_json({"type": "chunk", "audio_base64": b64,
"metadata": {"sample_rate": mrt.sample_rate}})
if not ok:
break
if getattr(websocket, "_pace", "asap") == "realtime":
t1 = time.perf_counter()
target_next += chunk_secs
sleep_s = max(0.0, target_next - t1 - 0.02)
if sleep_s > 0:
await asyncio.sleep(sleep_s)
except asyncio.CancelledError:
pass
except Exception:
pass
websocket._rt_task = asyncio.create_task(_rt_loop())
continue # skip the “bar-mode started” message below
await send_json({"type":"started","session_id": sid, "mode": mode})
# if we’re in bar-mode, begin pushing chunks as they arrive
if mode == "bar" and worker is not None:
async def _pump():
while True:
if not worker.is_alive():
break
chunk = worker.get_next_chunk(timeout=60.0)
if chunk is None:
continue
if binary_audio:
await websocket.send_bytes(base64.b64decode(chunk.audio_base64))
await send_json({"type":"chunk_meta","index":chunk.index,"metadata":chunk.metadata})
else:
await send_json({"type":"chunk","index":chunk.index,
"audio_base64":chunk.audio_base64,"metadata":chunk.metadata})
asyncio.create_task(_pump())
# --- UPDATES (bar or rt) ---
elif mtype == "update":
if mode == "bar":
if not sid:
await send_json({"type":"error","error":"No session_id yet"}); return
# fan values straight into your existing HTTP handler:
res = jam_update(
session_id=sid,
guidance_weight=msg.get("guidance_weight"),
temperature=msg.get("temperature"),
topk=msg.get("topk"),
styles=msg.get("styles",""),
style_weights=msg.get("style_weights",""),
loop_weight=msg.get("loop_weight"),
use_current_mix_as_style=bool(msg.get("use_current_mix_as_style", False)),
)
await send_json({"type":"status", **res}) # {"ok": True}
else:
# rt-mode: there’s no JamWorker; update the local knobs/state
websocket._rt_temp = float(msg.get("temperature", websocket._rt_temp))
websocket._rt_topk = int(msg.get("topk", websocket._rt_topk))
websocket._rt_guid = float(msg.get("guidance_weight", websocket._rt_guid))
# NEW steering fields
if "mean" in msg and msg["mean"] is not None:
try: websocket._rt_mean = float(msg["mean"])
except: websocket._rt_mean = 0.0
if "centroid_weights" in msg:
cw = [w.strip() for w in str(msg["centroid_weights"]).split(",") if w.strip() != ""]
try:
websocket._rt_centroid_weights = [float(x) for x in cw]
except:
websocket._rt_centroid_weights = []
# styles / text weights (optional, comma-separated)
styles_str = msg.get("styles", None)
style_weights_str = msg.get("style_weights", "")
text_list = [s for s in (styles_str.split(",") if styles_str else []) if s.strip()]
text_w = [float(x) for x in style_weights_str.split(",")] if style_weights_str else []
asset_manager.ensure_assets_loaded(get_mrt())
websocket._style_tgt = build_style_vector(
websocket._mrt,
text_styles=text_list,
text_weights=text_w,
loop_embed=None,
loop_weight=None,
mean_weight=float(websocket._rt_mean),
centroid_weights=websocket._rt_centroid_weights,
)
# optionally allow live changes to ramp:
if "style_ramp_seconds" in msg:
try: websocket._style_ramp_s = float(msg["style_ramp_seconds"])
except: pass
await send_json({"type":"status","updated":"rt-knobs+style"})
elif mtype == "consume" and mode == "bar":
with jam_lock:
worker = jam_registry.get(msg.get("session_id"))
if worker is not None:
worker.mark_chunk_consumed(int(msg.get("chunk_index", -1)))
elif mtype == "reseed" and mode == "bar":
with jam_lock:
worker = jam_registry.get(msg.get("session_id"))
if worker is None or not worker.is_alive():
await send_json({"type":"error","error":"Session not found"}); continue
loop_b64 = msg.get("loop_audio_b64")
if not loop_b64:
await send_json({"type":"error","error":"loop_audio_b64 required"}); continue
loop_bytes = base64.b64decode(loop_b64)
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp:
tmp.write(loop_bytes); path = tmp.name
wav = au.Waveform.from_file(path).resample(worker.mrt.sample_rate).as_stereo()
worker.reseed_from_waveform(wav)
await send_json({"type":"status","reseeded":True})
elif mtype == "reseed_splice" and mode == "bar":
with jam_lock:
worker = jam_registry.get(msg.get("session_id"))
if worker is None or not worker.is_alive():
await send_json({"type":"error","error":"Session not found"}); continue
anchor = float(msg.get("anchor_bars", 2.0))
b64 = msg.get("combined_audio_b64")
if b64:
data = base64.b64decode(b64)
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp:
tmp.write(data); path = tmp.name
wav = au.Waveform.from_file(path).resample(worker.mrt.sample_rate).as_stereo()
worker.reseed_splice(wav, anchor_bars=anchor)
else:
# fallback: model-side stream splice
worker.reseed_splice(worker.params.combined_loop, anchor_bars=anchor)
await send_json({"type":"status","splice":anchor})
elif mtype == "stop":
if mode == "rt":
websocket._rt_running = False
task = getattr(websocket, "_rt_task", None)
if task is not None:
task.cancel()
try: await task
except asyncio.CancelledError: pass
await send_json({"type":"stopped"})
break # <- add this if you want to end the socket after stop
elif mtype == "ping":
await send_json({"type":"pong"})
else:
await send_json({"type":"error","error":f"Unknown type {mtype}"})
except WebSocketDisconnect:
# best-effort cleanup for bar-mode sessions started within this socket (optional)
pass
except Exception as e:
try:
await send_json({"type":"error","error":str(e)})
except Exception:
pass
finally:
try:
if websocket.client_state != WebSocketState.DISCONNECTED:
await websocket.close()
except Exception:
pass
@app.get("/ping")
def ping():
return {"ok": True}
@app.get("/", response_class=Response)
def read_root():
"""Root endpoint that explains what this API does"""
try:
html_file = Path(__file__).parent / "documentation.html"
html_content = html_file.read_text(encoding='utf-8')
except FileNotFoundError:
# Fallback if file is missing
html_content = """
<!DOCTYPE html>
<html><body>
<h1>MagentaRT Research API</h1>
<p>Documentation file not found. Please check documentation.html</p>
</body></html>
"""
return Response(content=html_content, media_type="text/html")
@app.get("/lil_demo_540p.mp4")
def demo_video():
return FileResponse(Path(__file__).parent / "lil_demo_540p.mp4", media_type="video/mp4")
@app.get("/tester", response_class=HTMLResponse)
def tester():
html_path = Path(__file__).parent / "magentaRT_rt_tester.html"
return HTMLResponse(
html_path.read_text(encoding="utf-8"),
headers={"Cache-Control": "no-store"} # avoid sticky caches while iterating
) |