thecollabagepatch commited on
Commit
db53efe
Β·
1 Parent(s): e686e92

reverting again

Browse files
Files changed (1) hide show
  1. jam_worker.py +117 -256
jam_worker.py CHANGED
@@ -409,307 +409,168 @@ class JamWorker(threading.Thread):
409
  self._pending_drop_intro_bars = getattr(self, "_pending_drop_intro_bars", 0) + 1
410
 
411
  def run(self):
412
- """
413
- Main worker loop:
414
- β€’ Generate continuous audio at model/native SR (sr_in).
415
- β€’ Maintain input-domain emit pointer for groove realign.
416
- β€’ Maintain an OUTPUT-domain streaming resampler (sr_out = 44100 by default).
417
- β€’ Emit EXACTLY bars_per_chunk at sr_out using a fractional phase accumulator.
418
- β€’ No per-chunk resampling; resampler carries state across chunks => seamless.
419
- """
420
- import numpy as np
421
- import time
422
- from math import floor, ceil
423
- from utils import wav_bytes_base64, match_loudness_to_reference, apply_micro_fades
424
-
425
- # ---------- Session timing ----------
426
- spb = self._seconds_per_bar() # seconds per bar
427
- chunk_secs = float(self.params.bars_per_chunk) * float(spb) # seconds per emitted chunk
428
-
429
- # ---------- Sample rates ----------
430
- sr_in = int(self.mrt.sample_rate) # model/native SR (e.g., 48000)
431
- sr_out = int(getattr(self.params, "target_sr", 44100) or 44100) # desired client SR (44.1k by default)
432
- self.params.target_sr = sr_out # reflect back in metadata
433
-
434
- # ---------- Crossfade (model-side stitching), seconds ----------
435
- xfade_seconds = float(self.mrt.config.crossfade_length)
436
-
437
- # ---------- INPUT-domain emit step (used for groove realign + generation need) ----------
438
- chunk_step_in_f = chunk_secs * sr_in # float samples per chunk (input domain)
439
- self._emit_phase = float(getattr(self, "_emit_phase", 0.0)) # carry across loops
440
-
441
- # ---------- OUTPUT-domain emit step (controls exact client length) ----------
442
- chunk_step_out_f = chunk_secs * sr_out
443
- self._emit_phase_out = float(getattr(self, "_emit_phase_out", 0.0))
444
- self._next_emit_start_out = int(getattr(self, "_next_emit_start_out", 0))
445
-
446
- # ---------- Continuous resampler state (into sr_out) ----------
447
- self._resampler = None
448
- self._stream_out = np.zeros((0, int(self.params.channels or 2)), dtype=np.float32)
449
- if sr_out != sr_in:
450
- # Lazy import to avoid hard dep if not needed
451
- from utils import StreamingResampler
452
- ch = int(self.params.channels or 2)
453
- self._resampler = StreamingResampler(in_sr=sr_in, out_sr=sr_out, channels=ch, quality="VHQ")
454
-
455
- # ---------- INPUT stream / pointers ----------
456
- # self._stream: np.ndarray (S_in, C) grows as we generate
457
- # self._next_emit_start: input-domain pointer we realign to bar boundary once at start / reseed
458
- self._stream = getattr(self, "_stream", None)
459
- self._next_emit_start = int(getattr(self, "_next_emit_start", 0))
460
- self._needs_bar_realign = bool(getattr(self, "_needs_bar_realign", True))
461
-
462
- # How much of INPUT we have already fed into the resampler (in samples @ sr_in)
463
- input_consumed = int(getattr(self, "_input_consumed", 0))
464
-
465
- # Delivery bookkeeping
466
- self.idx = int(getattr(self, "idx", 0))
467
- self._last_delivered_index = int(getattr(self, "_last_delivered_index", 0))
468
- self.outbox = getattr(self, "outbox", [])
469
-
470
- print("πŸš€ JamWorker started (bar-aligned streaming, stateful resampler)…")
471
-
472
- # ---------- Helpers inside run() ----------
473
- def _need_input(first_chunk_extra: bool = False) -> int:
474
- """
475
- How many INPUT-domain samples we still need in self._stream to be comfortable
476
- before emitting the next slice. Mirrors your fractional step math without
477
- mutating _emit_phase here.
478
- """
479
- total = 0 if self._stream is None else self._stream.shape[0]
480
- start = int(getattr(self, "_next_emit_start", 0))
481
- have = max(0, total - start)
482
-
483
- # Integer step we will advance by (input domain), non-mutating:
484
- step_int = int(floor(chunk_step_in_f + float(getattr(self, "_emit_phase", 0.0))))
485
-
486
- want = step_int
487
- if first_chunk_extra:
488
- # reserve 2 extra bars for downbeat/onset alignment safety
489
- want += int(ceil(2.0 * spb * sr_in))
490
 
 
 
 
 
 
 
 
491
  return max(0, want - have)
