thecollabagepatch commited on
Commit
d54b5ce
·
1 Parent(s): 6896250

cleaning up ws close

Browse files
Files changed (1) hide show
  1. app.py +53 -19
app.py CHANGED
@@ -24,6 +24,24 @@ from typing import Optional
24
 
25
  import json, asyncio, base64
26
  import time
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
  # --- Patch T5X mesh helpers for GPUs on JAX >= 0.7 (coords present, no core_on_chip) ---
29
  def _patch_t5x_for_gpu_coords():
@@ -932,8 +950,9 @@ async def ws_jam(websocket: WebSocket):
932
  binary_audio = False
933
  mode = "rt" # or "bar"
934
 
 
935
  async def send_json(obj):
936
- await websocket.send_text(json.dumps(obj))
937
 
938
  try:
939
  while True:
@@ -1053,12 +1072,11 @@ async def ws_jam(websocket: WebSocket):
1053
  chunk_secs = (mrt.config.chunk_length_frames * mrt.config.frame_length_samples) / float(mrt.sample_rate)
1054
  target_next = time.perf_counter()
1055
  while websocket._rt_running:
1056
- t0 = time.perf_counter()
1057
  mrt.guidance_weight = websocket._rt_guid
1058
  mrt.temperature = websocket._rt_temp
1059
  mrt.topk = websocket._rt_topk
1060
 
1061
- # style already in websocket._style
1062
  wav, new_state = mrt.generate_chunk(state=websocket._state, style=websocket._style)
1063
  websocket._state = new_state
1064
 
@@ -1066,24 +1084,38 @@ async def ws_jam(websocket: WebSocket):
1066
  buf = io.BytesIO()
1067
  sf.write(buf, x, mrt.sample_rate, subtype="FLOAT", format="WAV")
1068
 
 
 
1069
  if binary_audio:
1070
- await websocket.send_bytes(buf.getvalue())
1071
- await send_json({"type":"chunk_meta","metadata":{"sample_rate":mrt.sample_rate}})
 
 
 
1072
  else:
1073
  b64 = base64.b64encode(buf.getvalue()).decode("utf-8")
1074
- await send_json({"type":"chunk","audio_base64":b64,
1075
- "metadata":{"sample_rate":mrt.sample_rate}})
1076
 
1077
- # --- pacing ---
 
 
 
 
1078
  if getattr(websocket, "_pace", "asap") == "realtime":
1079
  t1 = time.perf_counter()
1080
  target_next += chunk_secs
1081
- sleep_s = max(0.0, target_next - t1 - 0.02) # tiny safety margin
1082
  if sleep_s > 0:
1083
  await asyncio.sleep(sleep_s)
1084
- except Exception as e:
1085
- await send_json({"type":"error","error":str(e)})
1086
- asyncio.create_task(_rt_loop())
 
 
 
 
 
1087
  continue # skip the “bar-mode started” message below
1088
 
1089
  await send_json({"type":"started","session_id": sid, "mode": mode})
@@ -1178,12 +1210,13 @@ async def ws_jam(websocket: WebSocket):
1178
  elif mtype == "stop":
1179
  if mode == "rt":
1180
  websocket._rt_running = False
1181
- else:
1182
- with jam_lock:
1183
- worker = jam_registry.get(msg.get("session_id"))
1184
- if worker is not None:
1185
- worker.stop()
1186
- await send_json({"type":"stopped"}); break
 
1187
 
1188
  elif mtype == "ping":
1189
  await send_json({"type":"pong"})
@@ -1201,6 +1234,7 @@ async def ws_jam(websocket: WebSocket):
1201
  pass
1202
  finally:
1203
  try:
1204
- await websocket.close()
 
1205
  except Exception:
1206
  pass
 
24
 
25
  import json, asyncio, base64
26
  import time
