thecollabagepatch commited on
Commit
2bd198e
·
1 Parent(s): 15a3b0b

websocket route fix

Browse files
Files changed (1) hide show
  1. app.py +264 -2
app.py CHANGED
@@ -1,6 +1,6 @@
1
  from magenta_rt import system, audio as au
2
  import numpy as np
3
- from fastapi import FastAPI, UploadFile, File, Form, Body, HTTPException, Response, Request
4
  import tempfile, io, base64, math, threading
5
  from fastapi.middleware.cors import CORSMiddleware
6
  from contextlib import contextmanager
@@ -21,6 +21,9 @@ import logging
21
  import gradio as gr
22
  from typing import Optional
23
 
 
 
 
24
  # --- Patch T5X mesh helpers for GPUs on JAX >= 0.7 (coords present, no core_on_chip) ---
25
  def _patch_t5x_for_gpu_coords():
26
  try:
@@ -897,4 +900,263 @@ def read_root():
897
  </body>
898
  </html>
899
  """
900
- return Response(content=html_content, media_type="text/html")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from magenta_rt import system, audio as au
2
  import numpy as np
3
+ from fastapi import FastAPI, UploadFile, File, Form, Body, HTTPException, Response, Request, WebSocket, WebSocketDisconnect
4
  import tempfile, io, base64, math, threading
5
  from fastapi.middleware.cors import CORSMiddleware
6
  from contextlib import contextmanager
 
21
  import gradio as gr
22
  from typing import Optional
23
 
24
+
25
+ import json, asyncio, base64
26
+
27
  # --- Patch T5X mesh helpers for GPUs on JAX >= 0.7 (coords present, no core_on_chip) ---
28
  def _patch_t5x_for_gpu_coords():
29
  try:
 
900
  </body>
901
  </html>
902
  """
