thecollabagepatch commited on
Commit
7ae6392
·
1 Parent(s): 7fe8be5

a shot in the dark

Browse files
Files changed (1) hide show
  1. jam_worker.py +275 -518
jam_worker.py CHANGED
@@ -1,16 +1,26 @@
1
- # jam_worker.py - SIMPLE FIX VERSION
2
- import threading, time, base64, io, uuid
3
- from dataclasses import dataclass, field
 
 
 
 
 
4
  import numpy as np
5
- import soundfile as sf
6
  from magenta_rt import audio as au
7
- from threading import RLock
8
  from utils import (
9
- match_loudness_to_reference, stitch_generated, hard_trim_seconds,
10
- apply_micro_fades, make_bar_aligned_context, take_bar_aligned_tail,
11
- resample_and_snap, wav_bytes_base64
 
 
12
  )
13
 
 
 
 
 
14
  @dataclass
15
  class JamParams:
16
  bpm: float
@@ -19,558 +29,305 @@ class JamParams:
19
  target_sr: int
20
  loudness_mode: str = "auto"
21
  headroom_db: float = 1.0
22
- style_vec: np.ndarray | None = None
23
- ref_loop: any = None
24
- combined_loop: any = None
25
  guidance_weight: float = 1.1
26
  temperature: float = 1.1
27
  topk: int = 40
28
 
 
29
  @dataclass
30
  class JamChunk:
31
  index: int
32
  audio_base64: str
33
  metadata: dict
34
 
35
- class JamWorker(threading.Thread):
36
- def __init__(self, mrt, params: JamParams):
37
- super().__init__(daemon=True)
38
- self.mrt = mrt
39
- self.params = params
40
- self.state = mrt.init_state()
41
 
42
- # ✅ init synchronization + placeholders FIRST
43
- self._lock = threading.Lock()
44
- self._original_context_tokens = None # so hasattr checks are cheap/clear
45
 
46
- if params.combined_loop is not None:
47
- self._setup_context_from_combined_loop()
 
 
 
 
 
 
48
 
49
- self.idx = 0
50
- self.outbox: list[JamChunk] = []
51
- self._stop_event = threading.Event()
 
52
 
53
- self._stream = None
54
- self._next_emit_start = 0
55
 
56
- # NEW: Track delivery state
57
- self._last_delivered_index = 0
58
- self._max_buffer_ahead = 5
59
-
60
- # Timing info
61
- self.last_chunk_started_at = None
62
- self.last_chunk_completed_at = None
63
 
64
- self._pending_reseed = None # {"ctx": np.ndarray, "ref": au.Waveform|None}
65
- self._needs_bar_realign = False # request a one-shot downbeat alignment
66
- self._reseed_ref_loop = None # which loop to align against after reseed
67
 
 
 
 
68
 
69
- def _setup_context_from_combined_loop(self):
70
- """Set up MRT context tokens from the combined loop audio"""
71
- try:
72
- from utils import make_bar_aligned_context, take_bar_aligned_tail
73
 
74
- codec_fps = float(self.mrt.codec.frame_rate)
75
- ctx_seconds = float(self.mrt.config.context_length_frames) / codec_fps
 
 
 
76
 
77
- loop_for_context = take_bar_aligned_tail(
78
- self.params.combined_loop,
79
- self.params.bpm,
80
- self.params.beats_per_bar,
81
- ctx_seconds
82
- )
83
 
84
- tokens_full = self.mrt.codec.encode(loop_for_context).astype(np.int32)
85
- tokens = tokens_full[:, :self.mrt.config.decoder_codec_rvq_depth]
 
 
86
 
87
- context_tokens = make_bar_aligned_context(
88
- tokens,
89
- bpm=self.params.bpm,
90
- fps=float(self.mrt.codec.frame_rate), # keep fractional fps
91
- ctx_frames=self.mrt.config.context_length_frames,
92
- beats_per_bar=self.params.beats_per_bar
93
- )
94
 
95
- # Install fresh context
96
- self.state.context_tokens = context_tokens
97
- print(f"✅ JamWorker: Set up fresh context from combined loop")
 
98
 
99
- # NEW: keep a copy of the *original* context tokens for future splice-reseed
100
- # (guard so we only set this once, at jam start)
101
- with self._lock:
102
- if not hasattr(self, "_original_context_tokens") or self._original_context_tokens is None:
103
- self._original_context_tokens = np.copy(context_tokens) # shape: [T, depth]
104
 
105
- except Exception as e:
106
- print(f"❌ Failed to setup context from combined loop: {e}")
 
 
107
 
108
- def stop(self):
109
- self._stop_event.set()
 
110
 
111
- def update_knobs(self, *, guidance_weight=None, temperature=None, topk=None):
112
- with self._lock:
113
- if guidance_weight is not None: self.params.guidance_weight = float(guidance_weight)
114
- if temperature is not None: self.params.temperature = float(temperature)
115
- if topk is not None: self.params.topk = int(topk)
116
-
117
- def get_next_chunk(self) -> JamChunk | None:
118
- """Get the next sequential chunk (blocks/waits if not ready)"""
119
- target_index = self._last_delivered_index + 1
120
-
121
- # Wait for the target chunk to be ready (with timeout)
122
- max_wait = 30.0 # seconds
123
- start_time = time.time()
124
-
125
- while time.time() - start_time < max_wait and not self._stop_event.is_set():
126
- with self._lock:
127
- # Look for the exact chunk we need
128
- for chunk in self.outbox:
129
- if chunk.index == target_index:
130
- self._last_delivered_index = target_index
131
- print(f"📦 Delivered chunk {target_index}")
132
- return chunk
133
-
134
- # Not ready yet, wait a bit
135
- time.sleep(0.1)
136
-
137
- # Timeout or stopped
138
- return None
139
 
