thecollabagepatch commited on
Commit
1f74f2f
Β·
1 Parent(s): dec57f5
Files changed (1) hide show
  1. jam_worker.py +142 -152
jam_worker.py CHANGED
@@ -410,196 +410,186 @@ class JamWorker(threading.Thread):
410
  self._pending_drop_intro_bars = getattr(self, "_pending_drop_intro_bars", 0) + 1
411
 
412
  def run(self):
413
- """Main worker loop β€” continuous gen at model SR, stream resampled chunks butt-joined at target SR."""
414
- import numpy as _np
415
- from math import floor, ceil
416
- spb = self._seconds_per_bar() # seconds per bar
417
- chunk_secs = float(self.params.bars_per_chunk) * spb # seconds per client chunk
418
- xfade_s = float(self.mrt.config.crossfade_length) # seconds of model equal-power xfade
419
- sr_in = int(self.mrt.sample_rate) # model/native SR
420
- sr_out = int(self.params.target_sr or sr_in) # desired output SR (e.g., 44100)
421
- ch = 2 # enforce stereo out
422
-
423
- # --- Fractional emit steppers (input + output domains) ---
424
- chunk_step_in_f = chunk_secs * sr_in
425
- self._emit_phase = float(getattr(self, "_emit_phase", 0.0))
426
-
427
- chunk_step_out_f = chunk_secs * sr_out
428
- self._emit_phase_out = float(getattr(self, "_emit_phase_out", 0.0))
429
- self._next_emit_start_out = int(getattr(self, "_next_emit_start_out", 0))
430
-
431
- # --- Streaming resampler state (input -> output); hold back xfade tail so overlapped region is final ---
432
- xfade_n_in = int(round(xfade_s * sr_in))
433
- self._resampler = None
434
- self._stream_out = None
435
- self._resample_cursor_in = int(getattr(self, "_resample_cursor_in", 0)) # how many INPUT samples we fed to the resampler
436
-
437
- if sr_out != sr_in:
438
- try:
439
- from utils import StreamingResampler
440
- self._resampler = StreamingResampler(in_sr=sr_in, out_sr=sr_out, channels=ch, quality="VHQ")
441
- self._stream_out = _np.zeros((0, ch), dtype=_np.float32)
442
- except Exception as e:
443
- print(f"⚠️ Could not init StreamingResampler ({e}); falling back to alias-mode (sr_out==sr_in).")
444
- sr_out = sr_in
445
- self.params.target_sr = sr_out
446
- self._resampler = None
447
- self._stream_out = _np.zeros((0, ch), dtype=_np.float32)
448
- self._resample_cursor_in = 0
449
- else:
450
- self._stream_out = _np.zeros((0, ch), dtype=_np.float32)
451
- self._resample_cursor_in = 0
452
-
453
- # --- helper: how many more INPUT samples (stable) we need to be able to emit next client chunk ---
454
- def _need(first_chunk_extra: bool=False) -> int:
455
- start = int(getattr(self, "_next_emit_start", 0))
456
- total_in = 0 if getattr(self, "_stream", None) is None else int(self._stream.shape[0])
457
- total_in_stable = max(0, total_in - xfade_n_in) # hold back xfade tail (overlap will be replaced)
458
- have = max(0, total_in_stable - start)
459
  emit_phase = float(getattr(self, "_emit_phase", 0.0))
460
- step_int_in = int(floor(chunk_step_in_f + emit_phase))
461
- want = step_int_in
 
 
462
  if first_chunk_extra:
463
- want += int(ceil(2.0 * spb * sr_in))
 
 
 
464
  return max(0, want - have)
465
 