903
+ return Response(content=html_content, media_type="text/html")
904
+
905
+ @app.websocket("/ws/jam")
906
+ async def ws_jam(websocket: WebSocket):
907
+ await websocket.accept()
908
+ sid = None
909
+ worker = None
910
+ binary_audio = False
911
+ mode = "rt" # or "bar"
912
+
913
+ async def send_json(obj):
914
+ await websocket.send_text(json.dumps(obj))
915
+
916
+ try:
917
+ while True:
918
+ raw = await websocket.receive_text()
919
+ msg = json.loads(raw)
920
+ mtype = msg.get("type")
921
+
922
+ # --- START ---
923
+ if mtype == "start":
924
+ binary_audio = bool(msg.get("binary_audio", False))
925
+ mode = msg.get("mode", "bar")
926
+ params = msg.get("params", {}) or {}
927
+ sid = msg.get("session_id")
928
+
929
+ # attach or create
930
+ if sid:
931
+ with jam_lock:
932
+ worker = jam_registry.get(sid)
933
+ if worker is None or not worker.is_alive():
934
+ await send_json({"type":"error","error":"Session not found"})
935
+ continue
936
+ else:
937
+ # optionally accept base64 loop and start a new worker (bar-mode)
938
+ if mode == "bar":
939
+ loop_b64 = msg.get("loop_audio_b64")
940
+ if not loop_b64:
941
+ await send_json({"type":"error","error":"loop_audio_b64 required for mode=bar when no session_id"})
942
+ continue
943
+ loop_bytes = base64.b64decode(loop_b64)
944
+ # mimic /jam/start
945
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp:
946
+ tmp.write(loop_bytes); tmp_path = tmp.name
947
+ # build JamParams similar to /jam/start
948
+ mrt = get_mrt()
949
+ model_sr = int(mrt.sample_rate) # typically 48000
950
+ # Defaults for WS: raw loudness @ model SR, unless overridden by client:
951
+ target_sr = int(params.get("target_sr", model_sr))
952
+ loudness_mode = params.get("loudness_mode", "none")
953
+ headroom_db = float(params.get("headroom_db", 1.0))
954
+ loop = au.Waveform.from_file(tmp_path).resample(mrt.sample_rate).as_stereo()
955
+
956
+ codec_fps = float(mrt.codec.frame_rate)
957
+ ctx_seconds = float(mrt.config.context_length_frames) / codec_fps
958
+ bpm = float(params.get("bpm", 120.0))
959
+ bpb = int(params.get("beats_per_bar", 4))
960
+ loop_tail = take_bar_aligned_tail(loop, bpm, bpb, ctx_seconds)
961
+
962
+ # style vector (loop + extra styles)
963
+ embeds, weights = [mrt.embed_style(loop_tail)], [float(params.get("loop_weight", 1.0))]
964
+ extra = [s for s in (params.get("styles","").split(",")) if s.strip()]
965
+ sw = [float(x) for x in params.get("style_weights","").split(",") if x.strip()]
966
+ for i, s in enumerate(extra):
967
+ embeds.append(mrt.embed_style(s.strip()))
968
+ weights.append(sw[i] if i < len(sw) else 1.0)
969
+ wsum = sum(weights) or 1.0
970
+ weights = [w/wsum for w in weights]
971
+ style_vec = np.sum([w*e for w, e in zip(weights, embeds)], axis=0).astype(np.float32)
972
+
973
+ # target SR fallback: input SR
974
+ inp_info = sf.info(tmp_path)
975
+ target_sr = int(params.get("target_sr", int(inp_info.samplerate)))
976
+
977
+ # Build JamParams for WS bar-mode
978
+ jp = JamParams(
979
+ bpm=bpm, beats_per_bar=bpb, bars_per_chunk=int(params.get("bars_per_chunk", 8)),
980
+ target_sr=target_sr,
981
+ loudness_mode=loudness_mode, headroom_db=headroom_db,
982
+ style_vec=style_vec,
983
+ ref_loop=None if loudness_mode == "none" else loop_tail, # disable match by default
984
+ combined_loop=loop,
985
+ guidance_weight=float(params.get("guidance_weight", 1.1)),
986
+ temperature=float(params.get("temperature", 1.1)),
987
+ topk=int(params.get("topk", 40)),
988
+ )
989
+ worker = JamWorker(get_mrt(), jp)
990
+ sid = str(uuid.uuid4())
991
+ with jam_lock:
992
+ # single active jam per GPU, mirroring /jam/start
993
+ for _sid, w in list(jam_registry.items()):
994
+ if w.is_alive():
995
+ await send_json({"type":"error","error":"A jam is already running"})
996
+ worker = None; sid = None
997
+ break
998
+ if worker is not None:
999
+ jam_registry[sid] = worker
1000
+ worker.start()
1001
+
1002
+ else:
1003
+ # mode == "rt" (Colab-style, no loop context)
1004
+ # seed a fresh state with a silent context like warmup
1005
+ mrt = get_mrt()
1006
+ state = mrt.init_state()
1007
+ # build exact-length silent context
1008
+ codec_fps = float(mrt.codec.frame_rate)
1009
+ ctx_seconds = float(mrt.config.context_length_frames) / codec_fps
1010
+ sr = int(mrt.sample_rate)
1011
+ samples = int(max(1, round(ctx_seconds * sr)))
1012
+ silent = au.Waveform(np.zeros((samples,2), np.float32), sr)
1013
+ tokens_full = mrt.codec.encode(silent).astype(np.int32)
1014
+ tokens = tokens_full[:, :mrt.config.decoder_codec_rvq_depth]
1015
+ state.context_tokens = tokens
1016
+ # keep local “rt” loop state
1017
+ websocket._mrt = mrt
1018
+ websocket._state = state
1019
+ websocket._style = mrt.embed_style(params.get("styles","warmup"))
1020
+ websocket._rt_running = True
1021
+ websocket._rt_sr = sr
1022
+ websocket._rt_topk = int(params.get("topk", 40))
1023
+ websocket._rt_temp = float(params.get("temperature", 1.1))
1024
+ websocket._rt_guid = float(params.get("guidance_weight", 1.1))
1025
+ await send_json({"type":"started","mode":"rt"})
1026
+ # kick off a background task to stream ~2s chunks
1027
+ async def _rt_loop():
1028
+ try:
1029
+ while websocket._rt_running:
1030
+ mrt.guidance_weight = websocket._rt_guid
1031
+ mrt.temperature = websocket._rt_temp
1032
+ mrt.topk = websocket._rt_topk
1033
+ wav, new_state = mrt.generate_chunk(state=websocket._state, style=websocket._style)
1034
+ websocket._state = new_state
1035
+ x = wav.samples.astype(np.float32, copy=False)
1036
+ buf = io.BytesIO()
1037
+ sf.write(buf, x, mrt.sample_rate, subtype="FLOAT", format="WAV")
1038
+ if binary_audio:
1039
+ await websocket.send_bytes(buf.getvalue())
1040
+ await send_json({"type":"chunk_meta","metadata":{"sample_rate":mrt.sample_rate}})
1041
+ else:
1042
+ b64 = base64.b64encode(buf.getvalue()).decode("utf-8")
1043
+ await send_json({"type":"chunk","audio_base64":b64,
1044
+ "metadata":{"sample_rate":mrt.sample_rate}})
1045
+ except Exception as e:
1046
+ await send_json({"type":"error","error":str(e)})
1047
+ asyncio.create_task(_rt_loop())
1048
+ continue # skip the “bar-mode started” message below
1049
+
1050
+ await send_json({"type":"started","session_id": sid, "mode": mode})
1051
+
1052
+ # if we’re in bar-mode, begin pushing chunks as they arrive
1053
+ if mode == "bar" and worker is not None:
1054
+ async def _pump():
1055
+ while True:
1056
+ if not worker.is_alive():
1057
+ break
1058
+ chunk = worker.get_next_chunk(timeout=60.0)
1059
+ if chunk is None:
1060
+ continue
1061
+ if binary_audio:
1062
+ await websocket.send_bytes(base64.b64decode(chunk.audio_base64))
1063
+ await send_json({"type":"chunk_meta","index":chunk.index,"metadata":chunk.metadata})
1064
+ else:
1065
+ await send_json({"type":"chunk","index":chunk.index,
1066
+ "audio_base64":chunk.audio_base64,"metadata":chunk.metadata})
1067
+ asyncio.create_task(_pump())
1068
+
1069
+ # --- UPDATES (bar or rt) ---
1070
+ elif mtype == "update":
1071
+ if mode == "bar":
1072
+ if not sid:
1073
+ await send_json({"type":"error","error":"No session_id yet"}); return
1074
+ # fan values straight into your existing HTTP handler:
1075
+ res = jam_update(
1076
+ session_id=sid,
1077
+ guidance_weight=msg.get("guidance_weight"),
1078
+ temperature=msg.get("temperature"),
1079
+ topk=msg.get("topk"),
1080
+ styles=msg.get("styles",""),
1081
+ style_weights=msg.get("style_weights",""),
1082
+ loop_weight=msg.get("loop_weight"),
1083
+ use_current_mix_as_style=bool(msg.get("use_current_mix_as_style", False)),
1084
+ )
1085
+ await send_json({"type":"status", **res}) # {"ok": True}
1086
+ else:
1087
+ # rt-mode: there’s no JamWorker; update the local knobs/state
1088
+ websocket._rt_temp = float(msg.get("temperature", websocket._rt_temp))
1089
+ websocket._rt_topk = int(msg.get("topk", websocket._rt_topk))
1090
+ websocket._rt_guid = float(msg.get("guidance_weight", websocket._rt_guid))
1091
+ if "styles" in msg:
1092
+ websocket._style = websocket._mrt.embed_style(msg["styles"])
1093
+ await send_json({"type":"status","updated":"rt-knobs"})
1094
+
1095
+ elif mtype == "consume" and mode == "bar":
1096
+ with jam_lock:
1097
+ worker = jam_registry.get(msg.get("session_id"))
1098
+ if worker is not None:
1099
+ worker.mark_chunk_consumed(int(msg.get("chunk_index", -1)))
1100
+
1101
+ elif mtype == "reseed" and mode == "bar":
1102
+ with jam_lock:
1103
+ worker = jam_registry.get(msg.get("session_id"))
1104
+ if worker is None or not worker.is_alive():
1105
+ await send_json({"type":"error","error":"Session not found"}); continue
1106
+ loop_b64 = msg.get("loop_audio_b64")
1107
+ if not loop_b64:
1108
+ await send_json({"type":"error","error":"loop_audio_b64 required"}); continue
1109
+ loop_bytes = base64.b64decode(loop_b64)
1110
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp:
1111
+ tmp.write(loop_bytes); path = tmp.name
1112
+ wav = au.Waveform.from_file(path).resample(worker.mrt.sample_rate).as_stereo()
1113
+ worker.reseed_from_waveform(wav)
1114
+ await send_json({"type":"status","reseeded":True})
1115
+
1116
+ elif mtype == "reseed_splice" and mode == "bar":
1117
+ with jam_lock:
1118
+ worker = jam_registry.get(msg.get("session_id"))
1119
+ if worker is None or not worker.is_alive():
1120
+ await send_json({"type":"error","error":"Session not found"}); continue
1121
+ anchor = float(msg.get("anchor_bars", 2.0))
1122
+ b64 = msg.get("combined_audio_b64")
1123
+ if b64:
1124
+ data = base64.b64decode(b64)
1125
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp:
1126
+ tmp.write(data); path = tmp.name
1127
+ wav = au.Waveform.from_file(path).resample(worker.mrt.sample_rate).as_stereo()
1128
+ worker.reseed_splice(wav, anchor_bars=anchor)
1129
+ else:
1130
+ # fallback: model-side stream splice
1131
+ worker.reseed_splice(worker.params.combined_loop, anchor_bars=anchor)
1132
+ await send_json({"type":"status","splice":anchor})
1133
+
1134
+ elif mtype == "stop":
1135
+ if mode == "rt":
1136
+ websocket._rt_running = False
1137
+ else:
1138
+ with jam_lock:
1139
+ worker = jam_registry.get(msg.get("session_id"))
1140
+ if worker is not None:
1141
+ worker.stop()
1142
+ await send_json({"type":"stopped"}); break
1143
+
1144
+ elif mtype == "ping":
1145
+ await send_json({"type":"pong"})
1146
+
1147
+ else:
1148
+ await send_json({"type":"error","error":f"Unknown type {mtype}"})
1149
+
1150
+ except WebSocketDisconnect:
1151
+ # best-effort cleanup for bar-mode sessions started within this socket (optional)
1152
+ pass
1153
+ except Exception as e:
1154
+ try:
1155
+ await send_json({"type":"error","error":str(e)})
1156
+ except Exception:
1157
+ pass
1158
+ finally:
1159
+ try:
1160
+ await websocket.close()
1161
+ except Exception:
1162
+ pass