thecollabagepatch commited on
Commit
a8de318
·
1 Parent(s): 80346b6

attempting to add centroids/mean

Browse files
Files changed (1) hide show
  1. app.py +371 -91
app.py CHANGED
@@ -32,7 +32,7 @@ except Exception:
32
 
33
  from magenta_rt import system, audio as au
34
  import numpy as np
35
- from fastapi import FastAPI, UploadFile, File, Form, Body, HTTPException, Response, Request, WebSocket, WebSocketDisconnect
36
  import tempfile, io, base64, math, threading
37
  from fastapi.middleware.cors import CORSMiddleware
38
  from contextlib import contextmanager
@@ -70,6 +70,78 @@ import re, tarfile
70
  from pathlib import Path
71
  from huggingface_hub import snapshot_download
72
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
  def _resolve_checkpoint_dir() -> str | None:
74
  repo_id = os.getenv("MRT_CKPT_REPO")
75
  if not repo_id:
@@ -540,6 +612,9 @@ def generate_loop_continuation_with_mrt(
540
 
541
  return out, loud_stats
542
 
 
 
 
543
  def generate_style_only_with_mrt(
544
  mrt,
545
  bpm: float,
@@ -610,6 +685,92 @@ def generate_style_only_with_mrt(
610
 
611
  return out, None # loudness stats not applicable (no reference)
612
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
613
 
614
 
615
 
@@ -669,7 +830,6 @@ def _mrt_warmup():
669
  beats_per_bar = 4
670
 
671
  # --- build a silent, stereo context of ctx_seconds ---
672
- import numpy as np, soundfile as sf
673
  samples = int(max(1, round(ctx_seconds * sr)))
674
  silent = np.zeros((samples, 2), dtype=np.float32)
675
 
@@ -745,6 +905,28 @@ def model_swap(step: int = Form(...)):
745
  # optionally pre-warm here by calling get_mrt()
746
  return {"reloaded": True, "step": step}
747
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
748
  @app.post("/generate")
749
  def generate(
750
  loop_audio: UploadFile = File(...),
@@ -911,6 +1093,11 @@ def jam_start(
911
  styles: str = Form(""),
912
  style_weights: str = Form(""),
913
  loop_weight: float = Form(1.0),
 
 
 
 
 
914
  loudness_mode: str = Form("auto"),
915
  loudness_headroom_db: float = Form(1.0),
916
  guidance_weight: float = Form(1.1),
@@ -918,6 +1105,8 @@ def jam_start(
918
  topk: int = Form(40),
919
  target_sample_rate: int | None = Form(None),
920
  ):
 
 
921
  # enforce single active jam per GPU
922
  with jam_lock:
923
  for sid, w in list(jam_registry.items()):
@@ -938,16 +1127,32 @@ def jam_start(
938
  ctx_seconds = float(mrt.config.context_length_frames) / codec_fps
939
  loop_tail = take_bar_aligned_tail(loop, bpm, beats_per_bar, ctx_seconds)
940
 
941
- # style vec = normalized mix of loop_tail + extra styles
942
- embeds, weights = [mrt.embed_style(loop_tail)], [float(loop_weight)]
943
- extra = [s for s in (styles.split(",") if styles else []) if s.strip()]
944
- sw = [float(x) for x in style_weights.split(",")] if style_weights else []
945
- for i, s in enumerate(extra):
946
- embeds.append(mrt.embed_style(s.strip()))
947
- weights.append(sw[i] if i < len(sw) else 1.0)
948
- wsum = sum(weights) or 1.0
949
- weights = [w / wsum for w in weights]
950
- style_vec = np.sum([w * e for w, e in zip(weights, embeds)], axis=0).astype(embeds[0].dtype)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
951
 
952
  # target SR (default input SR)
953
  inp_info = sf.info(tmp_path)
@@ -1036,27 +1241,33 @@ def jam_stop(session_id: str = Body(..., embed=True)):
1036
  jam_registry.pop(session_id, None)
1037
  return {"stopped": True}
1038
 
1039
- @app.post("/jam/update") # consolidated
1040
  def jam_update(
1041
  session_id: str = Form(...),
1042
 
1043
- # knobs (all optional)
1044
  guidance_weight: Optional[float] = Form(None),
1045
  temperature: Optional[float] = Form(None),
1046
  topk: Optional[int] = Form(None),
1047
 
1048
- # styles (all optional)
1049
  styles: str = Form(""),
1050
  style_weights: str = Form(""),
1051
- loop_weight: Optional[float] = Form(None), # None means "don’t change"
1052
  use_current_mix_as_style: bool = Form(False),
 
 
 
 
1053
  ):
 
 
1054
  with jam_lock:
1055
  worker = jam_registry.get(session_id)
1056
  if worker is None or not worker.is_alive():
1057
  raise HTTPException(status_code=404, detail="Session not found")
1058
 
1059
- # --- 1) Apply knob updates (atomic under lock)
1060
  if any(v is not None for v in (guidance_weight, temperature, topk)):
1061
  worker.update_knobs(
1062
  guidance_weight=guidance_weight,
@@ -1064,35 +1275,62 @@ def jam_update(
1064
  topk=topk
1065
  )
1066
 
1067
- # --- 2) Apply style updates only if requested
1068
- wants_style_update = use_current_mix_as_style or (styles.strip() != "")
1069
- if wants_style_update:
1070
- embeds, weights = [], []
1071
-
1072
- # optional: include current mix as a style component
1073
- if use_current_mix_as_style and worker.params.combined_loop is not None:
1074
- lw = 1.0 if loop_weight is None else float(loop_weight)
1075
- embeds.append(worker.mrt.embed_style(worker.params.combined_loop))
1076
- weights.append(lw)
1077
-
1078
- # extra text styles
1079
- extra = [s for s in (styles.split(",") if styles else []) if s.strip()]
1080
- sw = [float(x) for x in style_weights.split(",")] if style_weights else []
1081
- for i, s in enumerate(extra):
1082
- embeds.append(worker.mrt.embed_style(s.strip()))
1083
- weights.append(sw[i] if i < len(sw) else 1.0)
1084
-
1085
- if embeds: # only swap if we actually built something
1086
- wsum = sum(weights) or 1.0
1087
- weights = [w / wsum for w in weights]
1088
- style_vec = np.sum([w * e for w, e in zip(weights, embeds)], axis=0).astype(np.float32)
1089
-
1090
- # install atomically
1091
- with worker._lock:
1092
- worker.params.style_vec = style_vec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1093
 
1094
  return {"ok": True}
1095
 
 
1096
  @app.post("/jam/reseed")
1097
  def jam_reseed(session_id: str = Form(...), loop_audio: UploadFile = File(None)):
1098
  with jam_lock:
@@ -1217,21 +1455,6 @@ async def log_requests(request: Request, call_next):
1217
  # ----------------------------
1218
 
1219
 
1220
-
1221
- def _combine_styles(mrt, styles_str: str = "", weights_str: str = ""):
1222
- extra = [s.strip() for s in (styles_str or "").split(",") if s.strip()]
1223
- if not extra:
1224
- return mrt.embed_style("warmup")
1225
- sw = [float(x) for x in (weights_str or "").split(",") if x.strip()]
1226
- embeds, weights = [], []
1227
- for i, s in enumerate(extra):
1228
- embeds.append(mrt.embed_style(s))
1229
- weights.append(sw[i] if i < len(sw) else 1.0)
1230
- wsum = sum(weights) or 1.0
1231
- weights = [w/wsum for w in weights]
1232
- import numpy as np
1233
- return np.sum([w*e for w, e in zip(weights, embeds)], axis=0).astype(np.float32)
1234
-
1235
  @app.websocket("/ws/jam")
1236
  async def ws_jam(websocket: WebSocket):
1237
  await websocket.accept()
@@ -1253,7 +1476,7 @@ async def ws_jam(websocket: WebSocket):
1253
  # --- START ---
1254
  if mtype == "start":
1255
  binary_audio = bool(msg.get("binary_audio", False))
1256
- mode = msg.get("mode", "bar")
1257
  params = msg.get("params", {}) or {}
1258
  sid = msg.get("session_id")
1259
 
@@ -1332,37 +1555,75 @@ async def ws_jam(websocket: WebSocket):
1332
 
1333
  else:
1334
  # mode == "rt" (Colab-style, no loop context)
1335
- # seed a fresh state with a silent context like warmup
1336
  mrt = get_mrt()
1337
  state = mrt.init_state()
1338
- codec_fps = float(mrt.codec.frame_rate)
 
 
1339
  ctx_seconds = float(mrt.config.context_length_frames) / codec_fps
1340
  sr = int(mrt.sample_rate)
1341
  samples = int(max(1, round(ctx_seconds * sr)))
1342
- silent = au.Waveform(np.zeros((samples,2), np.float32), sr)
1343
  tokens = mrt.codec.encode(silent).astype(np.int32)[:, :mrt.config.decoder_codec_rvq_depth]
1344
  state.context_tokens = tokens
1345
 
1346
- websocket._mrt = mrt
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1347
  websocket._state = state
1348
- websocket._style = _combine_styles(mrt,
1349
- params.get("styles","warmup"),
1350
- params.get("style_weights",""))
1351
- websocket._rt_running = True
1352
- websocket._rt_sr = sr
1353
- websocket._rt_topk = int(params.get("topk", 40))
1354
- websocket._rt_temp = float(params.get("temperature", 1.1))
1355
- websocket._rt_guid = float(params.get("guidance_weight", 1.1))
1356
- websocket._pace = params.get("pace", "asap") # "realtime" | "asap"
1357
- await send_json({"type":"started","mode":"rt"})
1358
- # kick off a background task to stream ~2s chunks
 
 
 
 
 
1359
  async def _rt_loop():
1360
  try:
1361
  mrt = websocket._mrt
1362
  chunk_secs = (mrt.config.chunk_length_frames * mrt.config.frame_length_samples) / float(mrt.sample_rate)
1363
  target_next = time.perf_counter()
1364
  while websocket._rt_running:
1365
- # read knobs (already set by update)
1366
  mrt.guidance_weight = websocket._rt_guid
1367
  mrt.temperature = websocket._rt_temp
1368
  mrt.topk = websocket._rt_topk
@@ -1374,37 +1635,32 @@ async def ws_jam(websocket: WebSocket):
1374
  buf = io.BytesIO()
1375
  sf.write(buf, x, mrt.sample_rate, subtype="FLOAT", format="WAV")
1376
 
1377
- # send bytes / json best-effort
1378
  ok = True
1379
  if binary_audio:
1380
  try:
1381
  await websocket.send_bytes(buf.getvalue())
1382
- ok = await send_json({"type":"chunk_meta","metadata":{"sample_rate":mrt.sample_rate}})
1383
  except Exception:
1384
  ok = False
1385
  else:
1386
  b64 = base64.b64encode(buf.getvalue()).decode("utf-8")
1387
- ok = await send_json({"type":"chunk","audio_base64":b64,
1388
- "metadata":{"sample_rate":mrt.sample_rate}})
1389
 
1390
  if not ok:
1391
- # client went away — exit cleanly
1392
  break
1393
 
1394
- # pacing (use captured flag from start)
1395
  if getattr(websocket, "_pace", "asap") == "realtime":
1396
  t1 = time.perf_counter()
1397
  target_next += chunk_secs
1398
  sleep_s = max(0.0, target_next - t1 - 0.02)
1399
  if sleep_s > 0:
1400
  await asyncio.sleep(sleep_s)
1401
-
1402
  except asyncio.CancelledError:
1403
- # normal on stop/close — just exit
1404
  pass
1405
  except Exception:
1406
- # don't try to send an error; socket may be closed
1407
  pass
 
1408
  websocket._rt_task = asyncio.create_task(_rt_loop())
1409
  continue # skip the “bar-mode started” message below
1410
 
@@ -1450,13 +1706,37 @@ async def ws_jam(websocket: WebSocket):
1450
  websocket._rt_topk = int(msg.get("topk", websocket._rt_topk))
1451
  websocket._rt_guid = float(msg.get("guidance_weight", websocket._rt_guid))
1452
 
1453
- if ("styles" in msg) or ("style_weights" in msg):
1454
- websocket._style = _combine_styles(
1455
- websocket._mrt,
1456
- msg.get("styles", ""),
1457
- msg.get("style_weights", "")
1458
- )
1459
- await send_json({"type":"status","updated":"rt-knobs"})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1460
 
1461
  elif mtype == "consume" and mode == "bar":
1462
  with jam_lock:
 
32
 
33
  from magenta_rt import system, audio as au
34
  import numpy as np
35
+ from fastapi import FastAPI, UploadFile, File, Form, Body, HTTPException, Response, Request, WebSocket, WebSocketDisconnect, Query
36
  import tempfile, io, base64, math, threading
37
  from fastapi.middleware.cors import CORSMiddleware
38
  from contextlib import contextmanager
 
70
  from pathlib import Path
71
  from huggingface_hub import snapshot_download
72
 
73
+ # ---- Finetune assets (mean & centroids) --------------------------------------
74
+ _FINETUNE_REPO_DEFAULT = os.getenv("MRT_ASSETS_REPO", "thepatch/magenta-ft")
75
+ _ASSETS_REPO_ID: str | None = None
76
+ _MEAN_EMBED: np.ndarray | None = None # shape (D,) dtype float32
77
+ _CENTROIDS: np.ndarray | None = None # shape (K, D) dtype float32
78
+
79
+ def _load_finetune_assets_from_hf(repo_id: str | None) -> tuple[bool, str]:
80
+ """
81
+ Download & load mean_style_embed.npy and cluster_centroids.npy from a HF model repo.
82
+ Safe to call multiple times; will overwrite globals if successful.
83
+ """
84
+ global _ASSETS_REPO_ID, _MEAN_EMBED, _CENTROIDS
85
+ repo_id = repo_id or _FINETUNE_REPO_DEFAULT
86
+ try:
87
+ from huggingface_hub import hf_hub_download
88
+ mean_path = None
89
+ cent_path = None
90
+ try:
91
+ mean_path = hf_hub_download(repo_id, filename="mean_style_embed.npy", repo_type="model")
92
+ except Exception:
93
+ pass
94
+ try:
95
+ cent_path = hf_hub_download(repo_id, filename="cluster_centroids.npy", repo_type="model")
96
+ except Exception:
97
+ pass
98
+
99
+ if mean_path is None and cent_path is None:
100
+ return False, f"No finetune asset files found in repo {repo_id}"
101
+
102
+ if mean_path is not None:
103
+ m = np.load(mean_path)
104
+ if m.ndim != 1:
105
+ return False, f"mean_style_embed.npy must be 1-D (got {m.shape})"
106
+ else:
107
+ m = None
108
+
109
+ if cent_path is not None:
110
+ c = np.load(cent_path)
111
+ if c.ndim != 2:
112
+ return False, f"cluster_centroids.npy must be 2-D (got {c.shape})"
113
+ else:
114
+ c = None
115
+
116
+ # Optional: shape check vs model embedding dim once model is alive
117
+ try:
118
+ d = int(get_mrt().style_model.config.embedding_dim)
119
+ if m is not None and m.shape[0] != d:
120
+ return False, f"mean_style_embed dim {m.shape[0]} != model dim {d}"
121
+ if c is not None and c.shape[1] != d:
122
+ return False, f"cluster_centroids dim {c.shape[1]} != model dim {d}"
123
+ except Exception:
124
+ # Model not built yet; we’ll trust the files and rely on runtime checks later
125
+ pass
126
+
127
+ _MEAN_EMBED = m.astype(np.float32, copy=False) if m is not None else None
128
+ _CENTROIDS = c.astype(np.float32, copy=False) if c is not None else None
129
+ _ASSETS_REPO_ID = repo_id
130
+ logging.info("Loaded finetune assets from %s (mean=%s, centroids=%s)",
131
+ repo_id,
132
+ "yes" if _MEAN_EMBED is not None else "no",
133
+ f"{_CENTROIDS.shape[0]}x{_CENTROIDS.shape[1]}" if _CENTROIDS is not None else "no")
134
+ return True, "ok"
135
+ except Exception as e:
136
+ logging.exception("Failed to load finetune assets: %s", e)
137
+ return False, str(e)
138
+
139
+ def _ensure_assets_loaded():
140
+ # Best-effort lazy load if nothing is loaded yet
141
+ if _MEAN_EMBED is None and _CENTROIDS is None:
142
+ _load_finetune_assets_from_hf(_ASSETS_REPO_ID or _FINETUNE_REPO_DEFAULT)
143
+ # ------------------------------------------------------------------------------
144
+
145
  def _resolve_checkpoint_dir() -> str | None:
146
  repo_id = os.getenv("MRT_CKPT_REPO")
147
  if not repo_id:
 
612
 
613
  return out, loud_stats
614
 
615
+ # untested.
616
+ # not sure how it will retain the input bpm. we may want to use a metronome instead of silence. i think google might do that.
617
+ # does a generation with silent context rather than a combined loop
618
  def generate_style_only_with_mrt(
619
  mrt,
620
  bpm: float,
 
685
 
686
  return out, None # loudness stats not applicable (no reference)
687
 
688
+ def _combine_styles(mrt, styles_str: str = "", weights_str: str = ""):
689
+ extra = [s.strip() for s in (styles_str or "").split(",") if s.strip()]
690
+ if not extra:
691
+ return mrt.embed_style("warmup")
692
+ sw = [float(x) for x in (weights_str or "").split(",") if x.strip()]
693
+ embeds, weights = [], []
694
+ for i, s in enumerate(extra):
695
+ embeds.append(mrt.embed_style(s))
696
+ weights.append(sw[i] if i < len(sw) else 1.0)
697
+ wsum = sum(weights) or 1.0
698
+ weights = [w/wsum for w in weights]
699
+ import numpy as np
700
+ return np.sum([w*e for w, e in zip(weights, embeds)], axis=0).astype(np.float32)
701
+
702
+ def build_style_vector(
703
+ mrt,
704
+ *,
705
+ text_styles: list[str] | None = None,
706
+ text_weights: list[float] | None = None,
707
+ loop_embed: np.ndarray | None = None,
708
+ loop_weight: float | None = None,
709
+ mean_weight: float | None = None,
710
+ centroid_weights: list[float] | None = None,
711
+ ) -> np.ndarray:
712
+ """
713
+ Returns a single style embedding combining:
714
+ - loop embedding (optional)
715
+ - one or more text style embeddings (optional)
716
+ - mean finetune embedding (optional)
717
+ - centroid embeddings (optional)
718
+ All weights are normalized so they sum to 1 if > 0.
719
+ """
720
+ comps: list[np.ndarray] = []
721
+ weights: list[float] = []
722
+
723
+ # loop component
724
+ if loop_embed is not None and (loop_weight or 0) > 0:
725
+ comps.append(loop_embed.astype(np.float32, copy=False))
726
+ weights.append(float(loop_weight))
727
+
728
+ # text components
729
+ if text_styles:
730
+ for i, s in enumerate(text_styles):
731
+ s = s.strip()
732
+ if not s:
733
+ continue
734
+ w = 1.0
735
+ if text_weights and i < len(text_weights):
736
+ try: w = float(text_weights[i])
737
+ except: w = 1.0
738
+ if w <= 0:
739
+ continue
740
+ e = mrt.embed_style(s)
741
+ comps.append(e.astype(np.float32, copy=False))
742
+ weights.append(w)
743
+
744
+ # mean finetune
745
+ if mean_weight and (_MEAN_EMBED is not None) and mean_weight > 0:
746
+ comps.append(_MEAN_EMBED)
747
+ weights.append(float(mean_weight))
748
+
749
+ # centroid components
750
+ if centroid_weights and _CENTROIDS is not None:
751
+ K = _CENTROIDS.shape[0]
752
+ for k, w in enumerate(centroid_weights[:K]):
753
+ try: w = float(w)
754
+ except: w = 0.0
755
+ if w <= 0:
756
+ continue
757
+ comps.append(_CENTROIDS[k])
758
+ weights.append(w)
759
+
760
+ if not comps:
761
+ # fallback: neutral style if nothing provided
762
+ return mrt.embed_style("")
763
+
764
+ wsum = sum(weights)
765
+ if wsum <= 0:
766
+ return mrt.embed_style("")
767
+ weights = [w/wsum for w in weights]
768
+
769
+ # weighted sum
770
+ out = np.zeros_like(comps[0], dtype=np.float32)
771
+ for w, e in zip(weights, comps):
772
+ out += w * e.astype(np.float32, copy=False)
773
+ return out
774
 
775
 
776
 
 
830
  beats_per_bar = 4
831
 
832
  # --- build a silent, stereo context of ctx_seconds ---
 
833
  samples = int(max(1, round(ctx_seconds * sr)))
834
  silent = np.zeros((samples, 2), dtype=np.float32)
835
 
 
905
  # optionally pre-warm here by calling get_mrt()
906
  return {"reloaded": True, "step": step}
907
 
908
+ @app.post("/model/assets/load")
909
+ def model_assets_load(repo_id: str = Form(None)):
910
+ ok, msg = _load_finetune_assets_from_hf(repo_id)
911
+ return {"ok": ok, "message": msg, "repo_id": _ASSETS_REPO_ID,
912
+ "mean": _MEAN_EMBED is not None,
913
+ "centroids": None if _CENTROIDS is None else int(_CENTROIDS.shape[0])}
914
+
915
+ @app.get("/model/assets/status")
916
+ def model_assets_status():
917
+ d = None
918
+ try:
919
+ d = int(get_mrt().style_model.config.embedding_dim)
920
+ except Exception:
921
+ pass
922
+ return {
923
+ "repo_id": _ASSETS_REPO_ID,
924
+ "mean_loaded": _MEAN_EMBED is not None,
925
+ "centroids_loaded": False if _CENTROIDS is None else True,
926
+ "centroid_count": None if _CENTROIDS is None else int(_CENTROIDS.shape[0]),
927
+ "embedding_dim": d,
928
+ }
929
+
930
  @app.post("/generate")
931
  def generate(
932
  loop_audio: UploadFile = File(...),
 
1093
  styles: str = Form(""),
1094
  style_weights: str = Form(""),
1095
  loop_weight: float = Form(1.0),
1096
+
1097
+ # NEW steering params:
1098
+ mean: float = Form(0.0),
1099
+ centroid_weights: str = Form(""),
1100
+
1101
  loudness_mode: str = Form("auto"),
1102
  loudness_headroom_db: float = Form(1.0),
1103
  guidance_weight: float = Form(1.1),
 
1105
  topk: int = Form(40),
1106
  target_sample_rate: int | None = Form(None),
1107
  ):
1108
+ _ensure_assets_loaded()
1109
+
1110
  # enforce single active jam per GPU
1111
  with jam_lock:
1112
  for sid, w in list(jam_registry.items()):
 
1127
  ctx_seconds = float(mrt.config.context_length_frames) / codec_fps
1128
  loop_tail = take_bar_aligned_tail(loop, bpm, beats_per_bar, ctx_seconds)
1129
 
1130
+ # Parse client style fields (preserves your semantics)
1131
+ text_list = [s.strip() for s in (styles.split(",") if styles else []) if s.strip()]
1132
+ try:
1133
+ tw = [float(x) for x in style_weights.split(",")] if style_weights else []
1134
+ except ValueError:
1135
+ tw = []
1136
+ try:
1137
+ cw = [float(x) for x in centroid_weights.split(",")] if centroid_weights else []
1138
+ except ValueError:
1139
+ cw = []
1140
+
1141
+ # Compute loop-tail embed once (same as before)
1142
+ loop_tail_embed = mrt.embed_style(loop_tail)
1143
+
1144
+ # Build final style vector:
1145
+ # - identical to your previous mix when mean==0 and cw is empty
1146
+ # - otherwise includes mean and centroid components (weights auto-normalized)
1147
+ style_vec = build_style_vector(
1148
+ mrt,
1149
+ text_styles=text_list,
1150
+ text_weights=tw,
1151
+ loop_embed=loop_tail_embed,
1152
+ loop_weight=float(loop_weight),
1153
+ mean_weight=float(mean),
1154
+ centroid_weights=cw,
1155
+ ).astype(np.float32, copy=False)
1156
 
1157
  # target SR (default input SR)
1158
  inp_info = sf.info(tmp_path)
 
1241
  jam_registry.pop(session_id, None)
1242
  return {"stopped": True}
1243
 
1244
+ @app.post("/jam/update")
1245
  def jam_update(
1246
  session_id: str = Form(...),
1247
 
1248
+ # knobs
1249
  guidance_weight: Optional[float] = Form(None),
1250
  temperature: Optional[float] = Form(None),
1251
  topk: Optional[int] = Form(None),
1252
 
1253
+ # styles
1254
  styles: str = Form(""),
1255
  style_weights: str = Form(""),
1256
+ loop_weight: Optional[float] = Form(None),
1257
  use_current_mix_as_style: bool = Form(False),
1258
+
1259
+ # NEW steering
1260
+ mean: Optional[float] = Form(None),
1261
+ centroid_weights: str = Form(""),
1262
  ):
1263
+ _ensure_assets_loaded()
1264
+
1265
  with jam_lock:
1266
  worker = jam_registry.get(session_id)
1267
  if worker is None or not worker.is_alive():
1268
  raise HTTPException(status_code=404, detail="Session not found")
1269
 
1270
+ # 1) fast knob updates
1271
  if any(v is not None for v in (guidance_weight, temperature, topk)):
1272
  worker.update_knobs(
1273
  guidance_weight=guidance_weight,
 
1275
  topk=topk
1276
  )
1277
 
1278
+ # 2) rebuild style only if asked
1279
+ wants_style_update = (
1280
+ use_current_mix_as_style
1281
+ or (styles.strip() != "")
1282
+ or (mean is not None)
1283
+ or (centroid_weights.strip() != "")
1284
+ )
1285
+ if not wants_style_update:
1286
+ return {"ok": True}
1287
+
1288
+ # --- parse inputs (robust) ---
1289
+ text_list = [s.strip() for s in (styles.split(",") if styles else []) if s.strip()]
1290
+ try:
1291
+ tw = [float(x) for x in style_weights.split(",")] if style_weights else []
1292
+ except ValueError:
1293
+ tw = []
1294
+ try:
1295
+ cw = [float(x) for x in centroid_weights.split(",")] if centroid_weights else []
1296
+ except ValueError:
1297
+ cw = []
1298
+
1299
+ # Clamp centroid weights to available centroids (if loaded)
1300
+ max_c = 0 if _CENTROIDS is None else int(_CENTROIDS.shape[0])
1301
+ if max_c and len(cw) > max_c:
1302
+ cw = cw[:max_c]
1303
+
1304
+ # Snapshot minimal state under lock
1305
+ with worker._lock:
1306
+ combined_loop = worker.params.combined_loop if use_current_mix_as_style else None
1307
+ lw = None
1308
+ if use_current_mix_as_style:
1309
+ lw = 1.0 if (loop_weight is None) else float(loop_weight)
1310
+ mrt = worker.mrt
1311
+
1312
+ # Heavy work OUTSIDE the lock
1313
+ loop_embed = None
1314
+ if combined_loop is not None:
1315
+ loop_embed = mrt.embed_style(combined_loop)
1316
+
1317
+ style_vec = build_style_vector(
1318
+ mrt,
1319
+ text_styles=text_list,
1320
+ text_weights=tw,
1321
+ loop_embed=loop_embed, # None => ignored by builder
1322
+ loop_weight=lw, # None => ignored by builder
1323
+ mean_weight=(None if mean is None else float(mean)),
1324
+ centroid_weights=cw, # [] => ignored by builder
1325
+ ).astype(np.float32, copy=False)
1326
+
1327
+ # Swap atomically
1328
+ with worker._lock:
1329
+ worker.params.style_vec = style_vec
1330
 
1331
  return {"ok": True}
1332
 
1333
+
1334
  @app.post("/jam/reseed")
1335
  def jam_reseed(session_id: str = Form(...), loop_audio: UploadFile = File(None)):
1336
  with jam_lock:
 
1455
  # ----------------------------
1456
 
1457
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1458
  @app.websocket("/ws/jam")
1459
  async def ws_jam(websocket: WebSocket):
1460
  await websocket.accept()
 
1476
  # --- START ---
1477
  if mtype == "start":
1478
  binary_audio = bool(msg.get("binary_audio", False))
1479
+ mode = msg.get("mode", "rt")
1480
  params = msg.get("params", {}) or {}
1481
  sid = msg.get("session_id")
1482
 
 
1555
 
1556
  else:
1557
  # mode == "rt" (Colab-style, no loop context)
 
1558
  mrt = get_mrt()
1559
  state = mrt.init_state()
1560
+
1561
+ # Build silent context (10s) tokens
1562
+ codec_fps = float(mrt.codec.frame_rate)
1563
  ctx_seconds = float(mrt.config.context_length_frames) / codec_fps
1564
  sr = int(mrt.sample_rate)
1565
  samples = int(max(1, round(ctx_seconds * sr)))
1566
+ silent = au.Waveform(np.zeros((samples, 2), np.float32), sr)
1567
  tokens = mrt.codec.encode(silent).astype(np.int32)[:, :mrt.config.decoder_codec_rvq_depth]
1568
  state.context_tokens = tokens
1569
 
1570
+ # Parse params (including steering)
1571
+ _ensure_assets_loaded()
1572
+ styles_str = params.get("styles", "warmup") or ""
1573
+ style_weights_str = params.get("style_weights", "") or ""
1574
+ mean_w = float(params.get("mean", 0.0) or 0.0)
1575
+ cw_str = str(params.get("centroid_weights", "") or "")
1576
+
1577
+ text_list = [s.strip() for s in styles_str.split(",") if s.strip()]
1578
+ try:
1579
+ text_w = [float(x) for x in style_weights_str.split(",")] if style_weights_str else []
1580
+ except ValueError:
1581
+ text_w = []
1582
+ try:
1583
+ cw = [float(x) for x in cw_str.split(",") if x.strip() != ""]
1584
+ except ValueError:
1585
+ cw = []
1586
+
1587
+ # Clamp centroid weights to available centroids
1588
+ if _CENTROIDS is not None and len(cw) > int(_CENTROIDS.shape[0]):
1589
+ cw = cw[: int(_CENTROIDS.shape[0])]
1590
+
1591
+ # Build initial style vector (no loop_embed in rt mode)
1592
+ style_vec = build_style_vector(
1593
+ mrt,
1594
+ text_styles=text_list,
1595
+ text_weights=text_w,
1596
+ loop_embed=None,
1597
+ loop_weight=None,
1598
+ mean_weight=mean_w,
1599
+ centroid_weights=cw,
1600
+ )
1601
+
1602
+ # Stash rt session fields
1603
+ websocket._mrt = mrt
1604
  websocket._state = state
1605
+ websocket._style = style_vec
1606
+
1607
+ websocket._rt_mean = mean_w
1608
+ websocket._rt_centroid_weights = cw
1609
+ websocket._rt_running = True
1610
+ websocket._rt_sr = sr
1611
+ websocket._rt_topk = int(params.get("topk", 40))
1612
+ websocket._rt_temp = float(params.get("temperature", 1.1))
1613
+ websocket._rt_guid = float(params.get("guidance_weight", 1.1))
1614
+ websocket._pace = params.get("pace", "asap") # "realtime" | "asap"
1615
+
1616
+ # (Optional) report whether steering assets were loaded
1617
+ assets_ok = (_MEAN_EMBED is not None) or (_CENTROIDS is not None)
1618
+ await send_json({"type": "started", "mode": "rt", "steering_assets": "loaded" if assets_ok else "none"})
1619
+
1620
+ # kick off the ~2s streaming loop
1621
  async def _rt_loop():
1622
  try:
1623
  mrt = websocket._mrt
1624
  chunk_secs = (mrt.config.chunk_length_frames * mrt.config.frame_length_samples) / float(mrt.sample_rate)
1625
  target_next = time.perf_counter()
1626
  while websocket._rt_running:
 
1627
  mrt.guidance_weight = websocket._rt_guid
1628
  mrt.temperature = websocket._rt_temp
1629
  mrt.topk = websocket._rt_topk
 
1635
  buf = io.BytesIO()
1636
  sf.write(buf, x, mrt.sample_rate, subtype="FLOAT", format="WAV")
1637
 
 
1638
  ok = True
1639
  if binary_audio:
1640
  try:
1641
  await websocket.send_bytes(buf.getvalue())
1642
+ ok = await send_json({"type": "chunk_meta", "metadata": {"sample_rate": mrt.sample_rate}})
1643
  except Exception:
1644
  ok = False
1645
  else:
1646
  b64 = base64.b64encode(buf.getvalue()).decode("utf-8")
1647
+ ok = await send_json({"type": "chunk", "audio_base64": b64,
1648
+ "metadata": {"sample_rate": mrt.sample_rate}})
1649
 
1650
  if not ok:
 
1651
  break
1652
 
 
1653
  if getattr(websocket, "_pace", "asap") == "realtime":
1654
  t1 = time.perf_counter()
1655
  target_next += chunk_secs
1656
  sleep_s = max(0.0, target_next - t1 - 0.02)
1657
  if sleep_s > 0:
1658
  await asyncio.sleep(sleep_s)
 
1659
  except asyncio.CancelledError:
 
1660
  pass
1661
  except Exception:
 
1662
  pass
1663
+
1664
  websocket._rt_task = asyncio.create_task(_rt_loop())
1665
  continue # skip the “bar-mode started” message below
1666
 
 
1706
  websocket._rt_topk = int(msg.get("topk", websocket._rt_topk))
1707
  websocket._rt_guid = float(msg.get("guidance_weight", websocket._rt_guid))
1708
 
1709
+ # NEW steering fields
1710
+ if "mean" in msg and msg["mean"] is not None:
1711
+ try: websocket._rt_mean = float(msg["mean"])
1712
+ except: websocket._rt_mean = 0.0
1713
+
1714
+ if "centroid_weights" in msg:
1715
+ cw = [w.strip() for w in str(msg["centroid_weights"]).split(",") if w.strip() != ""]
1716
+ try:
1717
+ websocket._rt_centroid_weights = [float(x) for x in cw]
1718
+ except:
1719
+ websocket._rt_centroid_weights = []
1720
+
1721
+ # styles / text weights (optional, comma-separated)
1722
+ styles_str = msg.get("styles", None)
1723
+ style_weights_str = msg.get("style_weights", "")
1724
+
1725
+ text_list = [s for s in (styles_str.split(",") if styles_str else []) if s.strip()]
1726
+ text_w = [float(x) for x in style_weights_str.split(",")] if style_weights_str else []
1727
+
1728
+ _ensure_assets_loaded()
1729
+ # build final style vec (no loop embedding in rt-mode)
1730
+ websocket._style = build_style_vector(
1731
+ websocket._mrt,
1732
+ text_styles=text_list,
1733
+ text_weights=text_w,
1734
+ loop_embed=None,
1735
+ loop_weight=None,
1736
+ mean_weight=float(websocket._rt_mean),
1737
+ centroid_weights=websocket._rt_centroid_weights,
1738
+ )
1739
+ await send_json({"type":"status","updated":"rt-knobs+style"})
1740
 
1741
  elif mtype == "consume" and mode == "bar":
1742
  with jam_lock: