Commit
·
7ae6392
1
Parent(s):
7fe8be5
a shot in the dark
Browse files- jam_worker.py +275 -518
jam_worker.py
CHANGED
@@ -1,16 +1,26 @@
|
|
1 |
-
# jam_worker.py -
|
2 |
-
|
3 |
-
|
|
|
|
|
|
|
|
|
|
|
4 |
import numpy as np
|
5 |
-
import soundfile as sf
|
6 |
from magenta_rt import audio as au
|
7 |
-
|
8 |
from utils import (
|
9 |
-
|
10 |
-
|
11 |
-
|
|
|
|
|
12 |
)
|
13 |
|
|
|
|
|
|
|
|
|
14 |
@dataclass
|
15 |
class JamParams:
|
16 |
bpm: float
|
@@ -19,558 +29,305 @@ class JamParams:
|
|
19 |
target_sr: int
|
20 |
loudness_mode: str = "auto"
|
21 |
headroom_db: float = 1.0
|
22 |
-
style_vec: np.ndarray
|
23 |
-
ref_loop:
|
24 |
-
combined_loop:
|
25 |
guidance_weight: float = 1.1
|
26 |
temperature: float = 1.1
|
27 |
topk: int = 40
|
28 |
|
|
|
29 |
@dataclass
|
30 |
class JamChunk:
|
31 |
index: int
|
32 |
audio_base64: str
|
33 |
metadata: dict
|
34 |
|
35 |
-
class JamWorker(threading.Thread):
|
36 |
-
def __init__(self, mrt, params: JamParams):
|
37 |
-
super().__init__(daemon=True)
|
38 |
-
self.mrt = mrt
|
39 |
-
self.params = params
|
40 |
-
self.state = mrt.init_state()
|
41 |
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
|
46 |
-
|
47 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
48 |
|
49 |
-
|
50 |
-
self.
|
51 |
-
self.
|
|
|
52 |
|
53 |
-
|
54 |
-
self.
|
55 |
|
56 |
-
# NEW: Track delivery state
|
57 |
-
self._last_delivered_index = 0
|
58 |
-
self._max_buffer_ahead = 5
|
59 |
-
|
60 |
-
# Timing info
|
61 |
-
self.last_chunk_started_at = None
|
62 |
-
self.last_chunk_completed_at = None
|
63 |
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
|
|
|
|
|
|
|
68 |
|
69 |
-
def
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
|
74 |
-
|
75 |
-
|
|
|
|
|
|
|
76 |
|
77 |
-
|
78 |
-
|
79 |
-
self.params.bpm,
|
80 |
-
self.params.beats_per_bar,
|
81 |
-
ctx_seconds
|
82 |
-
)
|
83 |
|
84 |
-
|
85 |
-
|
|
|
|
|
86 |
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
fps=float(self.mrt.codec.frame_rate), # keep fractional fps
|
91 |
-
ctx_frames=self.mrt.config.context_length_frames,
|
92 |
-
beats_per_bar=self.params.beats_per_bar
|
93 |
-
)
|
94 |
|
95 |
-
|
96 |
-
|
97 |
-
|
|
|
98 |
|
99 |
-
|
100 |
-
|
101 |
-
with self._lock:
|
102 |
-
if not hasattr(self, "_original_context_tokens") or self._original_context_tokens is None:
|
103 |
-
self._original_context_tokens = np.copy(context_tokens) # shape: [T, depth]
|
104 |
|
105 |
-
|
106 |
-
|
|
|
|
|
107 |
|
108 |
-
|
109 |
-
self.
|
|
|
110 |
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
if temperature is not None: self.params.temperature = float(temperature)
|
115 |
-
if topk is not None: self.params.topk = int(topk)
|
116 |
-
|
117 |
-
def get_next_chunk(self) -> JamChunk | None:
|
118 |
-
"""Get the next sequential chunk (blocks/waits if not ready)"""
|
119 |
-
target_index = self._last_delivered_index + 1
|
120 |
-
|
121 |
-
# Wait for the target chunk to be ready (with timeout)
|
122 |
-
max_wait = 30.0 # seconds
|
123 |
-
start_time = time.time()
|
124 |
-
|
125 |
-
while time.time() - start_time < max_wait and not self._stop_event.is_set():
|
126 |
-
with self._lock:
|
127 |
-
# Look for the exact chunk we need
|
128 |
-
for chunk in self.outbox:
|
129 |
-
if chunk.index == target_index:
|
130 |
-
self._last_delivered_index = target_index
|
131 |
-
print(f"📦 Delivered chunk {target_index}")
|
132 |
-
return chunk
|
133 |
-
|
134 |
-
# Not ready yet, wait a bit
|
135 |
-
time.sleep(0.1)
|
136 |
-
|
137 |
-
# Timeout or stopped
|
138 |
-
return None
|
139 |
|
140 |
-
|
141 |
-
|
142 |
-
with self._lock:
|
143 |
-
self._last_delivered_index = max(self._last_delivered_index, chunk_index)
|
144 |
-
print(f"✅ Chunk {chunk_index} consumed")
|
145 |
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
# Don't generate if we're already too far ahead
|
150 |
-
if self.idx > self._last_delivered_index + self._max_buffer_ahead:
|
151 |
-
return False
|
152 |
-
return True
|
153 |
-
|
154 |
-
def _seconds_per_bar(self) -> float:
|
155 |
-
return self.params.beats_per_bar * (60.0 / self.params.bpm)
|
156 |
-
|
157 |
-
def _snap_and_encode(self, y, seconds, target_sr, bars):
|
158 |
-
cur_sr = int(self.mrt.sample_rate)
|
159 |
-
x = y.samples if y.samples.ndim == 2 else y.samples[:, None]
|
160 |
-
x = resample_and_snap(x, cur_sr=cur_sr, target_sr=target_sr, seconds=seconds)
|
161 |
-
b64, total_samples, channels = wav_bytes_base64(x, target_sr)
|
162 |
-
meta = {
|
163 |
-
"bpm": int(round(self.params.bpm)),
|
164 |
-
"bars": int(bars),
|
165 |
-
"beats_per_bar": int(self.params.beats_per_bar),
|
166 |
-
"sample_rate": int(target_sr),
|
167 |
-
"channels": channels,
|
168 |
-
"total_samples": total_samples,
|
169 |
-
"seconds_per_bar": self._seconds_per_bar(),
|
170 |
-
"loop_duration_seconds": bars * self._seconds_per_bar(),
|
171 |
-
"guidance_weight": self.params.guidance_weight,
|
172 |
-
"temperature": self.params.temperature,
|
173 |
-
"topk": self.params.topk,
|
174 |
-
}
|
175 |
-
return b64, meta
|
176 |
-
|
177 |
-
def _append_model_chunk_to_stream(self, wav):
|
178 |
-
"""Incrementally append a model chunk with equal-power crossfade."""
|
179 |
-
xfade_s = float(self.mrt.config.crossfade_length)
|
180 |
-
sr = int(self.mrt.sample_rate)
|
181 |
-
xfade_n = int(round(xfade_s * sr))
|
182 |
-
|
183 |
-
s = wav.samples if wav.samples.ndim == 2 else wav.samples[:, None]
|
184 |
-
|
185 |
-
if getattr(self, "_stream", None) is None:
|
186 |
-
# First chunk: drop model pre-roll (xfade head)
|
187 |
-
if s.shape[0] > xfade_n:
|
188 |
-
self._stream = s[xfade_n:].astype(np.float32, copy=True)
|
189 |
-
else:
|
190 |
-
self._stream = np.zeros((0, s.shape[1]), dtype=np.float32)
|
191 |
-
self._next_emit_start = 0 # pointer into _stream (model SR samples)
|
192 |
-
return
|
193 |
|
194 |
-
|
195 |
-
if s.shape[0] <= xfade_n or self._stream.shape[0] < xfade_n:
|
196 |
-
# Degenerate safeguard
|
197 |
-
self._stream = np.concatenate([self._stream, s], axis=0)
|
198 |
-
return
|
199 |
|
200 |
-
|
201 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
202 |
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
|
|
|
|
|
|
|
|
|
207 |
|
208 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
209 |
|
210 |
-
|
211 |
-
|
212 |
-
|
|
|
213 |
|
214 |
-
#
|
215 |
-
|
216 |
-
ctx_seconds = float(self.mrt.config.context_length_frames) / codec_fps
|
217 |
-
from utils import take_bar_aligned_tail, make_bar_aligned_context
|
218 |
|
219 |
-
|
|
|
|
|
|
|
220 |
tokens_full = self.mrt.codec.encode(tail).astype(np.int32)
|
221 |
-
|
222 |
-
context_tokens =
|
223 |
-
|
224 |
-
|
225 |
-
|
226 |
-
|
227 |
-
|
228 |
-
self.
|
229 |
-
|
230 |
-
|
231 |
-
|
232 |
-
#
|
233 |
-
|
234 |
-
|
235 |
-
|
236 |
-
|
237 |
-
|
238 |
-
|
239 |
-
|
240 |
-
|
241 |
-
|
242 |
-
|
243 |
-
|
244 |
-
""
|
245 |
-
|
246 |
-
|
247 |
-
|
248 |
-
|
249 |
-
|
250 |
-
|
251 |
-
|
252 |
-
|
253 |
-
|
254 |
-
|
255 |
-
)
|
256 |
-
|
257 |
-
|
258 |
-
|
259 |
-
|
260 |
-
|
261 |
-
|
262 |
-
|
263 |
-
|
264 |
-
|
265 |
-
|
266 |
-
return
|
267 |
-
|
268 |
-
|
269 |
-
|
270 |
-
|
271 |
-
|
272 |
-
|
273 |
-
|
274 |
-
|
275 |
-
|
276 |
-
frames_per_bar = self._frames_per_bar()
|
277 |
-
|
278 |
-
# 1) Anchor tail (whole bars)
|
279 |
-
anchor = self._bar_aligned_tail(original_tokens, math.floor(anchor_bars))
|
280 |
-
|
281 |
-
# 2) Fill remainder with recent (prefer whole bars)
|
282 |
-
a = anchor.shape[0]
|
283 |
-
remain = max(ctx_frames - a, 0)
|
284 |
-
|
285 |
-
recent = recent_tokens[:0]
|
286 |
-
used_recent = 0 # frames taken from the END of recent_tokens
|
287 |
-
if remain > 0:
|
288 |
-
bars_fit = remain // frames_per_bar
|
289 |
-
if bars_fit >= 1:
|
290 |
-
want_recent_frames = int(bars_fit * frames_per_bar)
|
291 |
-
used_recent = min(want_recent_frames, recent_tokens.shape[0])
|
292 |
-
recent = recent_tokens[-used_recent:] if used_recent > 0 else recent_tokens[:0]
|
293 |
-
else:
|
294 |
-
used_recent = min(remain, recent_tokens.shape[0])
|
295 |
-
recent = recent_tokens[-used_recent:] if used_recent > 0 else recent_tokens[:0]
|
296 |
-
|
297 |
-
# 3) Concat in order [anchor, recent]
|
298 |
-
if anchor.size or recent.size:
|
299 |
-
out = np.concatenate([anchor, recent], axis=0)
|
300 |
else:
|
301 |
-
|
302 |
-
|
303 |
|
304 |
-
#
|
305 |
-
if
|
306 |
-
|
|
|
|
|
|
|
307 |
|
308 |
-
|
309 |
-
|
310 |
-
|
311 |
-
|
312 |
-
|
313 |
-
if
|
314 |
-
|
315 |
-
|
316 |
-
|
317 |
-
|
318 |
-
|
319 |
-
|
320 |
-
|
321 |
-
|
322 |
-
|
323 |
-
|
324 |
-
|
325 |
-
|
326 |
-
|
327 |
-
|
328 |
-
|
329 |
-
|
330 |
-
|
331 |
-
|
332 |
-
|
333 |
-
|
334 |
-
|
335 |
-
|
336 |
-
|
337 |
-
|
338 |
-
|
339 |
-
|
340 |
-
|
341 |
-
|
342 |
-
|
343 |
-
|
344 |
-
|
345 |
-
|
346 |
-
|
347 |
-
|
348 |
-
|
349 |
-
|
350 |
-
|
351 |
-
|
352 |
-
|
353 |
-
|
354 |
-
|
355 |
-
|
356 |
-
|
357 |
-
out = out[:, :depth]
|
358 |
-
return out
|
359 |
-
|
360 |
-
|
361 |
-
def _realign_emit_pointer_to_bar(self, sr_model: int):
|
362 |
-
"""Advance _next_emit_start to the next bar boundary in model-sample space."""
|
363 |
-
bar_samps = int(round(self._seconds_per_bar() * sr_model))
|
364 |
-
if bar_samps <= 0:
|
365 |
-
return
|
366 |
-
phase = self._next_emit_start % bar_samps
|
367 |
-
if phase != 0:
|
368 |
-
self._next_emit_start += (bar_samps - phase)
|
369 |
-
|
370 |
-
def _prepare_stream_for_reseed_handoff(self):
|
371 |
-
# OLD: keep crossfade tail -> causes phase offset
|
372 |
-
# sr = int(self.mrt.sample_rate)
|
373 |
-
# xfade_s = float(self.mrt.config.crossfade_length)
|
374 |
-
# xfade_n = int(round(xfade_s * sr))
|
375 |
-
# if getattr(self, "_stream", None) is not None and self._stream.shape[0] > 0:
|
376 |
-
# tail = self._stream[-xfade_n:] if self._stream.shape[0] > xfade_n else self._stream
|
377 |
-
# self._stream = tail.copy()
|
378 |
-
# else:
|
379 |
-
# self._stream = None
|
380 |
-
|
381 |
-
# NEW: throw away the tail completely; start fresh
|
382 |
-
self._stream = None
|
383 |
-
|
384 |
-
self._next_emit_start = 0
|
385 |
-
self._needs_bar_realign = True
|
386 |
-
|
387 |
-
def reseed_splice(self, recent_wav, anchor_bars: float):
|
388 |
-
"""
|
389 |
-
Token-splice reseed queued for the next bar boundary between chunks.
|
390 |
-
"""
|
391 |
-
with self._lock:
|
392 |
-
if not hasattr(self, "_original_context_tokens") or self._original_context_tokens is None:
|
393 |
-
self._original_context_tokens = np.copy(self.state.context_tokens)
|
394 |
-
|
395 |
-
recent_tokens = self._make_recent_tokens_from_wave(recent_wav) # [T, depth]
|
396 |
-
new_ctx = self._splice_context(self._original_context_tokens, recent_tokens, anchor_bars)
|
397 |
-
|
398 |
-
# Queue it; the run loop will install right after we finish the current slice
|
399 |
-
self._pending_reseed = {"ctx": new_ctx, "ref": recent_wav}
|
400 |
-
|
401 |
-
# install the new context window
|
402 |
-
new_state = self.mrt.init_state()
|
403 |
-
new_state.context_tokens = new_ctx
|
404 |
-
self.state = new_state
|
405 |
-
|
406 |
-
self._prepare_stream_for_reseed_handoff()
|
407 |
-
|
408 |
-
# optional: ask streamer to drop an intro crossfade worth of audio right after reseed
|
409 |
-
self._pending_drop_intro_bars = getattr(self, "_pending_drop_intro_bars", 0) + 1
|
410 |
|
411 |
def run(self):
|
412 |
-
|
413 |
-
|
414 |
-
chunk_secs = self.params.bars_per_chunk * spb
|
415 |
-
xfade = float(self.mrt.config.crossfade_length) # seconds
|
416 |
-
sr = int(self.mrt.sample_rate)
|
417 |
-
chunk_samps = int(round(chunk_secs * sr))
|
418 |
-
|
419 |
-
def _need(first_chunk_extra=False):
|
420 |
-
"""How many more samples we still need in the stream to emit next slice."""
|
421 |
-
have = 0 if getattr(self, "_stream", None) is None else self._stream.shape[0] - getattr(self, "_next_emit_start", 0)
|
422 |
-
want = chunk_samps
|
423 |
-
if first_chunk_extra:
|
424 |
-
# reserve two bars extra so first-chunk onset alignment has material
|
425 |
-
want += int(round(2 * spb * sr))
|
426 |
-
return max(0, want - have)
|
427 |
-
|
428 |
-
def _mono_env(x: np.ndarray, sr: int, win_ms: float = 10.0) -> np.ndarray:
|
429 |
-
if x.ndim == 2: x = x.mean(axis=1)
|
430 |
-
x = np.abs(x).astype(np.float32)
|
431 |
-
w = max(1, int(round(win_ms * 1e-3 * sr)))
|
432 |
-
if w > 1:
|
433 |
-
kern = np.ones(w, dtype=np.float32) / float(w)
|
434 |
-
x = np.convolve(x, kern, mode="same")
|
435 |
-
d = np.diff(x, prepend=x[:1])
|
436 |
-
d[d < 0] = 0.0
|
437 |
-
return d
|
438 |
-
|
439 |
-
def _estimate_first_offset_samples(ref_loop_wav, gen_head_wav, sr: int, spb: float) -> int:
|
440 |
-
"""Tempo-aware first-downbeat offset (positive => model late)."""
|
441 |
-
try:
|
442 |
-
max_ms = int(max(160.0, min(0.25 * spb * 1000.0, 450.0)))
|
443 |
-
ref = ref_loop_wav if ref_loop_wav.sample_rate == sr else ref_loop_wav.resample(sr)
|
444 |
-
n_bar = int(round(spb * sr))
|
445 |
-
ref_tail = ref.samples[-n_bar:, :] if ref.samples.shape[0] >= n_bar else ref.samples
|
446 |
-
gen_head = gen_head_wav.samples[: int(2 * n_bar), :]
|
447 |
-
if ref_tail.size == 0 or gen_head.size == 0:
|
448 |
-
return 0
|
449 |
-
|
450 |
-
# envelopes + z-score
|
451 |
-
import numpy as np
|
452 |
-
def _z(a):
|
453 |
-
m, s = float(a.mean()), float(a.std() or 1.0); return (a - m) / s
|
454 |
-
e_ref = _z(_mono_env(ref_tail, sr)).astype(np.float32)
|
455 |
-
e_gen = _z(_mono_env(gen_head, sr)).astype(np.float32)
|
456 |
-
|
457 |
-
# upsample x4 for finer lag
|
458 |
-
def _upsample(a, r=4):
|
459 |
-
n = len(a); grid = np.arange(n, dtype=np.float32)
|
460 |
-
fine = np.linspace(0, n - 1, num=n * r, dtype=np.float32)
|
461 |
-
return np.interp(fine, grid, a).astype(np.float32)
|
462 |
-
up = 4
|
463 |
-
e_ref_u, e_gen_u = _upsample(e_ref, up), _upsample(e_gen, up)
|
464 |
-
|
465 |
-
max_lag_u = int(round((max_ms / 1000.0) * sr * up))
|
466 |
-
seg = min(len(e_ref_u), len(e_gen_u))
|
467 |
-
e_ref_u = e_ref_u[-seg:]
|
468 |
-
pad = np.zeros(max_lag_u, dtype=np.float32)
|
469 |
-
e_gen_u_pad = np.concatenate([pad, e_gen_u, pad])
|
470 |
-
|
471 |
-
best_lag_u, best_score = 0, -1e9
|
472 |
-
for lag_u in range(-max_lag_u, max_lag_u + 1):
|
473 |
-
start = max_lag_u + lag_u
|
474 |
-
b = e_gen_u_pad[start : start + seg]
|
475 |
-
denom = (np.linalg.norm(e_ref_u) * np.linalg.norm(b)) or 1.0
|
476 |
-
score = float(np.dot(e_ref_u, b) / denom)
|
477 |
-
if score > best_score:
|
478 |
-
best_score, best_lag_u = score, lag_u
|
479 |
-
return int(round(best_lag_u / up))
|
480 |
-
except Exception:
|
481 |
-
return 0
|
482 |
-
|
483 |
-
print("🚀 JamWorker started (bar-aligned streaming)…")
|
484 |
-
|
485 |
-
while not self._stop_event.is_set():
|
486 |
-
if not self._should_generate_next_chunk():
|
487 |
-
time.sleep(0.25)
|
488 |
-
continue
|
489 |
|
490 |
-
|
491 |
-
|
492 |
-
|
493 |
-
|
494 |
-
|
495 |
-
|
496 |
-
|
497 |
-
self.mrt.topk = int(self.params.topk)
|
498 |
-
wav, self.state = self.mrt.generate_chunk(state=self.state, style=style_vec)
|
499 |
-
self._append_model_chunk_to_stream(wav) # equal-power xfade into a persistent stream
|
500 |
-
need = _need(first_chunk_extra=(self.idx == 0))
|
501 |
-
|
502 |
-
if self._stop_event.is_set():
|
503 |
-
break
|
504 |
-
|
505 |
-
# 2) One-time: align the emit pointer to the groove
|
506 |
-
if (self.idx == 0 and self.params.combined_loop is not None) or self._needs_bar_realign:
|
507 |
-
ref_loop = self._reseed_ref_loop or self.params.combined_loop
|
508 |
-
if ref_loop is not None:
|
509 |
-
head_len = min(self._stream.shape[0] - self._next_emit_start, int(round(2 * spb * sr)))
|
510 |
-
seg = self._stream[self._next_emit_start : self._next_emit_start + head_len]
|
511 |
-
gen_head = au.Waveform(seg.astype(np.float32, copy=False), sr).as_stereo()
|
512 |
-
offs = _estimate_first_offset_samples(ref_loop, gen_head, sr, spb)
|
513 |
-
if offs != 0:
|
514 |
-
self._next_emit_start = max(0, self._next_emit_start + offs)
|
515 |
-
print(f"🎯 Offset compensation: {offs/sr:+.3f}s")
|
516 |
-
self._realign_emit_pointer_to_bar(sr)
|
517 |
-
self._needs_bar_realign = False
|
518 |
-
self._reseed_ref_loop = None
|
519 |
-
|
520 |
-
# 3) Emit exactly bars_per_chunk × spb from the stream
|
521 |
-
start = self._next_emit_start
|
522 |
-
end = start + chunk_samps
|
523 |
-
if end > self._stream.shape[0]:
|
524 |
-
# shouldn't happen often; generate a bit more and loop
|
525 |
continue
|
526 |
|
527 |
-
|
528 |
-
self.
|
529 |
-
|
530 |
-
|
531 |
-
|
532 |
-
|
533 |
-
|
534 |
-
|
535 |
-
|
536 |
-
|
537 |
-
|
538 |
-
|
539 |
-
|
540 |
-
|
541 |
-
|
542 |
-
# 5) Resample + exact-length snap + encode
|
543 |
-
b64, meta = self._snap_and_encode(
|
544 |
-
y, seconds=chunk_secs, target_sr=self.params.target_sr, bars=self.params.bars_per_chunk
|
545 |
-
)
|
546 |
-
meta["xfade_seconds"] = xfade
|
547 |
-
|
548 |
-
# 6) Publish
|
549 |
-
with self._lock:
|
550 |
-
self.idx += 1
|
551 |
-
self.outbox.append(JamChunk(index=self.idx, audio_base64=b64, metadata=meta))
|
552 |
-
if len(self.outbox) > 10:
|
553 |
-
cutoff = self._last_delivered_index - 5
|
554 |
-
self.outbox = [ch for ch in self.outbox if ch.index > cutoff]
|
555 |
-
|
556 |
-
# 👉 If a reseed was requested, apply it *now*, between chunks
|
557 |
-
if self._pending_reseed is not None:
|
558 |
-
pkg = self._pending_reseed
|
559 |
-
self._pending_reseed = None
|
560 |
-
|
561 |
-
new_state = self.mrt.init_state()
|
562 |
-
new_state.context_tokens = pkg["ctx"] # exact (ctx_frames, depth)
|
563 |
-
self.state = new_state
|
564 |
-
|
565 |
-
# start a fresh stream and schedule one-time alignment
|
566 |
-
self._stream = None
|
567 |
-
self._next_emit_start = 0
|
568 |
-
self._reseed_ref_loop = pkg.get("ref") or self.params.combined_loop
|
569 |
-
self._needs_bar_realign = True
|
570 |
-
|
571 |
-
print("🔁 Reseed installed at bar boundary; will realign before next slice")
|
572 |
-
|
573 |
-
print(f"✅ Completed chunk {self.idx}")
|
574 |
-
|
575 |
-
print("🛑 JamWorker stopped")
|
576 |
-
|
|
|
1 |
+
# jam_worker.py - Bar-locked spool rewrite
|
2 |
+
from __future__ import annotations
|
3 |
+
|
4 |
+
import threading, time
|
5 |
+
from dataclasses import dataclass
|
6 |
+
from fractions import Fraction
|
7 |
+
from typing import Optional, Dict, Tuple, List
|
8 |
+
|
9 |
import numpy as np
|
|
|
10 |
from magenta_rt import audio as au
|
11 |
+
|
12 |
from utils import (
|
13 |
+
StreamingResampler,
|
14 |
+
match_loudness_to_reference,
|
15 |
+
make_bar_aligned_context,
|
16 |
+
take_bar_aligned_tail,
|
17 |
+
wav_bytes_base64,
|
18 |
)
|
19 |
|
20 |
+
# -----------------------------
|
21 |
+
# Data classes
|
22 |
+
# -----------------------------
|
23 |
+
|
24 |
@dataclass
|
25 |
class JamParams:
|
26 |
bpm: float
|
|
|
29 |
target_sr: int
|
30 |
loudness_mode: str = "auto"
|
31 |
headroom_db: float = 1.0
|
32 |
+
style_vec: Optional[np.ndarray] = None
|
33 |
+
ref_loop: Optional[au.Waveform] = None
|
34 |
+
combined_loop: Optional[au.Waveform] = None
|
35 |
guidance_weight: float = 1.1
|
36 |
temperature: float = 1.1
|
37 |
topk: int = 40
|
38 |
|
39 |
+
|
40 |
@dataclass
|
41 |
class JamChunk:
|
42 |
index: int
|
43 |
audio_base64: str
|
44 |
metadata: dict
|
45 |
|
|
|
|
|
|
|
|
|
|
|
|
|
46 |
|
47 |
+
# -----------------------------
|
48 |
+
# Helpers
|
49 |
+
# -----------------------------
|
50 |
|
51 |
+
class BarClock:
|
52 |
+
"""Sample-domain bar clock with drift-free absolute boundaries."""
|
53 |
+
def __init__(self, target_sr: int, bpm: float, beats_per_bar: int, base_offset_samples: int = 0):
|
54 |
+
self.sr = int(target_sr)
|
55 |
+
self.bpm = Fraction(str(bpm)) # exact decimal to avoid FP drift
|
56 |
+
self.beats_per_bar = int(beats_per_bar)
|
57 |
+
self.bar_samps = Fraction(self.sr * 60 * self.beats_per_bar, 1) / self.bpm
|
58 |
+
self.base = int(base_offset_samples)
|
59 |
|
60 |
+
def bounds_for_chunk(self, chunk_index: int, bars_per_chunk: int) -> Tuple[int, int]:
|
61 |
+
start_f = self.base + self.bar_samps * (chunk_index * bars_per_chunk)
|
62 |
+
end_f = self.base + self.bar_samps * ((chunk_index + 1) * bars_per_chunk)
|
63 |
+
return int(round(start_f)), int(round(end_f))
|
64 |
|
65 |
+
def seconds_per_bar(self) -> float:
|
66 |
+
return float(self.beats_per_bar) * (60.0 / float(self.bpm))
|
67 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
68 |
|
69 |
+
# -----------------------------
|
70 |
+
# Worker
|
71 |
+
# -----------------------------
|
72 |
|
73 |
+
class JamWorker(threading.Thread):
|
74 |
+
"""Generates continuous audio with MagentaRT, spools it at target SR,
|
75 |
+
and emits *sample-accurate*, bar-aligned chunks (no FPS drift)."""
|
76 |
|
77 |
+
def __init__(self, mrt, params: JamParams):
|
78 |
+
super().__init__(daemon=True)
|
79 |
+
self.mrt = mrt
|
80 |
+
self.params = params
|
81 |
|
82 |
+
# generation state
|
83 |
+
self.state = self.mrt.init_state()
|
84 |
+
self.mrt.guidance_weight = float(self.params.guidance_weight)
|
85 |
+
self.mrt.temperature = float(self.params.temperature)
|
86 |
+
self.mrt.topk = int(self.params.topk)
|
87 |
|
88 |
+
# style vector (already normalized upstream)
|
89 |
+
self._style_vec = self.params.style_vec
|
|
|
|
|
|
|
|
|
90 |
|
91 |
+
# codec/setup
|
92 |
+
self._codec_fps = float(self.mrt.codec.frame_rate)
|
93 |
+
self._ctx_frames = int(self.mrt.config.context_length_frames)
|
94 |
+
self._ctx_seconds = self._ctx_frames / self._codec_fps
|
95 |
|
96 |
+
# model stream (model SR) for internal continuity/crossfades
|
97 |
+
self._model_stream: Optional[np.ndarray] = None
|
98 |
+
self._model_sr = int(self.mrt.sample_rate)
|
|
|
|
|
|
|
|
|
99 |
|
100 |
+
# target-SR in-RAM spool (what we cut loops from)
|
101 |
+
self._rs = StreamingResampler(self._model_sr, int(self.params.target_sr), channels=2)
|
102 |
+
self._spool = np.zeros((0, 2), dtype=np.float32) # (S,2) target SR
|
103 |
+
self._spool_written = 0 # absolute frames written into spool
|
104 |
|
105 |
+
# bar clock: start with offset 0; if you have a downbeat estimator, set base later
|
106 |
+
self._bar_clock = BarClock(self.params.target_sr, self.params.bpm, self.params.beats_per_bar, base_offset_samples=0)
|
|
|
|
|
|
|
107 |
|
108 |
+
# emission counters
|
109 |
+
self.idx = 0 # next chunk index to *produce*
|
110 |
+
self._next_to_deliver = 0 # next chunk index to hand out via get_next_chunk()
|
111 |
+
self._last_consumed_index = -1 # updated via mark_chunk_consumed(); generation throttle uses this
|
112 |
|
113 |
+
# outbox and synchronization
|
114 |
+
self._outbox: Dict[int, JamChunk] = {}
|
115 |
+
self._cv = threading.Condition()
|
116 |
|
117 |
+
# control flags
|
118 |
+
self._stop = threading.Event()
|
119 |
+
self._max_buffer_ahead = 5
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
120 |
|
121 |
+
# reseed queue (install at next safe point)
|
122 |
+
self._pending_reseed: Optional[dict] = None
|
|
|
|
|
|
|
123 |
|
124 |
+
# Prepare initial context from combined loop (best musical alignment)
|
125 |
+
if self.params.combined_loop is not None:
|
126 |
+
self._install_context_from_loop(self.params.combined_loop)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
127 |
|
128 |
+
# ---------- lifecycle ----------
|
|
|
|
|
|
|
|
|
129 |
|
130 |
+
def stop(self):
|
131 |
+
self._stop.set()
|
132 |
+
|
133 |
+
# FastAPI reads this to block until the next sequential chunk is ready
|
134 |
+
def get_next_chunk(self, timeout: float = 30.0) -> Optional[JamChunk]:
|
135 |
+
deadline = time.time() + timeout
|
136 |
+
with self._cv:
|
137 |
+
while True:
|
138 |
+
c = self._outbox.get(self._next_to_deliver)
|
139 |
+
if c is not None:
|
140 |
+
self._next_to_deliver += 1
|
141 |
+
return c
|
142 |
+
remaining = deadline - time.time()
|
143 |
+
if remaining <= 0:
|
144 |
+
return None
|
145 |
+
self._cv.wait(timeout=min(0.25, remaining))
|
146 |
|
147 |
+
def mark_chunk_consumed(self, chunk_index: int):
|
148 |
+
# This lets the generator run ahead, but not too far
|
149 |
+
with self._cv:
|
150 |
+
self._last_consumed_index = max(self._last_consumed_index, int(chunk_index))
|
151 |
+
# purge old chunks to cap memory
|
152 |
+
for k in list(self._outbox.keys()):
|
153 |
+
if k < self._last_consumed_index - 1:
|
154 |
+
self._outbox.pop(k, None)
|
155 |
|
156 |
+
def update_knobs(self, *, guidance_weight=None, temperature=None, topk=None):
|
157 |
+
if guidance_weight is not None:
|
158 |
+
self.params.guidance_weight = float(guidance_weight)
|
159 |
+
if temperature is not None:
|
160 |
+
self.params.temperature = float(temperature)
|
161 |
+
if topk is not None:
|
162 |
+
self.params.topk = int(topk)
|
163 |
+
# push into mrt (thread-safe enough for our use)
|
164 |
+
self.mrt.guidance_weight = float(self.params.guidance_weight)
|
165 |
+
self.mrt.temperature = float(self.params.temperature)
|
166 |
+
self.mrt.topk = int(self.params.topk)
|
167 |
+
|
168 |
+
# ---------- context / reseed ----------
|
169 |
+
|
170 |
+
def _install_context_from_loop(self, loop: au.Waveform):
|
171 |
+
# Build a bar-aligned tail and encode to context tokens
|
172 |
+
loop = loop.as_stereo().resample(self._model_sr)
|
173 |
+
tail = take_bar_aligned_tail(loop, self.params.bpm, self.params.beats_per_bar, self._ctx_seconds)
|
174 |
+
tokens_full = self.mrt.codec.encode(tail).astype(np.int32)
|
175 |
+
depth = int(self.mrt.config.decoder_codec_rvq_depth)
|
176 |
+
context_tokens = tokens_full[:, :depth]
|
177 |
|
178 |
+
# install state
|
179 |
+
s = self.mrt.init_state()
|
180 |
+
s.context_tokens = context_tokens
|
181 |
+
self.state = s
|
182 |
|
183 |
+
# keep an original copy for future splices
|
184 |
+
self._original_context_tokens = np.copy(context_tokens)
|
|
|
|
|
185 |
|
186 |
+
def reseed_from_waveform(self, wav: au.Waveform):
|
187 |
+
"""Immediate reseed: replace context from provided wave (bar-aligned tail)."""
|
188 |
+
wav = wav.as_stereo().resample(self._model_sr)
|
189 |
+
tail = take_bar_aligned_tail(wav, self.params.bpm, self.params.beats_per_bar, self._ctx_seconds)
|
190 |
tokens_full = self.mrt.codec.encode(tail).astype(np.int32)
|
191 |
+
depth = int(self.mrt.config.decoder_codec_rvq_depth)
|
192 |
+
context_tokens = tokens_full[:, :depth]
|
193 |
+
|
194 |
+
s = self.mrt.init_state()
|
195 |
+
s.context_tokens = context_tokens
|
196 |
+
self.state = s
|
197 |
+
# reset model stream so next generate starts cleanly
|
198 |
+
self._model_stream = None
|
199 |
+
|
200 |
+
# optional loudness match will be applied per-chunk on emission
|
201 |
+
|
202 |
+
# also remember this as new "original"
|
203 |
+
self._original_context_tokens = np.copy(context_tokens)
|
204 |
+
|
205 |
+
def reseed_splice(self, recent_wav: au.Waveform, anchor_bars: float):
|
206 |
+
"""Queue a splice reseed to be applied right after the next emitted loop.
|
207 |
+
For now, we simply replace the context by recent wave tail; anchor is accepted
|
208 |
+
for API compatibility and future crossfade/token-splice logic."""
|
209 |
+
recent_wav = recent_wav.as_stereo().resample(self._model_sr)
|
210 |
+
tail = take_bar_aligned_tail(recent_wav, self.params.bpm, self.params.beats_per_bar, self._ctx_seconds)
|
211 |
+
tokens_full = self.mrt.codec.encode(tail).astype(np.int32)
|
212 |
+
depth = int(self.mrt.config.decoder_codec_rvq_depth)
|
213 |
+
new_ctx = tokens_full[:, :depth]
|
214 |
+
self._pending_reseed = {"ctx": new_ctx}
|
215 |
+
|
216 |
+
# ---------- core streaming helpers ----------
|
217 |
+
|
218 |
+
def _append_model_chunk_and_spool(self, wav: au.Waveform):
|
219 |
+
"""Crossfade into the model-rate stream and write the *non-overlapped*
|
220 |
+
tail to the target-SR spool."""
|
221 |
+
s = wav.samples.astype(np.float32, copy=False)
|
222 |
+
if s.ndim == 1:
|
223 |
+
s = s[:, None]
|
224 |
+
sr = self._model_sr
|
225 |
+
xfade_s = float(self.mrt.config.crossfade_length)
|
226 |
+
xfade_n = int(round(max(0.0, xfade_s) * sr))
|
227 |
+
|
228 |
+
if self._model_stream is None:
|
229 |
+
# first chunk: drop the preroll (xfade) then spool
|
230 |
+
new_part = s[xfade_n:] if xfade_n < s.shape[0] else s[:0]
|
231 |
+
self._model_stream = new_part.copy()
|
232 |
+
if new_part.size:
|
233 |
+
y = self._rs.process(new_part, final=False)
|
234 |
+
self._spool = np.concatenate([self._spool, y], axis=0)
|
235 |
+
self._spool_written += y.shape[0]
|
236 |
+
return
|
237 |
+
|
238 |
+
# crossfade into existing stream
|
239 |
+
if xfade_n > 0 and self._model_stream.shape[0] >= xfade_n and s.shape[0] >= xfade_n:
|
240 |
+
tail = self._model_stream[-xfade_n:]
|
241 |
+
head = s[:xfade_n]
|
242 |
+
t = np.linspace(0, np.pi/2, xfade_n, endpoint=False, dtype=np.float32)[:, None]
|
243 |
+
mixed = tail * np.cos(t) + head * np.sin(t)
|
244 |
+
self._model_stream = np.concatenate([self._model_stream[:-xfade_n], mixed, s[xfade_n:]], axis=0)
|
245 |
+
new_part = s[xfade_n:]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
246 |
else:
|
247 |
+
self._model_stream = np.concatenate([self._model_stream, s], axis=0)
|
248 |
+
new_part = s
|
249 |
|
250 |
+
# spool only the *new* non-overlapped part
|
251 |
+
if new_part.size:
|
252 |
+
y = self._rs.process(new_part.astype(np.float32, copy=False), final=False)
|
253 |
+
if y.size:
|
254 |
+
self._spool = np.concatenate([self._spool, y], axis=0)
|
255 |
+
self._spool_written += y.shape[0]
|
256 |
|
257 |
+
def _should_generate_next_chunk(self) -> bool:
|
258 |
+
# Don't let generation run too far ahead of consumption
|
259 |
+
return self.idx <= (self._last_consumed_index + self._max_buffer_ahead)
|
260 |
+
|
261 |
+
def _emit_ready(self):
|
262 |
+
"""Emit next chunk(s) if the spool has enough samples."""
|
263 |
+
while True:
|
264 |
+
start, end = self._bar_clock.bounds_for_chunk(self.idx, self.params.bars_per_chunk)
|
265 |
+
if end > self._spool_written:
|
266 |
+
break # need more audio
|
267 |
+
loop = self._spool[start:end]
|
268 |
+
|
269 |
+
# Loudness match to reference loop (optional)
|
270 |
+
if self.params.ref_loop is not None and self.params.loudness_mode != "none":
|
271 |
+
ref = self.params.ref_loop.as_stereo().resample(self.params.target_sr)
|
272 |
+
wav = au.Waveform(loop.copy(), int(self.params.target_sr))
|
273 |
+
matched, _ = match_loudness_to_reference(ref, wav, method=self.params.loudness_mode, headroom_db=self.params.headroom_db)
|
274 |
+
loop = matched.samples
|
275 |
+
|
276 |
+
audio_b64, total_samples, channels = wav_bytes_base64(loop, int(self.params.target_sr))
|
277 |
+
meta = {
|
278 |
+
"bpm": float(self.params.bpm),
|
279 |
+
"bars": int(self.params.bars_per_chunk),
|
280 |
+
"beats_per_bar": int(self.params.beats_per_bar),
|
281 |
+
"sample_rate": int(self.params.target_sr),
|
282 |
+
"channels": int(channels),
|
283 |
+
"total_samples": int(total_samples),
|
284 |
+
"seconds_per_bar": self._bar_clock.seconds_per_bar(),
|
285 |
+
"loop_duration_seconds": self.params.bars_per_chunk * self._bar_clock.seconds_per_bar(),
|
286 |
+
"guidance_weight": float(self.params.guidance_weight),
|
287 |
+
"temperature": float(self.params.temperature),
|
288 |
+
"topk": int(self.params.topk),
|
289 |
+
}
|
290 |
+
chunk = JamChunk(index=self.idx, audio_base64=audio_b64, metadata=meta)
|
291 |
+
|
292 |
+
with self._cv:
|
293 |
+
self._outbox[self.idx] = chunk
|
294 |
+
self._cv.notify_all()
|
295 |
+
self.idx += 1
|
296 |
+
|
297 |
+
# If a reseed is queued, install it *right after* we finish a chunk
|
298 |
+
if self._pending_reseed is not None:
|
299 |
+
new_state = self.mrt.init_state()
|
300 |
+
new_state.context_tokens = self._pending_reseed["ctx"]
|
301 |
+
self.state = new_state
|
302 |
+
self._model_stream = None # drop model-domain continuity so next chunk starts clean
|
303 |
+
self._pending_reseed = None
|
304 |
+
|
305 |
+
# ---------- main loop ----------
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
306 |
|
307 |
def run(self):
|
308 |
+
# set style vector if present
|
309 |
+
style_vec = self._style_vec
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
310 |
|
311 |
+
# generate until stopped
|
312 |
+
while not self._stop.is_set():
|
313 |
+
# throttle generation if we are far ahead
|
314 |
+
if not self._should_generate_next_chunk():
|
315 |
+
# still try to emit if spool already has enough
|
316 |
+
self._emit_ready()
|
317 |
+
time.sleep(0.01)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
318 |
continue
|
319 |
|
320 |
+
# generate next model chunk
|
321 |
+
wav, self.state = self.mrt.generate_chunk(state=self.state, style=style_vec)
|
322 |
+
# append and spool
|
323 |
+
self._append_model_chunk_and_spool(wav)
|
324 |
+
# try emitting zero or more chunks if available
|
325 |
+
self._emit_ready()
|
326 |
+
|
327 |
+
# finalize resampler (flush) — not strictly necessary here
|
328 |
+
tail = self._rs.process(np.zeros((0,2), np.float32), final=True)
|
329 |
+
if tail.size:
|
330 |
+
self._spool = np.concatenate([self._spool, tail], axis=0)
|
331 |
+
self._spool_written += tail.shape[0]
|
332 |
+
# one last emit attempt
|
333 |
+
self._emit_ready()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|