thecollabagepatch commited on
Commit
1b98b73
Β·
1 Parent(s): 783cbeb

fixing continuity

Browse files
Files changed (2) hide show
  1. jam_worker.py +106 -72
  2. utils.py +4 -2
jam_worker.py CHANGED
@@ -350,88 +350,122 @@ class JamWorker(threading.Thread):
350
  self._pending_drop_intro_bars = getattr(self, "_pending_drop_intro_bars", 0) + 1
351
 
352
  def run(self):
353
- """Continuous stream + sliding 8-bar window emitter."""
354
- sr_model = int(self.mrt.sample_rate)
355
  spb = self._seconds_per_bar()
356
- chunk_secs = float(self.params.bars_per_chunk) * spb
357
- chunk_n_model = int(round(chunk_secs * sr_model))
358
- xfade = self.mrt.config.crossfade_length
359
-
360
- # Streaming state
361
- self._stream = None # np.ndarray [S, C] at model SR
362
- self._next_emit_start = 0 # sample pointer for next 8-bar cut
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
363
 
364
- print("πŸš€ JamWorker (streaming) started...")
365
 
366
  while not self._stop_event.is_set():
367
- # Flow control: don't get too far ahead of the consumer
 
 
 
 
 
 
 
368
  with self._lock:
369
- if self.idx > self._last_delivered_index + self._max_buffer_ahead:
370
- time.sleep(0.25)
371
- continue
372
  style_vec = self.params.style_vec
373
- self.mrt.guidance_weight = self.params.guidance_weight
374
- self.mrt.temperature = self.params.temperature
375
- self.mrt.topk = self.params.topk
 
376
 
377
- # Generate ONE model chunk and append to the continuous stream
378
  self.last_chunk_started_at = time.time()
379
- wav, self.state = self.mrt.generate_chunk(state=self.state, style=style_vec)
380
- self._append_model_chunk_to_stream(wav)
381
- if getattr(self, "_needs_bar_realign", False):
382
- self._realign_emit_pointer_to_bar(sr_model)
383
- self._needs_bar_realign = False
384
- # DEBUG
385
- bar_samps = int(round(self._seconds_per_bar() * sr_model))
386
- if bar_samps > 0 and (self._next_emit_start % bar_samps) != 0:
387
- print(f"⚠️ emit pointer not aligned: phase={self._next_emit_start % bar_samps}")
388
- else:
389
- print("βœ… emit pointer aligned to bar")
390
-
391
- self.last_chunk_completed_at = time.time()
392
-
393
- # While we have at least one full 8-bar window available, emit it
394
- while (getattr(self, "_stream", None) is not None and
395
- self._stream.shape[0] - self._next_emit_start >= chunk_n_model and
396
- not self._stop_event.is_set()):
397
-
398
- seg = self._stream[self._next_emit_start:self._next_emit_start + chunk_n_model]
399
 