492
 
493
- def _feed_resampler_as_needed():
494
- """
495
- Ensure OUTPUT buffer (_stream_out) has resampled audio for any new INPUT
496
- samples appended to self._stream since we last consumed it.
497
- """
498
- nonlocal input_consumed, sr_in, sr_out
499
- total_in = 0 if self._stream is None else self._stream.shape[0]
500
- if total_in <= input_consumed:
501
- return # nothing new to feed
502
-
503
- # Slice the new INPUT region and push through streaming resampler (or pass-through)
504
- new_in = self._stream[input_consumed:total_in]
505
- if new_in.size == 0:
506
- return
507
-
508
- if self._resampler is not None:
509
- y_out = self._resampler.process(new_in, final=False)
510
- else:
511
- # No resampling needed; alias output to input
512
- y_out = new_in
513
-
514
- if y_out.size:
515
- self._stream_out = y_out if self._stream_out.size == 0 else np.vstack([self._stream_out, y_out])
516
-
517
- input_consumed = total_in # we've fed all available input into the (re)sampler
518
-
519
- def _output_have():
520
- """How many OUTPUT-domain samples are available to emit from current pointer."""
521
- total_out = 0 if self._stream_out is None else self._stream_out.shape[0]
522
- return max(0, total_out - self._next_emit_start_out)
523
-
524
- def _compute_step_in() -> int:
525
- """Integer input-domain step for internal pointer (non-mutating)."""
526
- return int(floor(chunk_step_in_f + float(getattr(self, "_emit_phase", 0.0))))
527
-
528
- def _compute_step_out() -> int:
529
- """Integer output-domain step for emission (non-mutating)."""
530
- return int(floor(chunk_step_out_f + float(getattr(self, "_emit_phase_out", 0.0))))
531
-
532
- def _advance_input_pointer():
533
- """Advance input emit pointer by the integer step and carry fractional phase."""
534
- step_total = chunk_step_in_f + self._emit_phase
535
- step_int = int(floor(step_total))
536
- self._emit_phase = float(step_total - step_int)
537
- self._next_emit_start += step_int
538
-
539
- def _advance_output_pointer():
540
- """Advance output emit pointer by the integer step and carry fractional phase."""
541
- step_total = chunk_step_out_f + self._emit_phase_out
542
- step_int = int(floor(step_total))
543
- self._emit_phase_out = float(step_total - step_int)
544
- self._next_emit_start_out += step_int
545
-
546
- def _trim_buffers_if_needed():
547
- """
548
- Keep memory bounded by dropping already-emitted OUTPUT and corresponding INPUT,
549
- while keeping indices consistent.
550
- """
551
- # Drop OUTPUT head
552
- if self._next_emit_start_out > 3 * int(chunk_step_out_f or sr_out):
553
- cut = int(self._next_emit_start_out)
554
- self._stream_out = self._stream_out[cut:]
555
- self._next_emit_start_out -= cut
556
-
557
- # Drop INPUT head **only** if we've consumed it into resampler AND it's before emit start
558
- # (emit start is for alignment math; after first chunk we keep advancing anyway)
559
- head_can_drop = min(input_consumed, self._next_emit_start)
560
- if head_can_drop > sr_in * 8: # keep a few bars as safety
561
- drop = head_can_drop - int(sr_in * 4)
562
- if drop > 0:
563
- self._stream = self._stream[drop:]
564
- self._next_emit_start -= drop
565
- input_consumed -= drop
566
-
567
- # ---------- Main loop ----------
568
  while not self._stop_event.is_set():
