thecollabagepatch commited on
Commit
6896250
·
1 Parent(s): 2bd198e

realtime flag added to websockets route

Browse files
Files changed (1) hide show
  1. app.py +54 -10
app.py CHANGED
@@ -23,6 +23,7 @@ from typing import Optional
23
 
24
 
25
  import json, asyncio, base64
 
26
 
27
  # --- Patch T5X mesh helpers for GPUs on JAX >= 0.7 (coords present, no core_on_chip) ---
28
  def _patch_t5x_for_gpu_coords():
@@ -902,6 +903,27 @@ def read_root():
902
  """
903
  return Response(content=html_content, media_type="text/html")
904
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
905
  @app.websocket("/ws/jam")
906
  async def ws_jam(websocket: WebSocket):
907
  await websocket.accept()
@@ -1004,44 +1026,61 @@ async def ws_jam(websocket: WebSocket):
1004
  # seed a fresh state with a silent context like warmup
1005
  mrt = get_mrt()
1006
  state = mrt.init_state()
1007
- # build exact-length silent context
1008
  codec_fps = float(mrt.codec.frame_rate)
1009
  ctx_seconds = float(mrt.config.context_length_frames) / codec_fps
1010
  sr = int(mrt.sample_rate)
1011
  samples = int(max(1, round(ctx_seconds * sr)))
1012
  silent = au.Waveform(np.zeros((samples,2), np.float32), sr)
1013
- tokens_full = mrt.codec.encode(silent).astype(np.int32)
1014
- tokens = tokens_full[:, :mrt.config.decoder_codec_rvq_depth]
1015
  state.context_tokens = tokens
1016
- # keep local “rt” loop state
1017
  websocket._mrt = mrt
1018
  websocket._state = state
1019
- websocket._style = mrt.embed_style(params.get("styles","warmup"))
 
 
1020
  websocket._rt_running = True
1021
  websocket._rt_sr = sr
1022
  websocket._rt_topk = int(params.get("topk", 40))
1023
  websocket._rt_temp = float(params.get("temperature", 1.1))
1024
  websocket._rt_guid = float(params.get("guidance_weight", 1.1))
 
1025
  await send_json({"type":"started","mode":"rt"})
1026
  # kick off a background task to stream ~2s chunks
1027
  async def _rt_loop():
1028
  try:
 
 
 
1029
  while websocket._rt_running:
 
1030
  mrt.guidance_weight = websocket._rt_guid
1031
- mrt.temperature = websocket._rt_temp
1032
- mrt.topk = websocket._rt_topk
 
 
1033
  wav, new_state = mrt.generate_chunk(state=websocket._state, style=websocket._style)
1034
  websocket._state = new_state
 
1035
  x = wav.samples.astype(np.float32, copy=False)
1036
  buf = io.BytesIO()
1037
  sf.write(buf, x, mrt.sample_rate, subtype="FLOAT", format="WAV")
 
1038
  if binary_audio:
1039
  await websocket.send_bytes(buf.getvalue())
1040
  await send_json({"type":"chunk_meta","metadata":{"sample_rate":mrt.sample_rate}})
1041
  else:
1042
  b64 = base64.b64encode(buf.getvalue()).decode("utf-8")
1043
  await send_json({"type":"chunk","audio_base64":b64,
1044
- "metadata":{"sample_rate":mrt.sample_rate}})
 
 
 
 
 
 
 
 
1045
  except Exception as e:
1046
  await send_json({"type":"error","error":str(e)})
1047
  asyncio.create_task(_rt_loop())
@@ -1088,8 +1127,13 @@ async def ws_jam(websocket: WebSocket):
1088
  websocket._rt_temp = float(msg.get("temperature", websocket._rt_temp))
1089
  websocket._rt_topk = int(msg.get("topk", websocket._rt_topk))
1090
  websocket._rt_guid = float(msg.get("guidance_weight", websocket._rt_guid))
1091
- if "styles" in msg:
1092
- websocket._style = websocket._mrt.embed_style(msg["styles"])
 
 
 
 
 
1093
  await send_json({"type":"status","updated":"rt-knobs"})
1094
 
1095
  elif mtype == "consume" and mode == "bar":
 
23
 
24
 
25
  import json, asyncio, base64
26
+ import time
27
 
28
  # --- Patch T5X mesh helpers for GPUs on JAX >= 0.7 (coords present, no core_on_chip) ---
29
  def _patch_t5x_for_gpu_coords():
 
903
  """
904
  return Response(content=html_content, media_type="text/html")
905
 