140
- def mark_chunk_consumed(self, chunk_index: int):
141
- """Mark a chunk as consumed by the frontend"""
142
- with self._lock:
143
- self._last_delivered_index = max(self._last_delivered_index, chunk_index)
144
- print(f"✅ Chunk {chunk_index} consumed")
145
 
146
- def _should_generate_next_chunk(self) -> bool:
147
- """Check if we should generate the next chunk (don't get too far ahead)"""
148
- with self._lock:
149
- # Don't generate if we're already too far ahead
150
- if self.idx > self._last_delivered_index + self._max_buffer_ahead:
151
- return False
152
- return True
153
-
154
- def _seconds_per_bar(self) -> float:
155
- return self.params.beats_per_bar * (60.0 / self.params.bpm)
156
-
157
- def _snap_and_encode(self, y, seconds, target_sr, bars):
158
- cur_sr = int(self.mrt.sample_rate)
159
- x = y.samples if y.samples.ndim == 2 else y.samples[:, None]
160
- x = resample_and_snap(x, cur_sr=cur_sr, target_sr=target_sr, seconds=seconds)
161
- b64, total_samples, channels = wav_bytes_base64(x, target_sr)
162
- meta = {
163
- "bpm": int(round(self.params.bpm)),
164
- "bars": int(bars),
165
- "beats_per_bar": int(self.params.beats_per_bar),
166
- "sample_rate": int(target_sr),
167
- "channels": channels,
168
- "total_samples": total_samples,
169
- "seconds_per_bar": self._seconds_per_bar(),
170
- "loop_duration_seconds": bars * self._seconds_per_bar(),
171
- "guidance_weight": self.params.guidance_weight,
172
- "temperature": self.params.temperature,
173
- "topk": self.params.topk,
174
- }
175
- return b64, meta
176
-
177
- def _append_model_chunk_to_stream(self, wav):
178
- """Incrementally append a model chunk with equal-power crossfade."""
179
- xfade_s = float(self.mrt.config.crossfade_length)
180
- sr = int(self.mrt.sample_rate)
181
- xfade_n = int(round(xfade_s * sr))
182
-
183
- s = wav.samples if wav.samples.ndim == 2 else wav.samples[:, None]
184
-
185
- if getattr(self, "_stream", None) is None:
186
- # First chunk: drop model pre-roll (xfade head)
187
- if s.shape[0] > xfade_n:
188
- self._stream = s[xfade_n:].astype(np.float32, copy=True)
189
- else:
190
- self._stream = np.zeros((0, s.shape[1]), dtype=np.float32)
191
- self._next_emit_start = 0 # pointer into _stream (model SR samples)
192
- return
193
 
194
- # Crossfade last xfade_n samples of _stream with head of new s
195
- if s.shape[0] <= xfade_n or self._stream.shape[0] < xfade_n:
196
- # Degenerate safeguard
197
- self._stream = np.concatenate([self._stream, s], axis=0)
198
- return
199
 
200
- tail = self._stream[-xfade_n:]
201
- head = s[:xfade_n]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
202
 
203
- # Equal-power envelopes
204
- t = np.linspace(0, np.pi/2, xfade_n, endpoint=False, dtype=np.float32)[:, None]
205
- eq_in, eq_out = np.sin(t), np.cos(t)
206
- mixed = tail * eq_out + head * eq_in
 
 
 
 
207
 