569
- # Throttle if we're too far ahead of the consumer
570
  if not self._should_generate_next_chunk():
571
  time.sleep(0.25)
572
  continue
573
 
574
- # 1) Ensure we have enough INPUT material for the next slice (and first-chunk extra)
575
- need_in = _need_input(first_chunk_extra=(self.idx == 0))
576
- while need_in > 0 and not self._stop_event.is_set():
577
- # Model generation step; xfade into persistent INPUT stream
578
  with self._lock:
579
  style_vec = self.params.style_vec
580
  self.mrt.guidance_weight = float(self.params.guidance_weight)
581
  self.mrt.temperature = float(self.params.temperature)
582
  self.mrt.topk = int(self.params.topk)
583
-
584
  wav, self.state = self.mrt.generate_chunk(state=self.state, style=style_vec)
585
- self._append_model_chunk_to_stream(wav) # equal-power crossfade into self._stream
586
-
587
- # Feed any newly appended INPUT into the OUTPUT resampler
588
- _feed_resampler_as_needed()
589
-
590
- need_in = _need_input(first_chunk_extra=(self.idx == 0))
591
 
592
  if self._stop_event.is_set():
593
  break
594
 
595
- # 2) One-time: tempo/bar realign in INPUT domain before emitting the *first* chunk
596
- if self._needs_bar_realign:
597
- self._realign_emit_pointer_to_bar(sr_in)
598
- self._emit_phase = 0.0 # reset input fractional phase after snapping to grid
599
-
600
- # Set INPUT→RESAMPLER start so the very first OUTPUT sample corresponds to _next_emit_start
601
- input_consumed = max(input_consumed, self._next_emit_start)
 
 
 
 
 
602
  self._needs_bar_realign = False
 
603
 
604
- # Feed any post-snap INPUT into OUTPUT resampler so we have aligned OUTPUT available
605
- _feed_resampler_as_needed()
606
-
607
- # 3) Ensure OUTPUT buffer has enough samples for the next emission step
608
- step_out_int = _compute_step_out()
609
- while _output_have() < step_out_int and not self._stop_event.is_set():
610
- # If OUTPUT is short, try feeding more INPUT into resampler; if INPUT has no new data, generate more
611
- _feed_resampler_as_needed()
612
- if _output_have() < step_out_int:
613
- # generate another model chunk
614
- with self._lock:
615
- style_vec = self.params.style_vec
616
- self.mrt.guidance_weight = float(self.params.guidance_weight)
617
- self.mrt.temperature = float(self.params.temperature)
618
- self.mrt.topk = int(self.params.topk)
619
- wav, self.state = self.mrt.generate_chunk(state=self.state, style=style_vec)
620
- self._append_model_chunk_to_stream(wav)
621
- _feed_resampler_as_needed()
622
-
623
- if self._stop_event.is_set():
624
- break
625
-
626
- # 4) Slice OUTPUT-domain chunk exactly step_out_int long and (optionally) loudness-align the first one
627
- start_out = int(self._next_emit_start_out)
628
- end_out = start_out + int(step_out_int)
629
-
630
- total_out = 0 if self._stream_out is None else self._stream_out.shape[0]
631
- if end_out > total_out:
632
- # Should be rare due to loop above, but guard anyway
633
- time.sleep(0.01)
634
  continue
635
 
636
- y_send = self._stream_out[start_out:end_out]
 
637
 
638
- if self.idx == 0 and getattr(self.params, "ref_loop", None) is not None:
639
- # First chunk: match loudness to reference if requested
640
- y_send, _ = match_loudness_to_reference(
641
- self.params.ref_loop, y_send,
642
- method=getattr(self.params, "loudness_mode", "integrated"),
643
- headroom_db=getattr(self.params, "headroom_db", 1.0),
 
 
644
  )