27
+ from starlette.websockets import WebSocketState
28
+ try:
29
+ from uvicorn.protocols.utils import ClientDisconnected # uvicorn >= 0.20
30
+ except Exception:
31
+ class ClientDisconnected(Exception): # fallback
32
+ pass
33
+
34
+ async def send_json_safe(ws: WebSocket, obj) -> bool:
35
+ """Try to send. Returns False if the socket is (or becomes) closed."""
36
+ if ws.client_state == WebSocketState.DISCONNECTED or ws.application_state == WebSocketState.DISCONNECTED:
37
+ return False
38
+ try:
39
+ await ws.send_text(json.dumps(obj))
40
+ return True
41
+ except (WebSocketDisconnect, ClientDisconnected, RuntimeError):
42
+ return False
43
+ except Exception:
44
+ return False
45
 
46
  # --- Patch T5X mesh helpers for GPUs on JAX >= 0.7 (coords present, no core_on_chip) ---
47
  def _patch_t5x_for_gpu_coords():
 
950
  binary_audio = False
951
  mode = "rt" # or "bar"
952
 
953
+ # NEW: capture ws in closure
954
  async def send_json(obj):
955
+ return await send_json_safe(websocket, obj)
956
 
957
  try:
958
  while True:
 
1072
  chunk_secs = (mrt.config.chunk_length_frames * mrt.config.frame_length_samples) / float(mrt.sample_rate)
1073
  target_next = time.perf_counter()
1074
  while websocket._rt_running:
1075
+ # read knobs (already set by update)
1076
  mrt.guidance_weight = websocket._rt_guid
1077
  mrt.temperature = websocket._rt_temp
1078
  mrt.topk = websocket._rt_topk
1079
 
 
1080
  wav, new_state = mrt.generate_chunk(state=websocket._state, style=websocket._style)
1081
  websocket._state = new_state
1082
 
 
1084
  buf = io.BytesIO()
1085
  sf.write(buf, x, mrt.sample_rate, subtype="FLOAT", format="WAV")
1086
 
1087
+ # send bytes / json best-effort
1088
+ ok = True
1089
  if binary_audio:
1090
+ try:
1091
+ await websocket.send_bytes(buf.getvalue())
1092
+ ok = await send_json({"type":"chunk_meta","metadata":{"sample_rate":mrt.sample_rate}})
1093
+ except Exception:
1094
+ ok = False
1095
  else:
1096
  b64 = base64.b64encode(buf.getvalue()).decode("utf-8")
1097
+ ok = await send_json({"type":"chunk","audio_base64":b64,
1098
+ "metadata":{"sample_rate":mrt.sample_rate}})
1099
 
1100
+ if not ok:
1101
+ # client went away — exit cleanly
1102
+ break
1103
+
1104
+ # pacing (use captured flag from start)
1105
  if getattr(websocket, "_pace", "asap") == "realtime":
1106
  t1 = time.perf_counter()
1107
  target_next += chunk_secs
1108
+ sleep_s = max(0.0, target_next - t1 - 0.02)
1109
  if sleep_s > 0:
1110
  await asyncio.sleep(sleep_s)
1111
+
1112
+ except asyncio.CancelledError:
1113
+ # normal on stop/close — just exit
1114
+ pass
1115
+ except Exception:
1116
+ # don't try to send an error; socket may be closed
1117
+ pass
1118
+ websocket._rt_task = asyncio.create_task(_rt_loop())
1119
  continue # skip the “bar-mode started” message below
1120
 
1121
  await send_json({"type":"started","session_id": sid, "mode": mode})
 
1210
  elif mtype == "stop":
1211
  if mode == "rt":
1212
  websocket._rt_running = False
1213
+ task = getattr(websocket, "_rt_task", None)
1214
+ if task is not None:
1215
+ task.cancel()
1216
+ try: await task
1217
+ except asyncio.CancelledError: pass
1218
+ await send_json({"type":"stopped"})
1219
+ break # <- add this if you want to end the socket after stop
1220
 
1221
  elif mtype == "ping":
1222
  await send_json({"type":"pong"})
 
1234
  pass
1235
  finally:
1236
  try:
1237
+ if websocket.client_state != WebSocketState.DISCONNECTED:
1238
+ await websocket.close()
1239
  except Exception:
1240
  pass