208
- self._stream = np.concatenate([self._stream[:-xfade_n], mixed, s[xfade_n:]], axis=0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
209
 
210
- def reseed_from_waveform(self, wav):
211
- # 1) Re-init state
212
- new_state = self.mrt.init_state()
 
213
 
214
- # 2) Build bar-aligned context tokens from provided audio
215
- codec_fps = float(self.mrt.codec.frame_rate)
216
- ctx_seconds = float(self.mrt.config.context_length_frames) / codec_fps
217
- from utils import take_bar_aligned_tail, make_bar_aligned_context
218
 
219
- tail = take_bar_aligned_tail(wav, self.params.bpm, self.params.beats_per_bar, ctx_seconds)
 
 
 
220
  tokens_full = self.mrt.codec.encode(tail).astype(np.int32)
221
- tokens = tokens_full[:, :self.mrt.config.decoder_codec_rvq_depth]
222
- context_tokens = make_bar_aligned_context(tokens,
223
- bpm=self.params.bpm, fps=float(self.mrt.codec.frame_rate),
224
- ctx_frames=self.mrt.config.context_length_frames,
225
- beats_per_bar=self.params.beats_per_bar
226
- )
227
- new_state.context_tokens = context_tokens
228
- self.state = new_state
229
- self._prepare_stream_for_reseed_handoff()
230
-
231
- def _frames_per_bar(self) -> int:
232
- # codec frame-rate (frames/s) -> frames per musical bar
233
- fps = float(self.mrt.codec.frame_rate)
234
- sec_per_bar = (60.0 / float(self.params.bpm)) * float(self.params.beats_per_bar)
235
- return int(round(fps * sec_per_bar))
236
-
237
- def _ctx_frames(self) -> int:
238
- # how many codec frames fit in the model’s conditioning window
239
- return int(self.mrt.config.context_length_frames)
240
-
241
- def _make_recent_tokens_from_wave(self, wav) -> np.ndarray:
242
- """
243
- Encode waveform and produce a BAR-ALIGNED context token window.
244
- """
245
- tokens_full = self.mrt.codec.encode(wav).astype(np.int32) # [T, rvq_total]
246
- tokens = tokens_full[:, :self.mrt.config.decoder_codec_rvq_depth]
247
-
248
- from utils import make_bar_aligned_context
249
- ctx = make_bar_aligned_context(
250
- tokens,
251
- bpm=self.params.bpm,
252
- fps=float(self.mrt.codec.frame_rate), # keep fractional fps
253
- ctx_frames=self.mrt.config.context_length_frames,
254
- beats_per_bar=self.params.beats_per_bar
255
- )
256
- return ctx
257
-
258
- def _bar_aligned_tail(self, tokens: np.ndarray, bars: float) -> np.ndarray:
259
- """
260
- Take a tail slice that is an integer number of codec frames corresponding to `bars`.
261
- We round to nearest frame to stay phase-consistent with codec grid.
262
- """
263
- frames_per_bar = self._frames_per_bar()
264
- want = max(frames_per_bar * int(round(bars)), 0)
265
- if want == 0:
266
- return tokens[:0] # empty
267
- if tokens.shape[0] <= want:
268
- return tokens
269
- return tokens[-want:]
270
-
271
- def _splice_context(self, original_tokens: np.ndarray, recent_tokens: np.ndarray,
272
- anchor_bars: float) -> np.ndarray:
273
- import math
274
- ctx_frames = self._ctx_frames()
275
- depth = original_tokens.shape[1]
276
- frames_per_bar = self._frames_per_bar()
277
-
278
- # 1) Anchor tail (whole bars)
279
- anchor = self._bar_aligned_tail(original_tokens, math.floor(anchor_bars))
280
-
281
- # 2) Fill remainder with recent (prefer whole bars)
282
- a = anchor.shape[0]
283
- remain = max(ctx_frames - a, 0)
284
-
285
- recent = recent_tokens[:0]
286
- used_recent = 0 # frames taken from the END of recent_tokens
287
- if remain > 0:
288
- bars_fit = remain // frames_per_bar
289
- if bars_fit >= 1:
290
- want_recent_frames = int(bars_fit * frames_per_bar)
291
- used_recent = min(want_recent_frames, recent_tokens.shape[0])
292
- recent = recent_tokens[-used_recent:] if used_recent > 0 else recent_tokens[:0]
293
- else:
294
- used_recent = min(remain, recent_tokens.shape[0])
295
- recent = recent_tokens[-used_recent:] if used_recent > 0 else recent_tokens[:0]
296
-
297
- # 3) Concat in order [anchor, recent]
298
- if anchor.size or recent.size:
299
- out = np.concatenate([anchor, recent], axis=0)
300
  else:
301
- # fallback: just take the last ctx window from recent
302
- out = recent_tokens[-ctx_frames:]
303
 
304
- # 4) Trim if we overshot
305
- if out.shape[0] > ctx_frames:
306
- out = out[-ctx_frames:]
 
 
 
307
 
308
- # 5) Snap the **END** to the nearest LOWER bar boundary
309
- if frames_per_bar > 0:
310
- max_bar_aligned = (out.shape[0] // frames_per_bar) * frames_per_bar
311
- else:
312
- max_bar_aligned = out.shape[0]
313
- if max_bar_aligned > 0 and out.shape[0] != max_bar_aligned:
314
- out = out[-max_bar_aligned:]
315
-
316
- # 6) Left-fill to reach ctx_frames **without moving the END**
317
- deficit = ctx_frames - out.shape[0]
318
- if deficit > 0:
319
- left_parts = []
320
-
321
- # Prefer frames immediately BEFORE the region we used from 'recent_tokens'
322
- if used_recent < recent_tokens.shape[0]:
323
- take = min(deficit, recent_tokens.shape[0] - used_recent)
324
- if used_recent > 0:
325
- left_parts.append(recent_tokens[-(used_recent + take) : -used_recent])
326
- else:
327
- left_parts.append(recent_tokens[-take:])
328
-
329
- # Then take frames immediately BEFORE the 'anchor' in original_tokens
330
- if sum(p.shape[0] for p in left_parts) < deficit and anchor.shape[0] > 0:
331
- need = deficit - sum(p.shape[0] for p in left_parts)
332
- a_len = anchor.shape[0]
333
- avail = max(original_tokens.shape[0] - a_len, 0)
334
- take2 = min(need, avail)
335
- if take2 > 0:
336
- left_parts.append(original_tokens[-(a_len + take2) : -a_len])
337
-
338
- # Still short? tile from what's available
339
- have = sum(p.shape[0] for p in left_parts)
340
- if have < deficit:
341
- base = out if out.shape[0] > 0 else (recent_tokens if recent_tokens.shape[0] > 0 else original_tokens)
342
- reps = int(np.ceil((deficit - have) / max(1, base.shape[0])))
343
- left_parts.append(np.tile(base, (reps, 1))[: (deficit - have)])
344
-
345
- left = np.concatenate(left_parts, axis=0)
346
- out = np.concatenate([left[-deficit:], out], axis=0)
347
-
348
- # 7) Final guard to exact length
349
- if out.shape[0] > ctx_frames:
350
- out = out[-ctx_frames:]
351
- elif out.shape[0] < ctx_frames:
352
- reps = int(np.ceil(ctx_frames / max(1, out.shape[0])))
353
- out = np.tile(out, (reps, 1))[-ctx_frames:]
354
-
355
- # 8) Depth guard
356
- if out.shape[1] != depth:
357
- out = out[:, :depth]
358
- return out
359
-
360
-
361
- def _realign_emit_pointer_to_bar(self, sr_model: int):
362
- """Advance _next_emit_start to the next bar boundary in model-sample space."""
363
- bar_samps = int(round(self._seconds_per_bar() * sr_model))
364
- if bar_samps <= 0:
365
- return
366
- phase = self._next_emit_start % bar_samps
367
- if phase != 0:
368
- self._next_emit_start += (bar_samps - phase)
369
-
370
- def _prepare_stream_for_reseed_handoff(self):
371
- # OLD: keep crossfade tail -> causes phase offset
372
- # sr = int(self.mrt.sample_rate)
373
- # xfade_s = float(self.mrt.config.crossfade_length)
374
- # xfade_n = int(round(xfade_s * sr))
375
- # if getattr(self, "_stream", None) is not None and self._stream.shape[0] > 0:
376
- # tail = self._stream[-xfade_n:] if self._stream.shape[0] > xfade_n else self._stream
377
- # self._stream = tail.copy()
378
- # else:
379
- # self._stream = None
380
-
381
- # NEW: throw away the tail completely; start fresh
382
- self._stream = None
383
-
384
- self._next_emit_start = 0
385
- self._needs_bar_realign = True
386
-
387
- def reseed_splice(self, recent_wav, anchor_bars: float):
388
- """
389
- Token-splice reseed queued for the next bar boundary between chunks.
390
- """
391
- with self._lock:
392
- if not hasattr(self, "_original_context_tokens") or self._original_context_tokens is None:
393
- self._original_context_tokens = np.copy(self.state.context_tokens)
394
-
395
- recent_tokens = self._make_recent_tokens_from_wave(recent_wav) # [T, depth]
396
- new_ctx = self._splice_context(self._original_context_tokens, recent_tokens, anchor_bars)
397
-
398
- # Queue it; the run loop will install right after we finish the current slice
399
- self._pending_reseed = {"ctx": new_ctx, "ref": recent_wav}
400
-
401
- # install the new context window
402
- new_state = self.mrt.init_state()
403
- new_state.context_tokens = new_ctx
404
- self.state = new_state
405
-
406
- self._prepare_stream_for_reseed_handoff()
407
-
408
- # optional: ask streamer to drop an intro crossfade worth of audio right after reseed
409
- self._pending_drop_intro_bars = getattr(self, "_pending_drop_intro_bars", 0) + 1
410
 
411
  def run(self):
412
- """Main worker loop generate into a continuous stream, then emit bar-aligned slices."""
413
- spb = self._seconds_per_bar() # seconds per bar
414
- chunk_secs = self.params.bars_per_chunk * spb
415
- xfade = float(self.mrt.config.crossfade_length) # seconds
416
- sr = int(self.mrt.sample_rate)
417
- chunk_samps = int(round(chunk_secs * sr))
418
-
419
- def _need(first_chunk_extra=False):
420
- """How many more samples we still need in the stream to emit next slice."""
421
- have = 0 if getattr(self, "_stream", None) is None else self._stream.shape[0] - getattr(self, "_next_emit_start", 0)
422
- want = chunk_samps
423
- if first_chunk_extra:
424
- # reserve two bars extra so first-chunk onset alignment has material
425
- want += int(round(2 * spb * sr))
426
- return max(0, want - have)
427
-
428
- def _mono_env(x: np.ndarray, sr: int, win_ms: float = 10.0) -> np.ndarray:
429
- if x.ndim == 2: x = x.mean(axis=1)
430
- x = np.abs(x).astype(np.float32)
431
- w = max(1, int(round(win_ms * 1e-3 * sr)))
432
- if w > 1:
433
- kern = np.ones(w, dtype=np.float32) / float(w)
434
- x = np.convolve(x, kern, mode="same")
435
- d = np.diff(x, prepend=x[:1])
436
- d[d < 0] = 0.0
437
- return d
438
-
439
- def _estimate_first_offset_samples(ref_loop_wav, gen_head_wav, sr: int, spb: float) -> int:
440
- """Tempo-aware first-downbeat offset (positive => model late)."""
441
- try:
442
- max_ms = int(max(160.0, min(0.25 * spb * 1000.0, 450.0)))
443
- ref = ref_loop_wav if ref_loop_wav.sample_rate == sr else ref_loop_wav.resample(sr)
444
- n_bar = int(round(spb * sr))
445
- ref_tail = ref.samples[-n_bar:, :] if ref.samples.shape[0] >= n_bar else ref.samples
446
- gen_head = gen_head_wav.samples[: int(2 * n_bar), :]
447
- if ref_tail.size == 0 or gen_head.size == 0:
448
- return 0
449
-
450
- # envelopes + z-score
451
- import numpy as np
452
- def _z(a):
453
- m, s = float(a.mean()), float(a.std() or 1.0); return (a - m) / s
454
- e_ref = _z(_mono_env(ref_tail, sr)).astype(np.float32)
455
- e_gen = _z(_mono_env(gen_head, sr)).astype(np.float32)
456
-
457
- # upsample x4 for finer lag
458
- def _upsample(a, r=4):
459
- n = len(a); grid = np.arange(n, dtype=np.float32)
460
- fine = np.linspace(0, n - 1, num=n * r, dtype=np.float32)
461
- return np.interp(fine, grid, a).astype(np.float32)
462
- up = 4
463
- e_ref_u, e_gen_u = _upsample(e_ref, up), _upsample(e_gen, up)
464
-
465
- max_lag_u = int(round((max_ms / 1000.0) * sr * up))
466
- seg = min(len(e_ref_u), len(e_gen_u))
467
- e_ref_u = e_ref_u[-seg:]
468
- pad = np.zeros(max_lag_u, dtype=np.float32)
469
- e_gen_u_pad = np.concatenate([pad, e_gen_u, pad])
470
-
471
- best_lag_u, best_score = 0, -1e9
472
- for lag_u in range(-max_lag_u, max_lag_u + 1):
473
- start = max_lag_u + lag_u
474
- b = e_gen_u_pad[start : start + seg]
475
- denom = (np.linalg.norm(e_ref_u) * np.linalg.norm(b)) or 1.0
476
- score = float(np.dot(e_ref_u, b) / denom)
477
- if score > best_score:
478
- best_score, best_lag_u = score, lag_u
479
- return int(round(best_lag_u / up))
480
- except Exception:
481
- return 0
482
-
483
- print("🚀 JamWorker started (bar-aligned streaming)…")
484
-
485
- while not self._stop_event.is_set():
486
- if not self._should_generate_next_chunk():
487
- time.sleep(0.25)
488
- continue
489
 
490
- # 1) Generate until we have enough material in the stream
491
- need = _need(first_chunk_extra=(self.idx == 0))
492
- while need > 0 and not self._stop_event.is_set():
493
- with self._lock:
494
- style_vec = self.params.style_vec
495
- self.mrt.guidance_weight = float(self.params.guidance_weight)
496
- self.mrt.temperature = float(self.params.temperature)
497
- self.mrt.topk = int(self.params.topk)
498
- wav, self.state = self.mrt.generate_chunk(state=self.state, style=style_vec)
499
- self._append_model_chunk_to_stream(wav) # equal-power xfade into a persistent stream
500
- need = _need(first_chunk_extra=(self.idx == 0))
501
-
502
- if self._stop_event.is_set():
503
- break
504
-
505
- # 2) One-time: align the emit pointer to the groove
506
- if (self.idx == 0 and self.params.combined_loop is not None) or self._needs_bar_realign:
507
- ref_loop = self._reseed_ref_loop or self.params.combined_loop
508
- if ref_loop is not None:
509
- head_len = min(self._stream.shape[0] - self._next_emit_start, int(round(2 * spb * sr)))
510
- seg = self._stream[self._next_emit_start : self._next_emit_start + head_len]
511
- gen_head = au.Waveform(seg.astype(np.float32, copy=False), sr).as_stereo()
512
- offs = _estimate_first_offset_samples(ref_loop, gen_head, sr, spb)
513
- if offs != 0:
514
- self._next_emit_start = max(0, self._next_emit_start + offs)
515
- print(f"🎯 Offset compensation: {offs/sr:+.3f}s")
516
- self._realign_emit_pointer_to_bar(sr)
517
- self._needs_bar_realign = False
518
- self._reseed_ref_loop = None
519
-
520
- # 3) Emit exactly bars_per_chunk × spb from the stream
521
- start = self._next_emit_start
522
- end = start + chunk_samps
523
- if end > self._stream.shape[0]:
524
- # shouldn't happen often; generate a bit more and loop
525
  continue
526
 
527
- slice_ = self._stream[start:end]
528
- self._next_emit_start = end
529
-
530
- y = au.Waveform(slice_.astype(np.float32, copy=False), sr).as_stereo()
531
-
532
- # 4) Post-processing / loudness
533
- if self.idx == 0 and self.params.ref_loop is not None:
534
- y, _ = match_loudness_to_reference(
535
- self.params.ref_loop, y,
536
- method=self.params.loudness_mode,
537
- headroom_db=self.params.headroom_db
538
- )
539
- else:
540
- apply_micro_fades(y, 3)
541
-
542
- # 5) Resample + exact-length snap + encode
543
- b64, meta = self._snap_and_encode(
544
- y, seconds=chunk_secs, target_sr=self.params.target_sr, bars=self.params.bars_per_chunk
545
- )
546
- meta["xfade_seconds"] = xfade
547
-
548
- # 6) Publish
549
- with self._lock:
550
- self.idx += 1
551
- self.outbox.append(JamChunk(index=self.idx, audio_base64=b64, metadata=meta))
552
- if len(self.outbox) > 10:
553
- cutoff = self._last_delivered_index - 5
554
- self.outbox = [ch for ch in self.outbox if ch.index > cutoff]
555
-
556
- # 👉 If a reseed was requested, apply it *now*, between chunks
557
- if self._pending_reseed is not None:
558
- pkg = self._pending_reseed
559
- self._pending_reseed = None
560
-
561
- new_state = self.mrt.init_state()
562
- new_state.context_tokens = pkg["ctx"] # exact (ctx_frames, depth)
563
- self.state = new_state
564
-
565
- # start a fresh stream and schedule one-time alignment
566
- self._stream = None
567
- self._next_emit_start = 0
568
- self._reseed_ref_loop = pkg.get("ref") or self.params.combined_loop
569
- self._needs_bar_realign = True
570
-
571
- print("🔁 Reseed installed at bar boundary; will realign before next slice")
572
-
573
- print(f"✅ Completed chunk {self.idx}")
574
-
575
- print("🛑 JamWorker stopped")
576
-
 
1
+ # jam_worker.py - Bar-locked spool rewrite
2
+ from __future__ import annotations
3
+
4
+ import threading, time
5
+ from dataclasses import dataclass
6
+ from fractions import Fraction
7
+ from typing import Optional, Dict, Tuple, List
8
+
9
  import numpy as np
 
10
  from magenta_rt import audio as au
11
+
12
  from utils import (
13
+ StreamingResampler,
14
+ match_loudness_to_reference,
15
+ make_bar_aligned_context,
16
+ take_bar_aligned_tail,
17
+ wav_bytes_base64,
18
  )
19
 
20
+ # -----------------------------
21
+ # Data classes
22
+ # -----------------------------
23
+
24
  @dataclass
25
  class JamParams:
26
  bpm: float
 
29
  target_sr: int
30
  loudness_mode: str = "auto"
31
  headroom_db: float = 1.0
32
+ style_vec: Optional[np.ndarray] = None
33
+ ref_loop: Optional[au.Waveform] = None
34
+ combined_loop: Optional[au.Waveform] = None
35
  guidance_weight: float = 1.1
36
  temperature: float = 1.1
37
  topk: int = 40
38
 
39
+
40
  @dataclass
41
  class JamChunk:
42
  index: int
43
  audio_base64: str
44
  metadata: dict
45
 
 
 
 
 
 
 
46
 
47
+ # -----------------------------
48
+ # Helpers
49
+ # -----------------------------
50
 
51
+ class BarClock:
52
+ """Sample-domain bar clock with drift-free absolute boundaries."""
53
+ def __init__(self, target_sr: int, bpm: float, beats_per_bar: int, base_offset_samples: int = 0):
54
+ self.sr = int(target_sr)
55
+ self.bpm = Fraction(str(bpm)) # exact decimal to avoid FP drift
56
+ self.beats_per_bar = int(beats_per_bar)
57
+ self.bar_samps = Fraction(self.sr * 60 * self.beats_per_bar, 1) / self.bpm
58
+ self.base = int(base_offset_samples)
59
 
60
+ def bounds_for_chunk(self, chunk_index: int, bars_per_chunk: int) -> Tuple[int, int]:
61
+ start_f = self.base + self.bar_samps * (chunk_index * bars_per_chunk)
62
+ end_f = self.base + self.bar_samps * ((chunk_index + 1) * bars_per_chunk)
63
+ return int(round(start_f)), int(round(end_f))
64
 
65
+ def seconds_per_bar(self) -> float:
66
+ return float(self.beats_per_bar) * (60.0 / float(self.bpm))
67
 
 
 
 
 
 
 
 
68
 
69
+ # -----------------------------
70
+ # Worker
71
+ # -----------------------------
72
 
73
+ class JamWorker(threading.Thread):
74
+ """Generates continuous audio with MagentaRT, spools it at target SR,
75
+ and emits *sample-accurate*, bar-aligned chunks (no FPS drift)."""
76
 
77
+ def __init__(self, mrt, params: JamParams):
78
+ super().__init__(daemon=True)
79
+ self.mrt = mrt
80
+ self.params = params
81
 
82
+ # generation state
83
+ self.state = self.mrt.init_state()
84
+ self.mrt.guidance_weight = float(self.params.guidance_weight)
85
+ self.mrt.temperature = float(self.params.temperature)
86
+ self.mrt.topk = int(self.params.topk)
87
 
88
+ # style vector (already normalized upstream)
89
+ self._style_vec = self.params.style_vec
 
 
 
 
90
 
91
+ # codec/setup
92
+ self._codec_fps = float(self.mrt.codec.frame_rate)
93
+ self._ctx_frames = int(self.mrt.config.context_length_frames)
94
+ self._ctx_seconds = self._ctx_frames / self._codec_fps
95
 
96
+ # model stream (model SR) for internal continuity/crossfades
97
+ self._model_stream: Optional[np.ndarray] = None
98
+ self._model_sr = int(self.mrt.sample_rate)
 
 
 
 
99
 
100
+ # target-SR in-RAM spool (what we cut loops from)
101
+ self._rs = StreamingResampler(self._model_sr, int(self.params.target_sr), channels=2)
102
+ self._spool = np.zeros((0, 2), dtype=np.float32) # (S,2) target SR
103
+ self._spool_written = 0 # absolute frames written into spool
104
 
105
+ # bar clock: start with offset 0; if you have a downbeat estimator, set base later
106
+ self._bar_clock = BarClock(self.params.target_sr, self.params.bpm, self.params.beats_per_bar, base_offset_samples=0)
 
 
 
107
 
108
+ # emission counters
109
+ self.idx = 0 # next chunk index to *produce*
110
+ self._next_to_deliver = 0 # next chunk index to hand out via get_next_chunk()
111
+ self._last_consumed_index = -1 # updated via mark_chunk_consumed(); generation throttle uses this
112
 
113
+ # outbox and synchronization
114
+ self._outbox: Dict[int, JamChunk] = {}
115
+ self._cv = threading.Condition()
116
 
117
+ # control flags
118
+ self._stop = threading.Event()
119
+ self._max_buffer_ahead = 5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
120
 
121
+ # reseed queue (install at next safe point)
122
+ self._pending_reseed: Optional[dict] = None
 
 
 
123
 
124
+ # Prepare initial context from combined loop (best musical alignment)
125
+ if self.params.combined_loop is not None:
126
+ self._install_context_from_loop(self.params.combined_loop)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127
 
128
+ # ---------- lifecycle ----------
 
 
 
 
129
 
130
+ def stop(self):
131
+ self._stop.set()
132
+
133
+ # FastAPI reads this to block until the next sequential chunk is ready
134
+ def get_next_chunk(self, timeout: float = 30.0) -> Optional[JamChunk]:
135
+ deadline = time.time() + timeout
136
+ with self._cv:
137
+ while True:
138
+ c = self._outbox.get(self._next_to_deliver)
139
+ if c is not None:
140
+ self._next_to_deliver += 1
141
+ return c
142
+ remaining = deadline - time.time()
143
+ if remaining <= 0:
144
+ return None
145
+ self._cv.wait(timeout=min(0.25, remaining))
146
 
147
+ def mark_chunk_consumed(self, chunk_index: int):
148
+ # This lets the generator run ahead, but not too far
149
+ with self._cv:
150
+ self._last_consumed_index = max(self._last_consumed_index, int(chunk_index))
151
+ # purge old chunks to cap memory
152
+ for k in list(self._outbox.keys()):
153
+ if k < self._last_consumed_index - 1:
154
+ self._outbox.pop(k, None)
155
 
156
+ def update_knobs(self, *, guidance_weight=None, temperature=None, topk=None):
157
+ if guidance_weight is not None:
158
+ self.params.guidance_weight = float(guidance_weight)
159
+ if temperature is not None:
160
+ self.params.temperature = float(temperature)
161
+ if topk is not None:
162
+ self.params.topk = int(topk)
163
+ # push into mrt (thread-safe enough for our use)
164
+ self.mrt.guidance_weight = float(self.params.guidance_weight)
165
+ self.mrt.temperature = float(self.params.temperature)
166
+ self.mrt.topk = int(self.params.topk)
167
+
168
+ # ---------- context / reseed ----------
169
+
170
+ def _install_context_from_loop(self, loop: au.Waveform):
171
+ # Build a bar-aligned tail and encode to context tokens
172
+ loop = loop.as_stereo().resample(self._model_sr)
173
+ tail = take_bar_aligned_tail(loop, self.params.bpm, self.params.beats_per_bar, self._ctx_seconds)
174
+ tokens_full = self.mrt.codec.encode(tail).astype(np.int32)
175
+ depth = int(self.mrt.config.decoder_codec_rvq_depth)
176
+ context_tokens = tokens_full[:, :depth]
177
 
178
+ # install state
179
+ s = self.mrt.init_state()
180
+ s.context_tokens = context_tokens
181
+ self.state = s
182
 
183
+ # keep an original copy for future splices
184
+ self._original_context_tokens = np.copy(context_tokens)
 
 
185
 
186
+ def reseed_from_waveform(self, wav: au.Waveform):
187
+ """Immediate reseed: replace context from provided wave (bar-aligned tail)."""
188
+ wav = wav.as_stereo().resample(self._model_sr)
189
+ tail = take_bar_aligned_tail(wav, self.params.bpm, self.params.beats_per_bar, self._ctx_seconds)
190
  tokens_full = self.mrt.codec.encode(tail).astype(np.int32)
191
+ depth = int(self.mrt.config.decoder_codec_rvq_depth)
192
+ context_tokens = tokens_full[:, :depth]
193
+
194
+ s = self.mrt.init_state()
195
+ s.context_tokens = context_tokens
196
+ self.state = s
197
+ # reset model stream so next generate starts cleanly
198
+ self._model_stream = None
199
+
200
+ # optional loudness match will be applied per-chunk on emission
201
+
202
+ # also remember this as new "original"
203
+ self._original_context_tokens = np.copy(context_tokens)
204
+
205
+ def reseed_splice(self, recent_wav: au.Waveform, anchor_bars: float):
206
+ """Queue a splice reseed to be applied right after the next emitted loop.
207
+ For now, we simply replace the context by recent wave tail; anchor is accepted
208
+ for API compatibility and future crossfade/token-splice logic."""
209
+ recent_wav = recent_wav.as_stereo().resample(self._model_sr)
210
+ tail = take_bar_aligned_tail(recent_wav, self.params.bpm, self.params.beats_per_bar, self._ctx_seconds)
211
+ tokens_full = self.mrt.codec.encode(tail).astype(np.int32)
212
+ depth = int(self.mrt.config.decoder_codec_rvq_depth)
213
+ new_ctx = tokens_full[:, :depth]
214
+ self._pending_reseed = {"ctx": new_ctx}
215
+
216
+ # ---------- core streaming helpers ----------
217
+
218
+ def _append_model_chunk_and_spool(self, wav: au.Waveform):
219
+ """Crossfade into the model-rate stream and write the *non-overlapped*
220
+ tail to the target-SR spool."""
221
+ s = wav.samples.astype(np.float32, copy=False)
222
+ if s.ndim == 1:
223
+ s = s[:, None]
224
+ sr = self._model_sr
225
+ xfade_s = float(self.mrt.config.crossfade_length)
226
+ xfade_n = int(round(max(0.0, xfade_s) * sr))
227
+
228
+ if self._model_stream is None:
229
+ # first chunk: drop the preroll (xfade) then spool
230
+ new_part = s[xfade_n:] if xfade_n < s.shape[0] else s[:0]
231
+ self._model_stream = new_part.copy()
232
+ if new_part.size:
233
+ y = self._rs.process(new_part, final=False)
234
+ self._spool = np.concatenate([self._spool, y], axis=0)
235
+ self._spool_written += y.shape[0]
236
+ return
237
+
238
+ # crossfade into existing stream
239
+ if xfade_n > 0 and self._model_stream.shape[0] >= xfade_n and s.shape[0] >= xfade_n:
240
+ tail = self._model_stream[-xfade_n:]
241
+ head = s[:xfade_n]
242
+ t = np.linspace(0, np.pi/2, xfade_n, endpoint=False, dtype=np.float32)[:, None]
243
+ mixed = tail * np.cos(t) + head * np.sin(t)
244
+ self._model_stream = np.concatenate([self._model_stream[:-xfade_n], mixed, s[xfade_n:]], axis=0)
245
+ new_part = s[xfade_n:]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
246
  else:
247
+ self._model_stream = np.concatenate([self._model_stream, s], axis=0)
248
+ new_part = s
249
 
250
+ # spool only the *new* non-overlapped part
251
+ if new_part.size:
252
+ y = self._rs.process(new_part.astype(np.float32, copy=False), final=False)
253
+ if y.size:
254
+ self._spool = np.concatenate([self._spool, y], axis=0)
255
+ self._spool_written += y.shape[0]
256
 
257
+ def _should_generate_next_chunk(self) -> bool:
258
+ # Don't let generation run too far ahead of consumption
259
+ return self.idx <= (self._last_consumed_index + self._max_buffer_ahead)
260
+
261
+ def _emit_ready(self):
262
+ """Emit next chunk(s) if the spool has enough samples."""
263
+ while True:
264
+ start, end = self._bar_clock.bounds_for_chunk(self.idx, self.params.bars_per_chunk)
265
+ if end > self._spool_written:
266
+ break # need more audio
267
+ loop = self._spool[start:end]
268
+
269
+ # Loudness match to reference loop (optional)
270
+ if self.params.ref_loop is not None and self.params.loudness_mode != "none":
271
+ ref = self.params.ref_loop.as_stereo().resample(self.params.target_sr)
272
+ wav = au.Waveform(loop.copy(), int(self.params.target_sr))
273
+ matched, _ = match_loudness_to_reference(ref, wav, method=self.params.loudness_mode, headroom_db=self.params.headroom_db)
274
+ loop = matched.samples
275
+
276
+ audio_b64, total_samples, channels = wav_bytes_base64(loop, int(self.params.target_sr))
277
+ meta = {
278
+ "bpm": float(self.params.bpm),
279
+ "bars": int(self.params.bars_per_chunk),
280
+ "beats_per_bar": int(self.params.beats_per_bar),
281
+ "sample_rate": int(self.params.target_sr),
282
+ "channels": int(channels),
283
+ "total_samples": int(total_samples),
284
+ "seconds_per_bar": self._bar_clock.seconds_per_bar(),
285
+ "loop_duration_seconds": self.params.bars_per_chunk * self._bar_clock.seconds_per_bar(),
286
+ "guidance_weight": float(self.params.guidance_weight),
287
+ "temperature": float(self.params.temperature),
288
+ "topk": int(self.params.topk),
289
+ }
290
+ chunk = JamChunk(index=self.idx, audio_base64=audio_b64, metadata=meta)
291
+
292
+ with self._cv:
293
+ self._outbox[self.idx] = chunk
294
+ self._cv.notify_all()
295
+ self.idx += 1
296
+
297
+ # If a reseed is queued, install it *right after* we finish a chunk
298
+ if self._pending_reseed is not None:
299
+ new_state = self.mrt.init_state()
300
+ new_state.context_tokens = self._pending_reseed["ctx"]
301
+ self.state = new_state
302
+ self._model_stream = None # drop model-domain continuity so next chunk starts clean
303
+ self._pending_reseed = None
304
+
305
+ # ---------- main loop ----------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
306
 
307
  def run(self):
308
+ # set style vector if present
309
+ style_vec = self._style_vec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
310
 
311
+ # generate until stopped
312
+ while not self._stop.is_set():
313
+ # throttle generation if we are far ahead
314
+ if not self._should_generate_next_chunk():
315
+ # still try to emit if spool already has enough
316
+ self._emit_ready()
317
+ time.sleep(0.01)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
318
  continue
319
 
320
+ # generate next model chunk
321
+ wav, self.state = self.mrt.generate_chunk(state=self.state, style=style_vec)
322
+ # append and spool
323
+ self._append_model_chunk_and_spool(wav)
324
+ # try emitting zero or more chunks if available
325
+ self._emit_ready()
326
+
327
+ # finalize resampler (flush) — not strictly necessary here
328
+ tail = self._rs.process(np.zeros((0,2), np.float32), final=True)
329
+ if tail.size:
330
+ self._spool = np.concatenate([self._spool, tail], axis=0)
331
+ self._spool_written += tail.shape[0]
332
+ # one last emit attempt
333
+ self._emit_ready()