Commit
·
a8de318
1
Parent(s):
80346b6
attempting to add centroids/mean
Browse files
app.py
CHANGED
@@ -32,7 +32,7 @@ except Exception:
|
|
32 |
|
33 |
from magenta_rt import system, audio as au
|
34 |
import numpy as np
|
35 |
-
from fastapi import FastAPI, UploadFile, File, Form, Body, HTTPException, Response, Request, WebSocket, WebSocketDisconnect
|
36 |
import tempfile, io, base64, math, threading
|
37 |
from fastapi.middleware.cors import CORSMiddleware
|
38 |
from contextlib import contextmanager
|
@@ -70,6 +70,78 @@ import re, tarfile
|
|
70 |
from pathlib import Path
|
71 |
from huggingface_hub import snapshot_download
|
72 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
73 |
def _resolve_checkpoint_dir() -> str | None:
|
74 |
repo_id = os.getenv("MRT_CKPT_REPO")
|
75 |
if not repo_id:
|
@@ -540,6 +612,9 @@ def generate_loop_continuation_with_mrt(
|
|
540 |
|
541 |
return out, loud_stats
|
542 |
|
|
|
|
|
|
|
543 |
def generate_style_only_with_mrt(
|
544 |
mrt,
|
545 |
bpm: float,
|
@@ -610,6 +685,92 @@ def generate_style_only_with_mrt(
|
|
610 |
|
611 |
return out, None # loudness stats not applicable (no reference)
|
612 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
613 |
|
614 |
|
615 |
|
@@ -669,7 +830,6 @@ def _mrt_warmup():
|
|
669 |
beats_per_bar = 4
|
670 |
|
671 |
# --- build a silent, stereo context of ctx_seconds ---
|
672 |
-
import numpy as np, soundfile as sf
|
673 |
samples = int(max(1, round(ctx_seconds * sr)))
|
674 |
silent = np.zeros((samples, 2), dtype=np.float32)
|
675 |
|
@@ -745,6 +905,28 @@ def model_swap(step: int = Form(...)):
|
|
745 |
# optionally pre-warm here by calling get_mrt()
|
746 |
return {"reloaded": True, "step": step}
|
747 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
748 |
@app.post("/generate")
|
749 |
def generate(
|
750 |
loop_audio: UploadFile = File(...),
|
@@ -911,6 +1093,11 @@ def jam_start(
|
|
911 |
styles: str = Form(""),
|
912 |
style_weights: str = Form(""),
|
913 |
loop_weight: float = Form(1.0),
|
|
|
|
|
|
|
|
|
|
|
914 |
loudness_mode: str = Form("auto"),
|
915 |
loudness_headroom_db: float = Form(1.0),
|
916 |
guidance_weight: float = Form(1.1),
|
@@ -918,6 +1105,8 @@ def jam_start(
|
|
918 |
topk: int = Form(40),
|
919 |
target_sample_rate: int | None = Form(None),
|
920 |
):
|
|
|
|
|
921 |
# enforce single active jam per GPU
|
922 |
with jam_lock:
|
923 |
for sid, w in list(jam_registry.items()):
|
@@ -938,16 +1127,32 @@ def jam_start(
|
|
938 |
ctx_seconds = float(mrt.config.context_length_frames) / codec_fps
|
939 |
loop_tail = take_bar_aligned_tail(loop, bpm, beats_per_bar, ctx_seconds)
|
940 |
|
941 |
-
#
|
942 |
-
|
943 |
-
|
944 |
-
|
945 |
-
|
946 |
-
|
947 |
-
|
948 |
-
|
949 |
-
|
950 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
951 |
|
952 |
# target SR (default input SR)
|
953 |
inp_info = sf.info(tmp_path)
|
@@ -1036,27 +1241,33 @@ def jam_stop(session_id: str = Body(..., embed=True)):
|
|
1036 |
jam_registry.pop(session_id, None)
|
1037 |
return {"stopped": True}
|
1038 |
|
1039 |
-
@app.post("/jam/update")
|
1040 |
def jam_update(
|
1041 |
session_id: str = Form(...),
|
1042 |
|
1043 |
-
# knobs
|
1044 |
guidance_weight: Optional[float] = Form(None),
|
1045 |
temperature: Optional[float] = Form(None),
|
1046 |
topk: Optional[int] = Form(None),
|
1047 |
|
1048 |
-
# styles
|
1049 |
styles: str = Form(""),
|
1050 |
style_weights: str = Form(""),
|
1051 |
-
loop_weight: Optional[float] = Form(None),
|
1052 |
use_current_mix_as_style: bool = Form(False),
|
|
|
|
|
|
|
|
|
1053 |
):
|
|
|
|
|
1054 |
with jam_lock:
|
1055 |
worker = jam_registry.get(session_id)
|
1056 |
if worker is None or not worker.is_alive():
|
1057 |
raise HTTPException(status_code=404, detail="Session not found")
|
1058 |
|
1059 |
-
#
|
1060 |
if any(v is not None for v in (guidance_weight, temperature, topk)):
|
1061 |
worker.update_knobs(
|
1062 |
guidance_weight=guidance_weight,
|
@@ -1064,35 +1275,62 @@ def jam_update(
|
|
1064 |
topk=topk
|
1065 |
)
|
1066 |
|
1067 |
-
#
|
1068 |
-
wants_style_update =
|
1069 |
-
|
1070 |
-
|
1071 |
-
|
1072 |
-
|
1073 |
-
|
1074 |
-
|
1075 |
-
|
1076 |
-
|
1077 |
-
|
1078 |
-
|
1079 |
-
|
1080 |
-
|
1081 |
-
|
1082 |
-
|
1083 |
-
|
1084 |
-
|
1085 |
-
|
1086 |
-
|
1087 |
-
|
1088 |
-
|
1089 |
-
|
1090 |
-
|
1091 |
-
|
1092 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1093 |
|
1094 |
return {"ok": True}
|
1095 |
|
|
|
1096 |
@app.post("/jam/reseed")
|
1097 |
def jam_reseed(session_id: str = Form(...), loop_audio: UploadFile = File(None)):
|
1098 |
with jam_lock:
|
@@ -1217,21 +1455,6 @@ async def log_requests(request: Request, call_next):
|
|
1217 |
# ----------------------------
|
1218 |
|
1219 |
|
1220 |
-
|
1221 |
-
def _combine_styles(mrt, styles_str: str = "", weights_str: str = ""):
|
1222 |
-
extra = [s.strip() for s in (styles_str or "").split(",") if s.strip()]
|
1223 |
-
if not extra:
|
1224 |
-
return mrt.embed_style("warmup")
|
1225 |
-
sw = [float(x) for x in (weights_str or "").split(",") if x.strip()]
|
1226 |
-
embeds, weights = [], []
|
1227 |
-
for i, s in enumerate(extra):
|
1228 |
-
embeds.append(mrt.embed_style(s))
|
1229 |
-
weights.append(sw[i] if i < len(sw) else 1.0)
|
1230 |
-
wsum = sum(weights) or 1.0
|
1231 |
-
weights = [w/wsum for w in weights]
|
1232 |
-
import numpy as np
|
1233 |
-
return np.sum([w*e for w, e in zip(weights, embeds)], axis=0).astype(np.float32)
|
1234 |
-
|
1235 |
@app.websocket("/ws/jam")
|
1236 |
async def ws_jam(websocket: WebSocket):
|
1237 |
await websocket.accept()
|
@@ -1253,7 +1476,7 @@ async def ws_jam(websocket: WebSocket):
|
|
1253 |
# --- START ---
|
1254 |
if mtype == "start":
|
1255 |
binary_audio = bool(msg.get("binary_audio", False))
|
1256 |
-
mode = msg.get("mode", "
|
1257 |
params = msg.get("params", {}) or {}
|
1258 |
sid = msg.get("session_id")
|
1259 |
|
@@ -1332,37 +1555,75 @@ async def ws_jam(websocket: WebSocket):
|
|
1332 |
|
1333 |
else:
|
1334 |
# mode == "rt" (Colab-style, no loop context)
|
1335 |
-
# seed a fresh state with a silent context like warmup
|
1336 |
mrt = get_mrt()
|
1337 |
state = mrt.init_state()
|
1338 |
-
|
|
|
|
|
1339 |
ctx_seconds = float(mrt.config.context_length_frames) / codec_fps
|
1340 |
sr = int(mrt.sample_rate)
|
1341 |
samples = int(max(1, round(ctx_seconds * sr)))
|
1342 |
-
silent = au.Waveform(np.zeros((samples,2), np.float32), sr)
|
1343 |
tokens = mrt.codec.encode(silent).astype(np.int32)[:, :mrt.config.decoder_codec_rvq_depth]
|
1344 |
state.context_tokens = tokens
|
1345 |
|
1346 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1347 |
websocket._state = state
|
1348 |
-
websocket._style =
|
1349 |
-
|
1350 |
-
|
1351 |
-
websocket.
|
1352 |
-
websocket.
|
1353 |
-
websocket.
|
1354 |
-
websocket.
|
1355 |
-
websocket.
|
1356 |
-
websocket.
|
1357 |
-
|
1358 |
-
|
|
|
|
|
|
|
|
|
|
|
1359 |
async def _rt_loop():
|
1360 |
try:
|
1361 |
mrt = websocket._mrt
|
1362 |
chunk_secs = (mrt.config.chunk_length_frames * mrt.config.frame_length_samples) / float(mrt.sample_rate)
|
1363 |
target_next = time.perf_counter()
|
1364 |
while websocket._rt_running:
|
1365 |
-
# read knobs (already set by update)
|
1366 |
mrt.guidance_weight = websocket._rt_guid
|
1367 |
mrt.temperature = websocket._rt_temp
|
1368 |
mrt.topk = websocket._rt_topk
|
@@ -1374,37 +1635,32 @@ async def ws_jam(websocket: WebSocket):
|
|
1374 |
buf = io.BytesIO()
|
1375 |
sf.write(buf, x, mrt.sample_rate, subtype="FLOAT", format="WAV")
|
1376 |
|
1377 |
-
# send bytes / json best-effort
|
1378 |
ok = True
|
1379 |
if binary_audio:
|
1380 |
try:
|
1381 |
await websocket.send_bytes(buf.getvalue())
|
1382 |
-
ok = await send_json({"type":"chunk_meta","metadata":{"sample_rate":mrt.sample_rate}})
|
1383 |
except Exception:
|
1384 |
ok = False
|
1385 |
else:
|
1386 |
b64 = base64.b64encode(buf.getvalue()).decode("utf-8")
|
1387 |
-
ok = await send_json({"type":"chunk","audio_base64":b64,
|
1388 |
-
"metadata":{"sample_rate":mrt.sample_rate}})
|
1389 |
|
1390 |
if not ok:
|
1391 |
-
# client went away — exit cleanly
|
1392 |
break
|
1393 |
|
1394 |
-
# pacing (use captured flag from start)
|
1395 |
if getattr(websocket, "_pace", "asap") == "realtime":
|
1396 |
t1 = time.perf_counter()
|
1397 |
target_next += chunk_secs
|
1398 |
sleep_s = max(0.0, target_next - t1 - 0.02)
|
1399 |
if sleep_s > 0:
|
1400 |
await asyncio.sleep(sleep_s)
|
1401 |
-
|
1402 |
except asyncio.CancelledError:
|
1403 |
-
# normal on stop/close — just exit
|
1404 |
pass
|
1405 |
except Exception:
|
1406 |
-
# don't try to send an error; socket may be closed
|
1407 |
pass
|
|
|
1408 |
websocket._rt_task = asyncio.create_task(_rt_loop())
|
1409 |
continue # skip the “bar-mode started” message below
|
1410 |
|
@@ -1450,13 +1706,37 @@ async def ws_jam(websocket: WebSocket):
|
|
1450 |
websocket._rt_topk = int(msg.get("topk", websocket._rt_topk))
|
1451 |
websocket._rt_guid = float(msg.get("guidance_weight", websocket._rt_guid))
|
1452 |
|
1453 |
-
|
1454 |
-
|
1455 |
-
|
1456 |
-
|
1457 |
-
|
1458 |
-
|
1459 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1460 |
|
1461 |
elif mtype == "consume" and mode == "bar":
|
1462 |
with jam_lock:
|
|
|
32 |
|
33 |
from magenta_rt import system, audio as au
|
34 |
import numpy as np
|
35 |
+
from fastapi import FastAPI, UploadFile, File, Form, Body, HTTPException, Response, Request, WebSocket, WebSocketDisconnect, Query
|
36 |
import tempfile, io, base64, math, threading
|
37 |
from fastapi.middleware.cors import CORSMiddleware
|
38 |
from contextlib import contextmanager
|
|
|
70 |
from pathlib import Path
|
71 |
from huggingface_hub import snapshot_download
|
72 |
|
73 |
+
# ---- Finetune assets (mean & centroids) --------------------------------------
|
74 |
+
_FINETUNE_REPO_DEFAULT = os.getenv("MRT_ASSETS_REPO", "thepatch/magenta-ft")
|
75 |
+
_ASSETS_REPO_ID: str | None = None
|
76 |
+
_MEAN_EMBED: np.ndarray | None = None # shape (D,) dtype float32
|
77 |
+
_CENTROIDS: np.ndarray | None = None # shape (K, D) dtype float32
|
78 |
+
|
79 |
+
def _load_finetune_assets_from_hf(repo_id: str | None) -> tuple[bool, str]:
|
80 |
+
"""
|
81 |
+
Download & load mean_style_embed.npy and cluster_centroids.npy from a HF model repo.
|
82 |
+
Safe to call multiple times; will overwrite globals if successful.
|
83 |
+
"""
|
84 |
+
global _ASSETS_REPO_ID, _MEAN_EMBED, _CENTROIDS
|
85 |
+
repo_id = repo_id or _FINETUNE_REPO_DEFAULT
|
86 |
+
try:
|
87 |
+
from huggingface_hub import hf_hub_download
|
88 |
+
mean_path = None
|
89 |
+
cent_path = None
|
90 |
+
try:
|
91 |
+
mean_path = hf_hub_download(repo_id, filename="mean_style_embed.npy", repo_type="model")
|
92 |
+
except Exception:
|
93 |
+
pass
|
94 |
+
try:
|
95 |
+
cent_path = hf_hub_download(repo_id, filename="cluster_centroids.npy", repo_type="model")
|
96 |
+
except Exception:
|
97 |
+
pass
|
98 |
+
|
99 |
+
if mean_path is None and cent_path is None:
|
100 |
+
return False, f"No finetune asset files found in repo {repo_id}"
|
101 |
+
|
102 |
+
if mean_path is not None:
|
103 |
+
m = np.load(mean_path)
|
104 |
+
if m.ndim != 1:
|
105 |
+
return False, f"mean_style_embed.npy must be 1-D (got {m.shape})"
|
106 |
+
else:
|
107 |
+
m = None
|
108 |
+
|
109 |
+
if cent_path is not None:
|
110 |
+
c = np.load(cent_path)
|
111 |
+
if c.ndim != 2:
|
112 |
+
return False, f"cluster_centroids.npy must be 2-D (got {c.shape})"
|
113 |
+
else:
|
114 |
+
c = None
|
115 |
+
|
116 |
+
# Optional: shape check vs model embedding dim once model is alive
|
117 |
+
try:
|
118 |
+
d = int(get_mrt().style_model.config.embedding_dim)
|
119 |
+
if m is not None and m.shape[0] != d:
|
120 |
+
return False, f"mean_style_embed dim {m.shape[0]} != model dim {d}"
|
121 |
+
if c is not None and c.shape[1] != d:
|
122 |
+
return False, f"cluster_centroids dim {c.shape[1]} != model dim {d}"
|
123 |
+
except Exception:
|
124 |
+
# Model not built yet; we’ll trust the files and rely on runtime checks later
|
125 |
+
pass
|
126 |
+
|
127 |
+
_MEAN_EMBED = m.astype(np.float32, copy=False) if m is not None else None
|
128 |
+
_CENTROIDS = c.astype(np.float32, copy=False) if c is not None else None
|
129 |
+
_ASSETS_REPO_ID = repo_id
|
130 |
+
logging.info("Loaded finetune assets from %s (mean=%s, centroids=%s)",
|
131 |
+
repo_id,
|
132 |
+
"yes" if _MEAN_EMBED is not None else "no",
|
133 |
+
f"{_CENTROIDS.shape[0]}x{_CENTROIDS.shape[1]}" if _CENTROIDS is not None else "no")
|
134 |
+
return True, "ok"
|
135 |
+
except Exception as e:
|
136 |
+
logging.exception("Failed to load finetune assets: %s", e)
|
137 |
+
return False, str(e)
|
138 |
+
|
139 |
+
def _ensure_assets_loaded():
|
140 |
+
# Best-effort lazy load if nothing is loaded yet
|
141 |
+
if _MEAN_EMBED is None and _CENTROIDS is None:
|
142 |
+
_load_finetune_assets_from_hf(_ASSETS_REPO_ID or _FINETUNE_REPO_DEFAULT)
|
143 |
+
# ------------------------------------------------------------------------------
|
144 |
+
|
145 |
def _resolve_checkpoint_dir() -> str | None:
|
146 |
repo_id = os.getenv("MRT_CKPT_REPO")
|
147 |
if not repo_id:
|
|
|
612 |
|
613 |
return out, loud_stats
|
614 |
|
615 |
+
# untested.
|
616 |
+
# not sure how it will retain the input bpm. we may want to use a metronome instead of silence. i think google might do that.
|
617 |
+
# does a generation with silent context rather than a combined loop
|
618 |
def generate_style_only_with_mrt(
|
619 |
mrt,
|
620 |
bpm: float,
|
|
|
685 |
|
686 |
return out, None # loudness stats not applicable (no reference)
|
687 |
|
688 |
+
def _combine_styles(mrt, styles_str: str = "", weights_str: str = ""):
|
689 |
+
extra = [s.strip() for s in (styles_str or "").split(",") if s.strip()]
|
690 |
+
if not extra:
|
691 |
+
return mrt.embed_style("warmup")
|
692 |
+
sw = [float(x) for x in (weights_str or "").split(",") if x.strip()]
|
693 |
+
embeds, weights = [], []
|
694 |
+
for i, s in enumerate(extra):
|
695 |
+
embeds.append(mrt.embed_style(s))
|
696 |
+
weights.append(sw[i] if i < len(sw) else 1.0)
|
697 |
+
wsum = sum(weights) or 1.0
|
698 |
+
weights = [w/wsum for w in weights]
|
699 |
+
import numpy as np
|
700 |
+
return np.sum([w*e for w, e in zip(weights, embeds)], axis=0).astype(np.float32)
|
701 |
+
|
702 |
+
def build_style_vector(
|
703 |
+
mrt,
|
704 |
+
*,
|
705 |
+
text_styles: list[str] | None = None,
|
706 |
+
text_weights: list[float] | None = None,
|
707 |
+
loop_embed: np.ndarray | None = None,
|
708 |
+
loop_weight: float | None = None,
|
709 |
+
mean_weight: float | None = None,
|
710 |
+
centroid_weights: list[float] | None = None,
|
711 |
+
) -> np.ndarray:
|
712 |
+
"""
|
713 |
+
Returns a single style embedding combining:
|
714 |
+
- loop embedding (optional)
|
715 |
+
- one or more text style embeddings (optional)
|
716 |
+
- mean finetune embedding (optional)
|
717 |
+
- centroid embeddings (optional)
|
718 |
+
All weights are normalized so they sum to 1 if > 0.
|
719 |
+
"""
|
720 |
+
comps: list[np.ndarray] = []
|
721 |
+
weights: list[float] = []
|
722 |
+
|
723 |
+
# loop component
|
724 |
+
if loop_embed is not None and (loop_weight or 0) > 0:
|
725 |
+
comps.append(loop_embed.astype(np.float32, copy=False))
|
726 |
+
weights.append(float(loop_weight))
|
727 |
+
|
728 |
+
# text components
|
729 |
+
if text_styles:
|
730 |
+
for i, s in enumerate(text_styles):
|
731 |
+
s = s.strip()
|
732 |
+
if not s:
|
733 |
+
continue
|
734 |
+
w = 1.0
|
735 |
+
if text_weights and i < len(text_weights):
|
736 |
+
try: w = float(text_weights[i])
|
737 |
+
except: w = 1.0
|
738 |
+
if w <= 0:
|
739 |
+
continue
|
740 |
+
e = mrt.embed_style(s)
|
741 |
+
comps.append(e.astype(np.float32, copy=False))
|
742 |
+
weights.append(w)
|
743 |
+
|
744 |
+
# mean finetune
|
745 |
+
if mean_weight and (_MEAN_EMBED is not None) and mean_weight > 0:
|
746 |
+
comps.append(_MEAN_EMBED)
|
747 |
+
weights.append(float(mean_weight))
|
748 |
+
|
749 |
+
# centroid components
|
750 |
+
if centroid_weights and _CENTROIDS is not None:
|
751 |
+
K = _CENTROIDS.shape[0]
|
752 |
+
for k, w in enumerate(centroid_weights[:K]):
|
753 |
+
try: w = float(w)
|
754 |
+
except: w = 0.0
|
755 |
+
if w <= 0:
|
756 |
+
continue
|
757 |
+
comps.append(_CENTROIDS[k])
|
758 |
+
weights.append(w)
|
759 |
+
|
760 |
+
if not comps:
|
761 |
+
# fallback: neutral style if nothing provided
|
762 |
+
return mrt.embed_style("")
|
763 |
+
|
764 |
+
wsum = sum(weights)
|
765 |
+
if wsum <= 0:
|
766 |
+
return mrt.embed_style("")
|
767 |
+
weights = [w/wsum for w in weights]
|
768 |
+
|
769 |
+
# weighted sum
|
770 |
+
out = np.zeros_like(comps[0], dtype=np.float32)
|
771 |
+
for w, e in zip(weights, comps):
|
772 |
+
out += w * e.astype(np.float32, copy=False)
|
773 |
+
return out
|
774 |
|
775 |
|
776 |
|
|
|
830 |
beats_per_bar = 4
|
831 |
|
832 |
# --- build a silent, stereo context of ctx_seconds ---
|
|
|
833 |
samples = int(max(1, round(ctx_seconds * sr)))
|
834 |
silent = np.zeros((samples, 2), dtype=np.float32)
|
835 |
|
|
|
905 |
# optionally pre-warm here by calling get_mrt()
|
906 |
return {"reloaded": True, "step": step}
|
907 |
|
908 |
+
@app.post("/model/assets/load")
|
909 |
+
def model_assets_load(repo_id: str = Form(None)):
|
910 |
+
ok, msg = _load_finetune_assets_from_hf(repo_id)
|
911 |
+
return {"ok": ok, "message": msg, "repo_id": _ASSETS_REPO_ID,
|
912 |
+
"mean": _MEAN_EMBED is not None,
|
913 |
+
"centroids": None if _CENTROIDS is None else int(_CENTROIDS.shape[0])}
|
914 |
+
|
915 |
+
@app.get("/model/assets/status")
|
916 |
+
def model_assets_status():
|
917 |
+
d = None
|
918 |
+
try:
|
919 |
+
d = int(get_mrt().style_model.config.embedding_dim)
|
920 |
+
except Exception:
|
921 |
+
pass
|
922 |
+
return {
|
923 |
+
"repo_id": _ASSETS_REPO_ID,
|
924 |
+
"mean_loaded": _MEAN_EMBED is not None,
|
925 |
+
"centroids_loaded": False if _CENTROIDS is None else True,
|
926 |
+
"centroid_count": None if _CENTROIDS is None else int(_CENTROIDS.shape[0]),
|
927 |
+
"embedding_dim": d,
|
928 |
+
}
|
929 |
+
|
930 |
@app.post("/generate")
|
931 |
def generate(
|
932 |
loop_audio: UploadFile = File(...),
|
|
|
1093 |
styles: str = Form(""),
|
1094 |
style_weights: str = Form(""),
|
1095 |
loop_weight: float = Form(1.0),
|
1096 |
+
|
1097 |
+
# NEW steering params:
|
1098 |
+
mean: float = Form(0.0),
|
1099 |
+
centroid_weights: str = Form(""),
|
1100 |
+
|
1101 |
loudness_mode: str = Form("auto"),
|
1102 |
loudness_headroom_db: float = Form(1.0),
|
1103 |
guidance_weight: float = Form(1.1),
|
|
|
1105 |
topk: int = Form(40),
|
1106 |
target_sample_rate: int | None = Form(None),
|
1107 |
):
|
1108 |
+
_ensure_assets_loaded()
|
1109 |
+
|
1110 |
# enforce single active jam per GPU
|
1111 |
with jam_lock:
|
1112 |
for sid, w in list(jam_registry.items()):
|
|
|
1127 |
ctx_seconds = float(mrt.config.context_length_frames) / codec_fps
|
1128 |
loop_tail = take_bar_aligned_tail(loop, bpm, beats_per_bar, ctx_seconds)
|
1129 |
|
1130 |
+
# Parse client style fields (preserves your semantics)
|
1131 |
+
text_list = [s.strip() for s in (styles.split(",") if styles else []) if s.strip()]
|
1132 |
+
try:
|
1133 |
+
tw = [float(x) for x in style_weights.split(",")] if style_weights else []
|
1134 |
+
except ValueError:
|
1135 |
+
tw = []
|
1136 |
+
try:
|
1137 |
+
cw = [float(x) for x in centroid_weights.split(",")] if centroid_weights else []
|
1138 |
+
except ValueError:
|
1139 |
+
cw = []
|
1140 |
+
|
1141 |
+
# Compute loop-tail embed once (same as before)
|
1142 |
+
loop_tail_embed = mrt.embed_style(loop_tail)
|
1143 |
+
|
1144 |
+
# Build final style vector:
|
1145 |
+
# - identical to your previous mix when mean==0 and cw is empty
|
1146 |
+
# - otherwise includes mean and centroid components (weights auto-normalized)
|
1147 |
+
style_vec = build_style_vector(
|
1148 |
+
mrt,
|
1149 |
+
text_styles=text_list,
|
1150 |
+
text_weights=tw,
|
1151 |
+
loop_embed=loop_tail_embed,
|
1152 |
+
loop_weight=float(loop_weight),
|
1153 |
+
mean_weight=float(mean),
|
1154 |
+
centroid_weights=cw,
|
1155 |
+
).astype(np.float32, copy=False)
|
1156 |
|
1157 |
# target SR (default input SR)
|
1158 |
inp_info = sf.info(tmp_path)
|
|
|
1241 |
jam_registry.pop(session_id, None)
|
1242 |
return {"stopped": True}
|
1243 |
|
1244 |
+
@app.post("/jam/update")
|
1245 |
def jam_update(
|
1246 |
session_id: str = Form(...),
|
1247 |
|
1248 |
+
# knobs
|
1249 |
guidance_weight: Optional[float] = Form(None),
|
1250 |
temperature: Optional[float] = Form(None),
|
1251 |
topk: Optional[int] = Form(None),
|
1252 |
|
1253 |
+
# styles
|
1254 |
styles: str = Form(""),
|
1255 |
style_weights: str = Form(""),
|
1256 |
+
loop_weight: Optional[float] = Form(None),
|
1257 |
use_current_mix_as_style: bool = Form(False),
|
1258 |
+
|
1259 |
+
# NEW steering
|
1260 |
+
mean: Optional[float] = Form(None),
|
1261 |
+
centroid_weights: str = Form(""),
|
1262 |
):
|
1263 |
+
_ensure_assets_loaded()
|
1264 |
+
|
1265 |
with jam_lock:
|
1266 |
worker = jam_registry.get(session_id)
|
1267 |
if worker is None or not worker.is_alive():
|
1268 |
raise HTTPException(status_code=404, detail="Session not found")
|
1269 |
|
1270 |
+
# 1) fast knob updates
|
1271 |
if any(v is not None for v in (guidance_weight, temperature, topk)):
|
1272 |
worker.update_knobs(
|
1273 |
guidance_weight=guidance_weight,
|
|
|
1275 |
topk=topk
|
1276 |
)
|
1277 |
|
1278 |
+
# 2) rebuild style only if asked
|
1279 |
+
wants_style_update = (
|
1280 |
+
use_current_mix_as_style
|
1281 |
+
or (styles.strip() != "")
|
1282 |
+
or (mean is not None)
|
1283 |
+
or (centroid_weights.strip() != "")
|
1284 |
+
)
|
1285 |
+
if not wants_style_update:
|
1286 |
+
return {"ok": True}
|
1287 |
+
|
1288 |
+
# --- parse inputs (robust) ---
|
1289 |
+
text_list = [s.strip() for s in (styles.split(",") if styles else []) if s.strip()]
|
1290 |
+
try:
|
1291 |
+
tw = [float(x) for x in style_weights.split(",")] if style_weights else []
|
1292 |
+
except ValueError:
|
1293 |
+
tw = []
|
1294 |
+
try:
|
1295 |
+
cw = [float(x) for x in centroid_weights.split(",")] if centroid_weights else []
|
1296 |
+
except ValueError:
|
1297 |
+
cw = []
|
1298 |
+
|
1299 |
+
# Clamp centroid weights to available centroids (if loaded)
|
1300 |
+
max_c = 0 if _CENTROIDS is None else int(_CENTROIDS.shape[0])
|
1301 |
+
if max_c and len(cw) > max_c:
|
1302 |
+
cw = cw[:max_c]
|
1303 |
+
|
1304 |
+
# Snapshot minimal state under lock
|
1305 |
+
with worker._lock:
|
1306 |
+
combined_loop = worker.params.combined_loop if use_current_mix_as_style else None
|
1307 |
+
lw = None
|
1308 |
+
if use_current_mix_as_style:
|
1309 |
+
lw = 1.0 if (loop_weight is None) else float(loop_weight)
|
1310 |
+
mrt = worker.mrt
|
1311 |
+
|
1312 |
+
# Heavy work OUTSIDE the lock
|
1313 |
+
loop_embed = None
|
1314 |
+
if combined_loop is not None:
|
1315 |
+
loop_embed = mrt.embed_style(combined_loop)
|
1316 |
+
|
1317 |
+
style_vec = build_style_vector(
|
1318 |
+
mrt,
|
1319 |
+
text_styles=text_list,
|
1320 |
+
text_weights=tw,
|
1321 |
+
loop_embed=loop_embed, # None => ignored by builder
|
1322 |
+
loop_weight=lw, # None => ignored by builder
|
1323 |
+
mean_weight=(None if mean is None else float(mean)),
|
1324 |
+
centroid_weights=cw, # [] => ignored by builder
|
1325 |
+
).astype(np.float32, copy=False)
|
1326 |
+
|
1327 |
+
# Swap atomically
|
1328 |
+
with worker._lock:
|
1329 |
+
worker.params.style_vec = style_vec
|
1330 |
|
1331 |
return {"ok": True}
|
1332 |
|
1333 |
+
|
1334 |
@app.post("/jam/reseed")
|
1335 |
def jam_reseed(session_id: str = Form(...), loop_audio: UploadFile = File(None)):
|
1336 |
with jam_lock:
|
|
|
1455 |
# ----------------------------
|
1456 |
|
1457 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1458 |
@app.websocket("/ws/jam")
|
1459 |
async def ws_jam(websocket: WebSocket):
|
1460 |
await websocket.accept()
|
|
|
1476 |
# --- START ---
|
1477 |
if mtype == "start":
|
1478 |
binary_audio = bool(msg.get("binary_audio", False))
|
1479 |
+
mode = msg.get("mode", "rt")
|
1480 |
params = msg.get("params", {}) or {}
|
1481 |
sid = msg.get("session_id")
|
1482 |
|
|
|
1555 |
|
1556 |
else:
|
1557 |
# mode == "rt" (Colab-style, no loop context)
|
|
|
1558 |
mrt = get_mrt()
|
1559 |
state = mrt.init_state()
|
1560 |
+
|
1561 |
+
# Build silent context (10s) tokens
|
1562 |
+
codec_fps = float(mrt.codec.frame_rate)
|
1563 |
ctx_seconds = float(mrt.config.context_length_frames) / codec_fps
|
1564 |
sr = int(mrt.sample_rate)
|
1565 |
samples = int(max(1, round(ctx_seconds * sr)))
|
1566 |
+
silent = au.Waveform(np.zeros((samples, 2), np.float32), sr)
|
1567 |
tokens = mrt.codec.encode(silent).astype(np.int32)[:, :mrt.config.decoder_codec_rvq_depth]
|
1568 |
state.context_tokens = tokens
|
1569 |
|
1570 |
+
# Parse params (including steering)
|
1571 |
+
_ensure_assets_loaded()
|
1572 |
+
styles_str = params.get("styles", "warmup") or ""
|
1573 |
+
style_weights_str = params.get("style_weights", "") or ""
|
1574 |
+
mean_w = float(params.get("mean", 0.0) or 0.0)
|
1575 |
+
cw_str = str(params.get("centroid_weights", "") or "")
|
1576 |
+
|
1577 |
+
text_list = [s.strip() for s in styles_str.split(",") if s.strip()]
|
1578 |
+
try:
|
1579 |
+
text_w = [float(x) for x in style_weights_str.split(",")] if style_weights_str else []
|
1580 |
+
except ValueError:
|
1581 |
+
text_w = []
|
1582 |
+
try:
|
1583 |
+
cw = [float(x) for x in cw_str.split(",") if x.strip() != ""]
|
1584 |
+
except ValueError:
|
1585 |
+
cw = []
|
1586 |
+
|
1587 |
+
# Clamp centroid weights to available centroids
|
1588 |
+
if _CENTROIDS is not None and len(cw) > int(_CENTROIDS.shape[0]):
|
1589 |
+
cw = cw[: int(_CENTROIDS.shape[0])]
|
1590 |
+
|
1591 |
+
# Build initial style vector (no loop_embed in rt mode)
|
1592 |
+
style_vec = build_style_vector(
|
1593 |
+
mrt,
|
1594 |
+
text_styles=text_list,
|
1595 |
+
text_weights=text_w,
|
1596 |
+
loop_embed=None,
|
1597 |
+
loop_weight=None,
|
1598 |
+
mean_weight=mean_w,
|
1599 |
+
centroid_weights=cw,
|
1600 |
+
)
|
1601 |
+
|
1602 |
+
# Stash rt session fields
|
1603 |
+
websocket._mrt = mrt
|
1604 |
websocket._state = state
|
1605 |
+
websocket._style = style_vec
|
1606 |
+
|
1607 |
+
websocket._rt_mean = mean_w
|
1608 |
+
websocket._rt_centroid_weights = cw
|
1609 |
+
websocket._rt_running = True
|
1610 |
+
websocket._rt_sr = sr
|
1611 |
+
websocket._rt_topk = int(params.get("topk", 40))
|
1612 |
+
websocket._rt_temp = float(params.get("temperature", 1.1))
|
1613 |
+
websocket._rt_guid = float(params.get("guidance_weight", 1.1))
|
1614 |
+
websocket._pace = params.get("pace", "asap") # "realtime" | "asap"
|
1615 |
+
|
1616 |
+
# (Optional) report whether steering assets were loaded
|
1617 |
+
assets_ok = (_MEAN_EMBED is not None) or (_CENTROIDS is not None)
|
1618 |
+
await send_json({"type": "started", "mode": "rt", "steering_assets": "loaded" if assets_ok else "none"})
|
1619 |
+
|
1620 |
+
# kick off the ~2s streaming loop
|
1621 |
async def _rt_loop():
|
1622 |
try:
|
1623 |
mrt = websocket._mrt
|
1624 |
chunk_secs = (mrt.config.chunk_length_frames * mrt.config.frame_length_samples) / float(mrt.sample_rate)
|
1625 |
target_next = time.perf_counter()
|
1626 |
while websocket._rt_running:
|
|
|
1627 |
mrt.guidance_weight = websocket._rt_guid
|
1628 |
mrt.temperature = websocket._rt_temp
|
1629 |
mrt.topk = websocket._rt_topk
|
|
|
1635 |
buf = io.BytesIO()
|
1636 |
sf.write(buf, x, mrt.sample_rate, subtype="FLOAT", format="WAV")
|
1637 |
|
|
|
1638 |
ok = True
|
1639 |
if binary_audio:
|
1640 |
try:
|
1641 |
await websocket.send_bytes(buf.getvalue())
|
1642 |
+
ok = await send_json({"type": "chunk_meta", "metadata": {"sample_rate": mrt.sample_rate}})
|
1643 |
except Exception:
|
1644 |
ok = False
|
1645 |
else:
|
1646 |
b64 = base64.b64encode(buf.getvalue()).decode("utf-8")
|
1647 |
+
ok = await send_json({"type": "chunk", "audio_base64": b64,
|
1648 |
+
"metadata": {"sample_rate": mrt.sample_rate}})
|
1649 |
|
1650 |
if not ok:
|
|
|
1651 |
break
|
1652 |
|
|
|
1653 |
if getattr(websocket, "_pace", "asap") == "realtime":
|
1654 |
t1 = time.perf_counter()
|
1655 |
target_next += chunk_secs
|
1656 |
sleep_s = max(0.0, target_next - t1 - 0.02)
|
1657 |
if sleep_s > 0:
|
1658 |
await asyncio.sleep(sleep_s)
|
|
|
1659 |
except asyncio.CancelledError:
|
|
|
1660 |
pass
|
1661 |
except Exception:
|
|
|
1662 |
pass
|
1663 |
+
|
1664 |
websocket._rt_task = asyncio.create_task(_rt_loop())
|
1665 |
continue # skip the “bar-mode started” message below
|
1666 |
|
|
|
1706 |
websocket._rt_topk = int(msg.get("topk", websocket._rt_topk))
|
1707 |
websocket._rt_guid = float(msg.get("guidance_weight", websocket._rt_guid))
|
1708 |
|
1709 |
+
# NEW steering fields
|
1710 |
+
if "mean" in msg and msg["mean"] is not None:
|
1711 |
+
try: websocket._rt_mean = float(msg["mean"])
|
1712 |
+
except: websocket._rt_mean = 0.0
|
1713 |
+
|
1714 |
+
if "centroid_weights" in msg:
|
1715 |
+
cw = [w.strip() for w in str(msg["centroid_weights"]).split(",") if w.strip() != ""]
|
1716 |
+
try:
|
1717 |
+
websocket._rt_centroid_weights = [float(x) for x in cw]
|
1718 |
+
except:
|
1719 |
+
websocket._rt_centroid_weights = []
|
1720 |
+
|
1721 |
+
# styles / text weights (optional, comma-separated)
|
1722 |
+
styles_str = msg.get("styles", None)
|
1723 |
+
style_weights_str = msg.get("style_weights", "")
|
1724 |
+
|
1725 |
+
text_list = [s for s in (styles_str.split(",") if styles_str else []) if s.strip()]
|
1726 |
+
text_w = [float(x) for x in style_weights_str.split(",")] if style_weights_str else []
|
1727 |
+
|
1728 |
+
_ensure_assets_loaded()
|
1729 |
+
# build final style vec (no loop embedding in rt-mode)
|
1730 |
+
websocket._style = build_style_vector(
|
1731 |
+
websocket._mrt,
|
1732 |
+
text_styles=text_list,
|
1733 |
+
text_weights=text_w,
|
1734 |
+
loop_embed=None,
|
1735 |
+
loop_weight=None,
|
1736 |
+
mean_weight=float(websocket._rt_mean),
|
1737 |
+
centroid_weights=websocket._rt_centroid_weights,
|
1738 |
+
)
|
1739 |
+
await send_json({"type":"status","updated":"rt-knobs+style"})
|
1740 |
|
1741 |
elif mtype == "consume" and mode == "bar":
|
1742 |
with jam_lock:
|