645
  else:
646
- # With a continuous stateful resampler, no per-chunk fades are needed.
647
- # If you *really* want safety fades, do 1 ms only on first/last when stopping.
648
- pass
649
-
650
- # 5) Encode WAV (already exact length at sr_out)
651
- b64, total_samples, channels = wav_bytes_base64(y_send, sr_out)
652
- meta = {
653
- "bpm": float(self.params.bpm),
654
- "bars": int(self.params.bars_per_chunk),
655
- "seconds": float(chunk_secs),
656
- "sample_rate": int(sr_out),
657
- "samples": int(total_samples),
658
- "channels": int(channels),
659
- "xfade_seconds": float(xfade_seconds),
660
- }
661
-
662
- # 6) Publish + advance both emit pointers
663
  with self._lock:
664
  self.idx += 1
665
  self.outbox.append(JamChunk(index=self.idx, audio_base64=b64, metadata=meta))
666
- # prune outbox to keep memory in check
667
  if len(self.outbox) > 10:
668
  cutoff = self._last_delivered_index - 5
669
  self.outbox = [ch for ch in self.outbox if ch.index > cutoff]
670
 
671
- # Handle reseed requests BETWEEN chunks
672
- if getattr(self, "_pending_reseed", None) is not None:
673
  pkg = self._pending_reseed
674
  self._pending_reseed = None
675
 
676
- # Reset model state with fresh bar-aligned context tokens
677
  new_state = self.mrt.init_state()
678
- new_state.context_tokens = pkg["ctx"]
679
  self.state = new_state
680
 
681
- # Reset INPUT stream and schedule one-time bar realign
682
  self._stream = None
683
  self._next_emit_start = 0
684
  self._reseed_ref_loop = pkg.get("ref") or self.params.combined_loop
685
  self._needs_bar_realign = True
686
 
687
- # Reset OUTPUT-domain streaming state
688
- self._stream_out = np.zeros((0, int(self.params.channels or 2)), dtype=np.float32)
689
- self._next_emit_start_out = 0
690
- self._emit_phase_out = 0.0
691
- input_consumed = 0
692
- if self._resampler is not None:
693
- # Rebuild the resampler to clear its filter tail
694
- from utils import StreamingResampler
695
- ch = int(self.params.channels or 2)
696
- self._resampler = StreamingResampler(in_sr=sr_in, out_sr=sr_out, channels=ch, quality="VHQ")
697
-
698
  print("πŸ” Reseed installed at bar boundary; will realign before next slice")
699
 
700
- # Advance both emit pointers for next round
701
- _advance_input_pointer()
702
- _advance_output_pointer()
703
-
704
- # Keep memory tidy
705
- _trim_buffers_if_needed()
706
-
707
  print(f"βœ… Completed chunk {self.idx}")
708
 
709
- # Stop: flush tail from resampler (optional)
710
- if self._resampler is not None:
711
- tail = self._resampler.flush()
712
- if tail.size:
713
- self._stream_out = tail if self._stream_out.size == 0 else np.vstack([self._stream_out, tail])
714
-
715
  print("πŸ›‘ JamWorker stopped")
 
 
409
  self._pending_drop_intro_bars = getattr(self, "_pending_drop_intro_bars", 0) + 1
410
 
411
  def run(self):
412
+ """Main worker loop β€” generate into a continuous stream, then emit bar-aligned slices."""
413
+ spb = self._seconds_per_bar() # seconds per bar
414
+ chunk_secs = self.params.bars_per_chunk * spb
415
+ xfade = float(self.mrt.config.crossfade_length) # seconds
416
+ sr = int(self.mrt.sample_rate)
417
+ chunk_samps = int(round(chunk_secs * sr))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
418
 
419
+ def _need(first_chunk_extra=False):
420
+ """How many more samples we still need in the stream to emit next slice."""
421
+ have = 0 if getattr(self, "_stream", None) is None else self._stream.shape[0] - getattr(self, "_next_emit_start", 0)
422
+ want = chunk_samps
423
+ if first_chunk_extra:
424
+ # reserve two bars extra so first-chunk onset alignment has material
425
+ want += int(round(2 * spb * sr))
426
  return max(0, want - have)