400
- # Wrap as Waveform at model SR
401
- y = au.Waveform(seg.astype(np.float32, copy=False), sr_model).as_stereo()
402
-
403
- # Post-processing:
404
- # - First emitted chunk: loudness-match to ref_loop
405
- # - No micro-fades on mid-stream windows (they cause dips)
406
- next_idx = self.idx + 1
407
- if next_idx == 1 and self.params.ref_loop is not None:
408
- y, _ = match_loudness_to_reference(
409
- self.params.ref_loop, y,
410
- method=self.params.loudness_mode,
411
- headroom_db=self.params.headroom_db
412
- )
413
-
414
- # Resample + snap + encode exactly chunk_secs long
415
- b64, meta = self._snap_and_encode(
416
- y, seconds=chunk_secs,
417
- target_sr=self.params.target_sr,
418
- bars=self.params.bars_per_chunk
 
 
 
 
 
 
 
 
 
 
 
 
 
 
419
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
420
 
421
- with self._lock:
422
- self.idx = next_idx
423
- self.outbox.append(JamChunk(index=next_idx, audio_base64=b64, metadata=meta))
424
- # Bound the outbox
425
- if len(self.outbox) > 10:
426
- self.outbox = [ch for ch in self.outbox if ch.index > self._last_delivered_index - 5]
427
-
428
- # Advance window pointer to the next 8-bar slot
429
- self._next_emit_start += chunk_n_model
430
 
431
- # Trim old samples to keep memory bounded (keep a little guard)
432
- keep_from = max(0, self._next_emit_start - chunk_n_model) # keep 1 extra window
433
- if keep_from > 0:
434
- self._stream = self._stream[keep_from:]
435
- self._next_emit_start -= keep_from
436
 
437
- print("πŸ›‘ JamWorker (streaming) stopped")
 
350
  self._pending_drop_intro_bars = getattr(self, "_pending_drop_intro_bars", 0) + 1
351
 
352
  def run(self):
353
+ """Main worker loop - generate chunks continuously but don't get too far ahead"""
 
354
  spb = self._seconds_per_bar()
355
+ chunk_secs = self.params.bars_per_chunk * spb
356
+ xfade = float(self.mrt.config.crossfade_length) # seconds
357
+
358
+ # local fallback stitcher that *keeps* the first head if utils.stitch_generated
359
+ # doesn't yet support drop_first_pre_roll
360
+ def _stitch_keep_head(chunks, sr: int, xfade_s: float):
361
+ from magenta_rt import audio as au
362
+ import numpy as _np
363
+ if not chunks:
364
+ raise ValueError("no chunks to stitch")
365
+ xfade_n = int(round(max(0.0, xfade_s) * sr))
366
+ # Fast-path: no crossfade
367
+ if xfade_n <= 0:
368
+ out = _np.concatenate([c.samples for c in chunks], axis=0)
369
+ return au.Waveform(out, sr)
370
+ # build equal-power curves
371
+ t = _np.linspace(0, _np.pi / 2, xfade_n, endpoint=False, dtype=_np.float32)
372
+ eq_in, eq_out = _np.sin(t)[:, None], _np.cos(t)[:, None]
373
+
374
+ first = chunks[0].samples
375
+ if first.shape[0] < xfade_n:
376
+ raise ValueError("chunk shorter than crossfade prefix")
377
+ out = first.copy() # πŸ‘ˆ keep the head for live seam
378
+
379
+ for i in range(1, len(chunks)):
380
+ cur = chunks[i].samples
381
+ if cur.shape[0] < xfade_n:
382
+ # too short to crossfade; just butt-join
383
+ out = _np.concatenate([out, cur], axis=0)
384
+ continue
385
+ head, tail = cur[:xfade_n], cur[xfade_n:]
386
+ mixed = out[-xfade_n:] * eq_out + head * eq_in
387
+ out = _np.concatenate([out[:-xfade_n], mixed, tail], axis=0)
388
+ return au.Waveform(out, sr)
389
 
390
+ print("πŸš€ JamWorker started with flow control...")
391
 
392
  while not self._stop_event.is_set():
393
+ # Don’t get too far ahead of the consumer
394
+ if not self._should_generate_next_chunk():
395
+ # We're ahead enough, wait a bit for frontend to catch up
396
+ # (kept short so stop() stays responsive)
397
+ time.sleep(0.5)
398
+ continue
399
+
400
+ # Snapshot knobs + compute index atomically
401
  with self._lock:
 
 
 
402
  style_vec = self.params.style_vec
403
+ self.mrt.guidance_weight = float(self.params.guidance_weight)
404
+ self.mrt.temperature = float(self.params.temperature)
405
+ self.mrt.topk = int(self.params.topk)
406
+ next_idx = self.idx + 1
407
 
408
+ print(f"🎹 Generating chunk {next_idx}...")
409
  self.last_chunk_started_at = time.time()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
410
 
411
+ # ---- Generate enough model sub-chunks to yield *audible* chunk_secs ----
412
+ # Count the first chunk at full length L, and each subsequent at (L - xfade)
413
+ assembled = 0.0
414
+ chunks = []
415
+
416
+ while assembled < chunk_secs and not self._stop_event.is_set():
417
+ # generate_chunk returns (au.Waveform, new_state)
418
+ wav, self.state = self.mrt.generate_chunk(state=self.state, style=style_vec)
419
+ chunks.append(wav)
420
+ L = wav.samples.shape[0] / float(self.mrt.sample_rate)
421
+ assembled += L if len(chunks) == 1 else max(0.0, L - xfade)
422
+
423
+ if self._stop_event.is_set():
424
+ break
425
+
426
+ # ---- Stitch and trim at model SR (keep first head for seamless handoff) ----
427
+ try:
428
+ # Preferred path if you've added the new param in utils.stitch_generated
429
+ y = stitch_generated(chunks, self.mrt.sample_rate, xfade, drop_first_pre_roll=False).as_stereo()
430
+ except TypeError:
431
+ # Backward-compatible: local stitcher that keeps the head
432
+ y = _stitch_keep_head(chunks, int(self.mrt.sample_rate), xfade).as_stereo()
433
+
434
+ # Hard trim to the exact musical duration (still at model SR)
435
+ y = hard_trim_seconds(y, chunk_secs)
436
+
437
+ # ---- Post-processing ----
438
+ if next_idx == 1 and self.params.ref_loop is not None:
439
+ # match loudness to the provided reference on the very first audible chunk
440
+ y, _ = match_loudness_to_reference(
441
+ self.params.ref_loop, y,
442
+ method=self.params.loudness_mode,
443
+ headroom_db=self.params.headroom_db
444
  )
445
+ else:
446
+ # light micro-fades to guard against clicks
447
+ apply_micro_fades(y, 3)
448
+
449
+ # ---- Resample + bar-snap + encode ----
450
+ b64, meta = self._snap_and_encode(
451
+ y,
452
+ seconds=chunk_secs,
453
+ target_sr=self.params.target_sr,
454
+ bars=self.params.bars_per_chunk
455
+ )
456
+ # small hint for the client if you want UI butter between chunks
457
+ meta["xfade_seconds"] = xfade
458
 
459
+ # ---- Publish the completed chunk ----
460
+ with self._lock:
461
+ self.idx = next_idx
462
+ self.outbox.append(JamChunk(index=next_idx, audio_base64=b64, metadata=meta))
463
+ # Keep outbox bounded (trim far-behind entries)
464
+ if len(self.outbox) > 10:
465
+ cutoff = self._last_delivered_index - 5
466
+ self.outbox = [ch for ch in self.outbox if ch.index > cutoff]
 
467
 
468
+ self.last_chunk_completed_at = time.time()
469
+ print(f"βœ… Completed chunk {next_idx}")
 
 
 
470
 
471
+ print("πŸ›‘ JamWorker stopped")
utils.py CHANGED
@@ -69,7 +69,7 @@ def match_loudness_to_reference(
69
 
70
 
71
  # ---------- Stitch / fades / trims ----------
72
- def stitch_generated(chunks, sr: int, xfade_s: float) -> au.Waveform:
73
  if not chunks:
74
  raise ValueError("no chunks")
75
  xfade_n = int(round(xfade_s * sr))
@@ -82,7 +82,9 @@ def stitch_generated(chunks, sr: int, xfade_s: float) -> au.Waveform:
82
  first = chunks[0].samples
83
  if first.shape[0] < xfade_n:
84
  raise ValueError("chunk shorter than crossfade prefix")
85
- out = first[xfade_n:].copy() # drop model pre-roll
 
 
86
 
87
  for i in range(1, len(chunks)):
88
  cur = chunks[i].samples
 
69
 
70
 
71
  # ---------- Stitch / fades / trims ----------
72
+ def stitch_generated(chunks, sr: int, xfade_s: float, drop_first_pre_roll: bool = True):
73
  if not chunks:
74
  raise ValueError("no chunks")
75
  xfade_n = int(round(xfade_s * sr))
 
82
  first = chunks[0].samples
83
  if first.shape[0] < xfade_n:
84
  raise ValueError("chunk shorter than crossfade prefix")
85
+
86
+ # πŸ”§ key change:
87
+ out = first[xfade_n:].copy() if drop_first_pre_roll else first.copy()
88
 
89
  for i in range(1, len(chunks)):
90
  cur = chunks[i].samples