906
+
907
+ # ----------------------------
908
+ # websockets route
909
+ # ----------------------------
910
+
911
+
912
+
913
+ def _combine_styles(mrt, styles_str: str = "", weights_str: str = ""):
914
+ extra = [s.strip() for s in (styles_str or "").split(",") if s.strip()]
915
+ if not extra:
916
+ return mrt.embed_style("warmup")
917
+ sw = [float(x) for x in (weights_str or "").split(",") if x.strip()]
918
+ embeds, weights = [], []
919
+ for i, s in enumerate(extra):
920
+ embeds.append(mrt.embed_style(s))
921
+ weights.append(sw[i] if i < len(sw) else 1.0)
922
+ wsum = sum(weights) or 1.0
923
+ weights = [w/wsum for w in weights]
924
+ import numpy as np
925
+ return np.sum([w*e for w, e in zip(weights, embeds)], axis=0).astype(np.float32)
926
+
927
  @app.websocket("/ws/jam")
928
  async def ws_jam(websocket: WebSocket):
929
  await websocket.accept()
 
1026
  # seed a fresh state with a silent context like warmup
1027
  mrt = get_mrt()
1028
  state = mrt.init_state()
 
1029
  codec_fps = float(mrt.codec.frame_rate)
1030
  ctx_seconds = float(mrt.config.context_length_frames) / codec_fps
1031
  sr = int(mrt.sample_rate)
1032
  samples = int(max(1, round(ctx_seconds * sr)))
1033
  silent = au.Waveform(np.zeros((samples,2), np.float32), sr)
1034
+ tokens = mrt.codec.encode(silent).astype(np.int32)[:, :mrt.config.decoder_codec_rvq_depth]
 
1035
  state.context_tokens = tokens
1036
+
1037
  websocket._mrt = mrt
1038
  websocket._state = state
1039
+ websocket._style = _combine_styles(mrt,
1040
+ params.get("styles","warmup"),
1041
+ params.get("style_weights",""))
1042
  websocket._rt_running = True
1043
  websocket._rt_sr = sr
1044
  websocket._rt_topk = int(params.get("topk", 40))
1045
  websocket._rt_temp = float(params.get("temperature", 1.1))
1046
  websocket._rt_guid = float(params.get("guidance_weight", 1.1))
1047
+ websocket._pace = params.get("pace", "asap") # "realtime" | "asap"
1048
  await send_json({"type":"started","mode":"rt"})
1049
  # kick off a background task to stream ~2s chunks
1050
  async def _rt_loop():
1051
  try:
1052
+ mrt = websocket._mrt
1053
+ chunk_secs = (mrt.config.chunk_length_frames * mrt.config.frame_length_samples) / float(mrt.sample_rate)
1054
+ target_next = time.perf_counter()
1055
  while websocket._rt_running:
1056
+ t0 = time.perf_counter()
1057
  mrt.guidance_weight = websocket._rt_guid
1058
+ mrt.temperature = websocket._rt_temp
1059
+ mrt.topk = websocket._rt_topk
1060
+
1061
+ # style already in websocket._style
1062
  wav, new_state = mrt.generate_chunk(state=websocket._state, style=websocket._style)
1063
  websocket._state = new_state
1064
+
1065
  x = wav.samples.astype(np.float32, copy=False)
1066
  buf = io.BytesIO()
1067
  sf.write(buf, x, mrt.sample_rate, subtype="FLOAT", format="WAV")
1068
+
1069
  if binary_audio:
1070
  await websocket.send_bytes(buf.getvalue())
1071
  await send_json({"type":"chunk_meta","metadata":{"sample_rate":mrt.sample_rate}})
1072
  else:
1073
  b64 = base64.b64encode(buf.getvalue()).decode("utf-8")
1074
  await send_json({"type":"chunk","audio_base64":b64,
1075
+ "metadata":{"sample_rate":mrt.sample_rate}})
1076
+
1077
+ # --- pacing ---
1078
+ if getattr(websocket, "_pace", "asap") == "realtime":
1079
+ t1 = time.perf_counter()
1080
+ target_next += chunk_secs
1081
+ sleep_s = max(0.0, target_next - t1 - 0.02) # tiny safety margin
1082
+ if sleep_s > 0:
1083
+ await asyncio.sleep(sleep_s)
1084
  except Exception as e:
1085
  await send_json({"type":"error","error":str(e)})
1086
  asyncio.create_task(_rt_loop())
 
1127
  websocket._rt_temp = float(msg.get("temperature", websocket._rt_temp))
1128
  websocket._rt_topk = int(msg.get("topk", websocket._rt_topk))
1129
  websocket._rt_guid = float(msg.get("guidance_weight", websocket._rt_guid))
1130
+
1131
+ if ("styles" in msg) or ("style_weights" in msg):
1132
+ websocket._style = _combine_styles(
1133
+ websocket._mrt,
1134
+ msg.get("styles", ""),
1135
+ msg.get("style_weights", "")
1136
+ )
1137
  await send_json({"type":"status","updated":"rt-knobs"})
1138
 
1139
  elif mtype == "consume" and mode == "bar":