Commit
·
d54b5ce
1
Parent(s):
6896250
cleaning up ws close
Browse files
app.py
CHANGED
@@ -24,6 +24,24 @@ from typing import Optional
|
|
24 |
|
25 |
import json, asyncio, base64
|
26 |
import time
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
27 |
|
28 |
# --- Patch T5X mesh helpers for GPUs on JAX >= 0.7 (coords present, no core_on_chip) ---
|
29 |
def _patch_t5x_for_gpu_coords():
|
@@ -932,8 +950,9 @@ async def ws_jam(websocket: WebSocket):
|
|
932 |
binary_audio = False
|
933 |
mode = "rt" # or "bar"
|
934 |
|
|
|
935 |
async def send_json(obj):
|
936 |
-
await websocket
|
937 |
|
938 |
try:
|
939 |
while True:
|
@@ -1053,12 +1072,11 @@ async def ws_jam(websocket: WebSocket):
|
|
1053 |
chunk_secs = (mrt.config.chunk_length_frames * mrt.config.frame_length_samples) / float(mrt.sample_rate)
|
1054 |
target_next = time.perf_counter()
|
1055 |
while websocket._rt_running:
|
1056 |
-
|
1057 |
mrt.guidance_weight = websocket._rt_guid
|
1058 |
mrt.temperature = websocket._rt_temp
|
1059 |
mrt.topk = websocket._rt_topk
|
1060 |
|
1061 |
-
# style already in websocket._style
|
1062 |
wav, new_state = mrt.generate_chunk(state=websocket._state, style=websocket._style)
|
1063 |
websocket._state = new_state
|
1064 |
|
@@ -1066,24 +1084,38 @@ async def ws_jam(websocket: WebSocket):
|
|
1066 |
buf = io.BytesIO()
|
1067 |
sf.write(buf, x, mrt.sample_rate, subtype="FLOAT", format="WAV")
|
1068 |
|
|
|
|
|
1069 |
if binary_audio:
|
1070 |
-
|
1071 |
-
|
|
|
|
|
|
|
1072 |
else:
|
1073 |
b64 = base64.b64encode(buf.getvalue()).decode("utf-8")
|
1074 |
-
await send_json({"type":"chunk","audio_base64":b64,
|
1075 |
-
|
1076 |
|
1077 |
-
|
|
|
|
|
|
|
|
|
1078 |
if getattr(websocket, "_pace", "asap") == "realtime":
|
1079 |
t1 = time.perf_counter()
|
1080 |
target_next += chunk_secs
|
1081 |
-
sleep_s = max(0.0, target_next - t1 - 0.02)
|
1082 |
if sleep_s > 0:
|
1083 |
await asyncio.sleep(sleep_s)
|
1084 |
-
|
1085 |
-
|
1086 |
-
|
|
|
|
|
|
|
|
|
|
|
1087 |
continue # skip the “bar-mode started” message below
|
1088 |
|
1089 |
await send_json({"type":"started","session_id": sid, "mode": mode})
|
@@ -1178,12 +1210,13 @@ async def ws_jam(websocket: WebSocket):
|
|
1178 |
elif mtype == "stop":
|
1179 |
if mode == "rt":
|
1180 |
websocket._rt_running = False
|
1181 |
-
|
1182 |
-
|
1183 |
-
|
1184 |
-
|
1185 |
-
|
1186 |
-
|
|
|
1187 |
|
1188 |
elif mtype == "ping":
|
1189 |
await send_json({"type":"pong"})
|
@@ -1201,6 +1234,7 @@ async def ws_jam(websocket: WebSocket):
|
|
1201 |
pass
|
1202 |
finally:
|
1203 |
try:
|
1204 |
-
|
|
|
1205 |
except Exception:
|
1206 |
pass
|
|
|
24 |
|
25 |
import json, asyncio, base64
|
26 |
import time
|
27 |
+
from starlette.websockets import WebSocketState
|
28 |
+
try:
|
29 |
+
from uvicorn.protocols.utils import ClientDisconnected # uvicorn >= 0.20
|
30 |
+
except Exception:
|
31 |
+
class ClientDisconnected(Exception): # fallback
|
32 |
+
pass
|
33 |
+
|
34 |
+
async def send_json_safe(ws: WebSocket, obj) -> bool:
|
35 |
+
"""Try to send. Returns False if the socket is (or becomes) closed."""
|
36 |
+
if ws.client_state == WebSocketState.DISCONNECTED or ws.application_state == WebSocketState.DISCONNECTED:
|
37 |
+
return False
|
38 |
+
try:
|
39 |
+
await ws.send_text(json.dumps(obj))
|
40 |
+
return True
|
41 |
+
except (WebSocketDisconnect, ClientDisconnected, RuntimeError):
|
42 |
+
return False
|
43 |
+
except Exception:
|
44 |
+
return False
|
45 |
|
46 |
# --- Patch T5X mesh helpers for GPUs on JAX >= 0.7 (coords present, no core_on_chip) ---
|
47 |
def _patch_t5x_for_gpu_coords():
|
|
|
950 |
binary_audio = False
|
951 |
mode = "rt" # or "bar"
|
952 |
|
953 |
+
# NEW: capture ws in closure
|
954 |
async def send_json(obj):
|
955 |
+
return await send_json_safe(websocket, obj)
|
956 |
|
957 |
try:
|
958 |
while True:
|
|
|
1072 |
chunk_secs = (mrt.config.chunk_length_frames * mrt.config.frame_length_samples) / float(mrt.sample_rate)
|
1073 |
target_next = time.perf_counter()
|
1074 |
while websocket._rt_running:
|
1075 |
+
# read knobs (already set by update)
|
1076 |
mrt.guidance_weight = websocket._rt_guid
|
1077 |
mrt.temperature = websocket._rt_temp
|
1078 |
mrt.topk = websocket._rt_topk
|
1079 |
|
|
|
1080 |
wav, new_state = mrt.generate_chunk(state=websocket._state, style=websocket._style)
|
1081 |
websocket._state = new_state
|
1082 |
|
|
|
1084 |
buf = io.BytesIO()
|
1085 |
sf.write(buf, x, mrt.sample_rate, subtype="FLOAT", format="WAV")
|
1086 |
|
1087 |
+
# send bytes / json best-effort
|
1088 |
+
ok = True
|
1089 |
if binary_audio:
|
1090 |
+
try:
|
1091 |
+
await websocket.send_bytes(buf.getvalue())
|
1092 |
+
ok = await send_json({"type":"chunk_meta","metadata":{"sample_rate":mrt.sample_rate}})
|
1093 |
+
except Exception:
|
1094 |
+
ok = False
|
1095 |
else:
|
1096 |
b64 = base64.b64encode(buf.getvalue()).decode("utf-8")
|
1097 |
+
ok = await send_json({"type":"chunk","audio_base64":b64,
|
1098 |
+
"metadata":{"sample_rate":mrt.sample_rate}})
|
1099 |
|
1100 |
+
if not ok:
|
1101 |
+
# client went away — exit cleanly
|
1102 |
+
break
|
1103 |
+
|
1104 |
+
# pacing (use captured flag from start)
|
1105 |
if getattr(websocket, "_pace", "asap") == "realtime":
|
1106 |
t1 = time.perf_counter()
|
1107 |
target_next += chunk_secs
|
1108 |
+
sleep_s = max(0.0, target_next - t1 - 0.02)
|
1109 |
if sleep_s > 0:
|
1110 |
await asyncio.sleep(sleep_s)
|
1111 |
+
|
1112 |
+
except asyncio.CancelledError:
|
1113 |
+
# normal on stop/close — just exit
|
1114 |
+
pass
|
1115 |
+
except Exception:
|
1116 |
+
# don't try to send an error; socket may be closed
|
1117 |
+
pass
|
1118 |
+
websocket._rt_task = asyncio.create_task(_rt_loop())
|
1119 |
continue # skip the “bar-mode started” message below
|
1120 |
|
1121 |
await send_json({"type":"started","session_id": sid, "mode": mode})
|
|
|
1210 |
elif mtype == "stop":
|
1211 |
if mode == "rt":
|
1212 |
websocket._rt_running = False
|
1213 |
+
task = getattr(websocket, "_rt_task", None)
|
1214 |
+
if task is not None:
|
1215 |
+
task.cancel()
|
1216 |
+
try: await task
|
1217 |
+
except asyncio.CancelledError: pass
|
1218 |
+
await send_json({"type":"stopped"})
|
1219 |
+
break # <- add this if you want to end the socket after stop
|
1220 |
|
1221 |
elif mtype == "ping":
|
1222 |
await send_json({"type":"pong"})
|
|
|
1234 |
pass
|
1235 |
finally:
|
1236 |
try:
|
1237 |
+
if websocket.client_state != WebSocketState.DISCONNECTED:
|
1238 |
+
await websocket.close()
|
1239 |
except Exception:
|
1240 |
pass
|