427
 
428
+ def _mono_env(x: np.ndarray, sr: int, win_ms: float = 10.0) -> np.ndarray:
429
+ if x.ndim == 2: x = x.mean(axis=1)
430
+ x = np.abs(x).astype(np.float32)
431
+ w = max(1, int(round(win_ms * 1e-3 * sr)))
432
+ if w > 1:
433
+ kern = np.ones(w, dtype=np.float32) / float(w)
434
+ x = np.convolve(x, kern, mode="same")
435
+ d = np.diff(x, prepend=x[:1])
436
+ d[d < 0] = 0.0
437
+ return d
438
+
439
+ def _estimate_first_offset_samples(ref_loop_wav, gen_head_wav, sr: int, spb: float) -> int:
440
+ """Tempo-aware first-downbeat offset (positive => model late)."""
441
+ try:
442
+ max_ms = int(max(160.0, min(0.25 * spb * 1000.0, 450.0)))
443
+ ref = ref_loop_wav if ref_loop_wav.sample_rate == sr else ref_loop_wav.resample(sr)
444
+ n_bar = int(round(spb * sr))
445
+ ref_tail = ref.samples[-n_bar:, :] if ref.samples.shape[0] >= n_bar else ref.samples
446
+ gen_head = gen_head_wav.samples[: int(2 * n_bar), :]
447
+ if ref_tail.size == 0 or gen_head.size == 0:
448
+ return 0
449
+
450
+ # envelopes + z-score
451
+ import numpy as np
452
+ def _z(a):
453
+ m, s = float(a.mean()), float(a.std() or 1.0); return (a - m) / s
454
+ e_ref = _z(_mono_env(ref_tail, sr)).astype(np.float32)
455
+ e_gen = _z(_mono_env(gen_head, sr)).astype(np.float32)
456
+
457
+ # upsample x4 for finer lag
458
+ def _upsample(a, r=4):
459
+ n = len(a); grid = np.arange(n, dtype=np.float32)
460
+ fine = np.linspace(0, n - 1, num=n * r, dtype=np.float32)
461
+ return np.interp(fine, grid, a).astype(np.float32)
462
+ up = 4
463
+ e_ref_u, e_gen_u = _upsample(e_ref, up), _upsample(e_gen, up)
464
+
465
+ max_lag_u = int(round((max_ms / 1000.0) * sr * up))
466
+ seg = min(len(e_ref_u), len(e_gen_u))
467
+ e_ref_u = e_ref_u[-seg:]
468
+ pad = np.zeros(max_lag_u, dtype=np.float32)
469
+ e_gen_u_pad = np.concatenate([pad, e_gen_u, pad])
470
+
471
+ best_lag_u, best_score = 0, -1e9
472
+ for lag_u in range(-max_lag_u, max_lag_u + 1):
473
+ start = max_lag_u + lag_u
474
+ b = e_gen_u_pad[start : start + seg]
475
+ denom = (np.linalg.norm(e_ref_u) * np.linalg.norm(b)) or 1.0
476
+ score = float(np.dot(e_ref_u, b) / denom)
477
+ if score > best_score:
478
+ best_score, best_lag_u = score, lag_u
479
+ return int(round(best_lag_u / up))
480
+ except Exception:
481
+ return 0
482
+
483
+ print("πŸš€ JamWorker started (bar-aligned streaming)…")
484
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
485
  while not self._stop_event.is_set():
 
486
  if not self._should_generate_next_chunk():
487
  time.sleep(0.25)
488
  continue
489
 
490
+ # 1) Generate until we have enough material in the stream
491
+ need = _need(first_chunk_extra=(self.idx == 0))
492
+ while need > 0 and not self._stop_event.is_set():
 
493
  with self._lock:
494
  style_vec = self.params.style_vec
495
  self.mrt.guidance_weight = float(self.params.guidance_weight)
496
  self.mrt.temperature = float(self.params.temperature)