466
- print(f"▢️ JamWorker starting: bpm={self.params.bpm}, bars/chunk={self.params.bars_per_chunk}, "
467
- f"sr_in={sr_in}, sr_out={sr_out}, xfade_s={xfade_s:.3f}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
468
 
469
- # --- main loop ---
470
  while not self._stop_event.is_set():
471
- # 0) Backpressure: don't run too far ahead
472
  if not self._should_generate_next_chunk():
473
- time.sleep(0.01)
474
  continue
475
 
476
- # 1) Ensure enough model audio exists (INPUT domain)
477
  need = _need(first_chunk_extra=(self.idx == 0))
478
- if need > 0:
479
- # Generate one model chunk
480
- style_vec = self.params.style_vec
481
- self.mrt.guidance_weight = float(self.params.guidance_weight)
482
- self.mrt.temperature = float(self.params.temperature)
483
- self.mrt.topk = int(self.params.topk)
484
-
485
  wav, self.state = self.mrt.generate_chunk(state=self.state, style=style_vec)
486
-
487
- # Append (equal-power crossfade into persistent input-domain stream)
488
- self._append_model_chunk_to_stream(wav)
489
-
490
- # Feed *stable* portion into the resampler/output buffer
491
- if getattr(self, "_stream", None) is not None and self._stream.shape[0] > 0:
492
- stable_end_in = max(0, int(self._stream.shape[0]) - xfade_n_in)
493
- if stable_end_in > self._resample_cursor_in:
494
- x_in = self._stream[self._resample_cursor_in:stable_end_in]
495
- if self._resampler is not None:
496
- y_out = self._resampler.process(x_in.astype(_np.float32, copy=False), final=False)
497
- if y_out.size:
498
- self._stream_out = y_out if self._stream_out.size == 0 else _np.vstack([self._stream_out, y_out])
499
- else:
500
- # pass-through (sr_out == sr_in)
501
- self._stream_out = x_in if self._stream_out.size == 0 else _np.vstack([self._stream_out, x_in])
502
- self._resample_cursor_in = stable_end_in
503
- # loop back to either generate more or try emitting
504
- continue
505
-
506
- # 2) Optional, one-shot bar realign (occurs on first slice or reseed)
507
- if getattr(self, "_needs_bar_realign", False):
508
- self._realign_emit_pointer_to_bar(sr_in)
509
- self._emit_phase = 0.0 # restart fractional phase at clean bar
510
  self._needs_bar_realign = False
511
  self._reseed_ref_loop = None
512
 
513
- # 3) Compute next emit window in BOTH domains
514
- start_in = int(getattr(self, "_next_emit_start", 0))
515
- step_total_in = chunk_step_in_f + self._emit_phase
516
- step_int_in = int(floor(step_total_in))
517
- new_phase_in = float(step_total_in - step_int_in)
518
- end_in = start_in + step_int_in
519
-
520
- start_out = int(self._next_emit_start_out)
521
- step_total_out = chunk_step_out_f + self._emit_phase_out
522
- step_int_out = int(floor(step_total_out))
523
- new_phase_out = float(step_total_out - step_int_out)
524
- end_out = start_out + step_int_out
525
-
526
- # 4) Guards β€” do we actually have enough ready in both domains?
527
- total_in_stable = 0
528
- if getattr(self, "_stream", None) is not None:
529
- total_in_stable = max(0, int(self._stream.shape[0]) - xfade_n_in)
530
- total_out_ready = 0 if self._stream_out is None else int(self._stream_out.shape[0])
531
-
532
- if end_in > total_in_stable or end_out > total_out_ready:
533
- time.sleep(0.005)
534
  continue
535
 
536
- # 5) Slice OUTPUT-domain audio to send
537
- slice_out = self._stream_out[start_out:end_out]
538
-
539
- # Advance pointers + phases atomically
540
- self._next_emit_start = end_in
541
- self._emit_phase = new_phase_in
542
- self._next_emit_start_out = end_out
543
- self._emit_phase_out = new_phase_out
544
 
545
- # 6) Post and encode
546
- y = au.Waveform(slice_out.astype(_np.float32, copy=False), sr_out).as_stereo()
547
 
548
- # Loudness: only on first chunk, match to ref if provided
549
  if self.idx == 0 and self.params.ref_loop is not None:
550
  y, _ = match_loudness_to_reference(
551
  self.params.ref_loop, y,
552
  method=self.params.loudness_mode,
553
  headroom_db=self.params.headroom_db
554
  )
555
- # (No per-slice micro fades; stream continuity handles joins)
 
556
 
557
- # Encode WAV (already sr_out and exact length by construction)
558
- b64, total_samples, channels = wav_bytes_base64(
559
- y.samples if y.samples.ndim == 2 else y.samples[:, None], sr_out
560
  )
 
561
 
562
- meta = {
563
- "bpm": int(round(self.params.bpm)),
564
- "bars": int(self.params.bars_per_chunk),
565
- "beats_per_bar": int(self.params.beats_per_bar),
566
- "sample_rate": int(sr_out),
567
- "channels": int(channels),
568
- "total_samples": int(total_samples),
569
- "seconds_per_bar": float(spb),
570
- "loop_duration_seconds": float(self.params.bars_per_chunk) * float(spb),
571
- "guidance_weight": float(self.params.guidance_weight),
572
- "temperature": float(self.params.temperature),
573
- "topk": int(self.params.topk),
574
- "xfade_seconds": float(xfade_s),
575
- }
576
-
577
  with self._lock:
578
  self.idx += 1
579
  self.outbox.append(JamChunk(index=self.idx, audio_base64=b64, metadata=meta))
580
- # prune outbox
581
  if len(self.outbox) > 10:
582
  cutoff = self._last_delivered_index - 5
583
  self.outbox = [ch for ch in self.outbox if ch.index > cutoff]
584
 
585
- # Apply any pending reseed *between* chunks
586
- if getattr(self, "_pending_reseed", None) is not None:
587
  pkg = self._pending_reseed
588
  self._pending_reseed = None
589
- # A reseed handler has already swapped state.context_tokens upstream.
590
- # Just request a one-shot bar realign against the new ref loop if present.
 
 
 
 
 
 
 
591
  self._needs_bar_realign = True
592
- self._reseed_ref_loop = pkg.get("ref") if isinstance(pkg, dict) else None
593
 
594
- time.sleep(0.001)
595
 
596
- # --- graceful stop: flush resampler tail so last bits become available if client requests them ---
597
- try:
598
- if self._resampler is not None:
599
- tail = self._resampler.flush()
600
- if tail.size:
601
- self._stream_out = tail if self._stream_out.size == 0 else _np.vstack([self._stream_out, tail])
602
- except Exception as e:
603
- print(f"⚠️ Resampler flush error: {e}")
604
 
605
  print("πŸ›‘ JamWorker stopped")
 
 
410
  self._pending_drop_intro_bars = getattr(self, "_pending_drop_intro_bars", 0) + 1
411
 
412
  def run(self):
413
+ """Main worker loop β€” generate into a continuous stream, then emit bar-aligned slices."""
414
+ spb = self._seconds_per_bar() # seconds per bar
415
+ chunk_secs = self.params.bars_per_chunk * spb
416
+ xfade = float(self.mrt.config.crossfade_length) # seconds
417
+ sr = int(self.mrt.sample_rate)
418
+ chunk_step_f = chunk_secs * sr # float samples per chunk
419
+ self._emit_phase = getattr(self, "_emit_phase", 0.0)
420
+
421
+ def _need(first_chunk_extra: bool = False) -> int:
422
+ """
423
+ How many more samples we still need in the stream to emit the next slice.
424
+ Uses the fractional step (chunk_step_f) + current _emit_phase to compute
425
+ the *integer* number of samples required for the next chunk, without
426
+ mutating _emit_phase here.
427
+ """
428
+ start = getattr(self, "_next_emit_start", 0)
429
+ total = 0 if getattr(self, "_stream", None) is None else self._stream.shape[0]
430
+ have = max(0, total - start)
431
+
432
+ # Compute the integer step we'd use for the next emit, non-mutating.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
433
  emit_phase = float(getattr(self, "_emit_phase", 0.0))
434
+ step_int = int(floor(chunk_step_f + emit_phase)) # matches the logic used when advancing
435
+
436
+ # How much we want available beyond 'start' for this emit.
437
+ want = step_int
438
  if first_chunk_extra:
439
+ # Reserve two extra bars so the first-chunk onset alignment has material.
440
+ # Use ceil to be conservative so we don't under-request.
441
+ want += int(ceil(2.0 * spb * sr))
442
+
443
  return max(0, want - have)
444
 
445
+ def _mono_env(x: np.ndarray, sr: int, win_ms: float = 10.0) -> np.ndarray:
446
+ if x.ndim == 2: x = x.mean(axis=1)
447
+ x = np.abs(x).astype(np.float32)
448
+ w = max(1, int(round(win_ms * 1e-3 * sr)))
449
+ if w > 1:
450
+ kern = np.ones(w, dtype=np.float32) / float(w)
451
+ x = np.convolve(x, kern, mode="same")
452
+ d = np.diff(x, prepend=x[:1])
453
+ d[d < 0] = 0.0
454
+ return d
455
+
456
+ def _estimate_first_offset_samples(ref_loop_wav, gen_head_wav, sr: int, spb: float) -> int:
457
+ """Tempo-aware first-downbeat offset (positive => model late)."""
458
+ try:
459
+ max_ms = int(max(160.0, min(0.25 * spb * 1000.0, 450.0)))
460
+ ref = ref_loop_wav if ref_loop_wav.sample_rate == sr else ref_loop_wav.resample(sr)
461
+ n_bar = int(round(spb * sr))
462
+ ref_tail = ref.samples[-n_bar:, :] if ref.samples.shape[0] >= n_bar else ref.samples
463
+ gen_head = gen_head_wav.samples[: int(2 * n_bar), :]
464
+ if ref_tail.size == 0 or gen_head.size == 0:
465
+ return 0
466
+
467
+ # envelopes + z-score
468
+ def _z(a):
469
+ m, s = float(a.mean()), float(a.std() or 1.0); return (a - m) / s
470
+ e_ref = _z(_mono_env(ref_tail, sr)).astype(np.float32)
471
+ e_gen = _z(_mono_env(gen_head, sr)).astype(np.float32)
472
+
473
+ # upsample x4 for finer lag
474
+ def _upsample(a, r=4):
475
+ n = len(a); grid = np.arange(n, dtype=np.float32)
476
+ fine = np.linspace(0, n - 1, num=n * r, dtype=np.float32)
477
+ return np.interp(fine, grid, a).astype(np.float32)
478
+ up = 4
479
+ e_ref_u, e_gen_u = _upsample(e_ref, up), _upsample(e_gen, up)
480
+
481
+ max_lag_u = int(round((max_ms / 1000.0) * sr * up))
482
+ seg = min(len(e_ref_u), len(e_gen_u))
483
+ e_ref_u = e_ref_u[-seg:]
484
+ pad = np.zeros(max_lag_u, dtype=np.float32)
485
+ e_gen_u_pad = np.concatenate([pad, e_gen_u, pad])
486
+
487
+ best_lag_u, best_score = 0, -1e9
488
+ for lag_u in range(-max_lag_u, max_lag_u + 1):
489
+ start = max_lag_u + lag_u
490
+ b = e_gen_u_pad[start : start + seg]
491
+ denom = (np.linalg.norm(e_ref_u) * np.linalg.norm(b)) or 1.0
492
+ score = float(np.dot(e_ref_u, b) / denom)
493
+ if score > best_score:
494
+ best_score, best_lag_u = score, lag_u
495
+ return int(round(best_lag_u / up))
496
+ except Exception:
497
+ return 0
498
+
499
+ print("πŸš€ JamWorker started (bar-aligned streaming)…")
500
 
 
501
  while not self._stop_event.is_set():
 
502
  if not self._should_generate_next_chunk():
503
+ time.sleep(0.25)
504
  continue
505
 
506
+ # 1) Generate until we have enough material in the stream
507
  need = _need(first_chunk_extra=(self.idx == 0))
508
+ while need > 0 and not self._stop_event.is_set():
509
+ with self._lock:
510
+ style_vec = self.params.style_vec
511
+ self.mrt.guidance_weight = float(self.params.guidance_weight)
512
+ self.mrt.temperature = float(self.params.temperature)
513
+ self.mrt.topk = int(self.params.topk)
 
514
  wav, self.state = self.mrt.generate_chunk(state=self.state, style=style_vec)
515
+ self._append_model_chunk_to_stream(wav) # equal-power xfade into a persistent stream
516
+ need = _need(first_chunk_extra=(self.idx == 0))
517
+
518
+ if self._stop_event.is_set():
519
+ break
520
+
521
+ # 2) One-time: align the emit pointer to the groove
522
+ if (self.idx == 0 and self.params.combined_loop is not None) or self._needs_bar_realign:
523
+ ref_loop = self._reseed_ref_loop or self.params.combined_loop
524
+ if ref_loop is not None:
525
+ head_len = min(self._stream.shape[0] - self._next_emit_start, int(round(2 * spb * sr)))
526
+ seg = self._stream[self._next_emit_start : self._next_emit_start + head_len]
527
+ gen_head = au.Waveform(seg.astype(np.float32, copy=False), sr).as_stereo()
528
+ offs = _estimate_first_offset_samples(ref_loop, gen_head, sr, spb)
529
+ if offs != 0:
530
+ self._next_emit_start = max(0, self._next_emit_start + offs)
531
+ print(f"🎯 Offset compensation: {offs/sr:+.3f}s")
532
+ self._realign_emit_pointer_to_bar(sr)
 
 
 
 
 
 
533
  self._needs_bar_realign = False
534
  self._reseed_ref_loop = None
535
 
536
+ # 3) Emit exactly bars_per_chunk Γ— spb from the stream
537
+ start = self._next_emit_start
538
+ step_total = chunk_step_f + self._emit_phase
539
+ step_int = int(np.floor(step_total))
540
+ self._emit_phase = float(step_total - step_int)
541
+ end = start + step_int
542
+ if end > self._stream.shape[0]:
543
+ # shouldn't happen often; generate a bit more and loop
 
 
 
 
 
 
 
 
 
 
 
 
 
544
  continue
545
 
546
+ slice_ = self._stream[start:end]
547
+ self._next_emit_start = end
 
 
 
 
 
 
548
 
549
+ y = au.Waveform(slice_.astype(np.float32, copy=False), sr).as_stereo()
 
550
 
551
+ # 4) Post-processing / loudness
552
  if self.idx == 0 and self.params.ref_loop is not None:
553
  y, _ = match_loudness_to_reference(
554
  self.params.ref_loop, y,
555
  method=self.params.loudness_mode,
556
  headroom_db=self.params.headroom_db
557
  )
558
+ else:
559
+ apply_micro_fades(y, 3)
560
 
561
+ # 5) Resample + exact-length snap + encode
562
+ b64, meta = self._snap_and_encode(
563
+ y, seconds=chunk_secs, target_sr=self.params.target_sr, bars=self.params.bars_per_chunk
564
  )
565
+ meta["xfade_seconds"] = xfade
566
 
567
+ # 6) Publish
 
 
 
 
 
 
 
 
 
 
 
 
 
 
568
  with self._lock:
569
  self.idx += 1
570
  self.outbox.append(JamChunk(index=self.idx, audio_base64=b64, metadata=meta))
 
571
  if len(self.outbox) > 10:
572
  cutoff = self._last_delivered_index - 5
573
  self.outbox = [ch for ch in self.outbox if ch.index > cutoff]
574
 
575
+ # πŸ‘‰ If a reseed was requested, apply it *now*, between chunks
576
+ if self._pending_reseed is not None:
577
  pkg = self._pending_reseed
578
  self._pending_reseed = None
579
+
580
+ new_state = self.mrt.init_state()
581
+ new_state.context_tokens = pkg["ctx"] # exact (ctx_frames, depth)
582
+ self.state = new_state
583
+
584
+ # start a fresh stream and schedule one-time alignment
585
+ self._stream = None
586
+ self._next_emit_start = 0
587
+ self._reseed_ref_loop = pkg.get("ref") or self.params.combined_loop
588
  self._needs_bar_realign = True
 
589
 
590
+ print("πŸ” Reseed installed at bar boundary; will realign before next slice")
591
 
592
+ print(f"βœ… Completed chunk {self.idx}")
 
 
 
 
 
 
 
593
 
594
  print("πŸ›‘ JamWorker stopped")
595
+