497
  self.mrt.topk = int(self.params.topk)
 
498
  wav, self.state = self.mrt.generate_chunk(state=self.state, style=style_vec)
499
+ self._append_model_chunk_to_stream(wav) # equal-power xfade into a persistent stream
500
+ need = _need(first_chunk_extra=(self.idx == 0))
 
 
 
 
501
 
502
  if self._stop_event.is_set():
503
  break
504
 
505
+ # 2) One-time: align the emit pointer to the groove
506
+ if (self.idx == 0 and self.params.combined_loop is not None) or self._needs_bar_realign:
507
+ ref_loop = self._reseed_ref_loop or self.params.combined_loop
508
+ if ref_loop is not None:
509
+ head_len = min(self._stream.shape[0] - self._next_emit_start, int(round(2 * spb * sr)))
510
+ seg = self._stream[self._next_emit_start : self._next_emit_start + head_len]
511
+ gen_head = au.Waveform(seg.astype(np.float32, copy=False), sr).as_stereo()
512
+ offs = _estimate_first_offset_samples(ref_loop, gen_head, sr, spb)
513
+ if offs != 0:
514
+ self._next_emit_start = max(0, self._next_emit_start + offs)
515
+ print(f"🎯 Offset compensation: {offs/sr:+.3f}s")
516
+ self._realign_emit_pointer_to_bar(sr)
517
  self._needs_bar_realign = False
518
+ self._reseed_ref_loop = None
519
 
520
+ # 3) Emit exactly bars_per_chunk Γ— spb from the stream
521
+ start = self._next_emit_start
522
+ end = start + chunk_samps
523
+ if end > self._stream.shape[0]:
524
+ # shouldn't happen often; generate a bit more and loop
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
525
  continue
526
 
527
+ slice_ = self._stream[start:end]
528
+ self._next_emit_start = end
529
 
530
+ y = au.Waveform(slice_.astype(np.float32, copy=False), sr).as_stereo()
531
+
532
+ # 4) Post-processing / loudness
533
+ if self.idx == 0 and self.params.ref_loop is not None:
534
+ y, _ = match_loudness_to_reference(
535
+ self.params.ref_loop, y,
536
+ method=self.params.loudness_mode,
537
+ headroom_db=self.params.headroom_db
538
  )
539
  else:
540
+ apply_micro_fades(y, 3)
541
+
542
+ # 5) Resample + exact-length snap + encode
543
+ b64, meta = self._snap_and_encode(
544
+ y, seconds=chunk_secs, target_sr=self.params.target_sr, bars=self.params.bars_per_chunk
545
+ )
546
+ meta["xfade_seconds"] = xfade
547
+
548
+ # 6) Publish
 
 
 
 
 
 
 
 
549
  with self._lock:
550
  self.idx += 1
551
  self.outbox.append(JamChunk(index=self.idx, audio_base64=b64, metadata=meta))
 
552
  if len(self.outbox) > 10:
553
  cutoff = self._last_delivered_index - 5
554
  self.outbox = [ch for ch in self.outbox if ch.index > cutoff]
555
 
556
+ # πŸ‘‰ If a reseed was requested, apply it *now*, between chunks
557
+ if self._pending_reseed is not None:
558
  pkg = self._pending_reseed
559
  self._pending_reseed = None
560
 
 
561
  new_state = self.mrt.init_state()
562
+ new_state.context_tokens = pkg["ctx"] # exact (ctx_frames, depth)
563
  self.state = new_state
564
 
565
+ # start a fresh stream and schedule one-time alignment
566
  self._stream = None
567
  self._next_emit_start = 0
568
  self._reseed_ref_loop = pkg.get("ref") or self.params.combined_loop
569
  self._needs_bar_realign = True
570
 
 
 
 
 
 
 
 
 
 
 
 
571
  print("πŸ” Reseed installed at bar boundary; will realign before next slice")
572
 
 
 
 
 
 
 
 
573
  print(f"βœ… Completed chunk {self.idx}")
574
 
 
 
 
 
 
 
575
  print("πŸ›‘ JamWorker stopped")
576
+