thecollabagepatch commited on
Commit
e686e92
·
1 Parent(s): 4587340

another try

Browse files
Files changed (1) hide show
  1. jam_worker.py +256 -117
jam_worker.py CHANGED
@@ -409,168 +409,307 @@ class JamWorker(threading.Thread):
409
  self._pending_drop_intro_bars = getattr(self, "_pending_drop_intro_bars", 0) + 1
410
 
411
  def run(self):
412
- """Main worker loop — generate into a continuous stream, then emit bar-aligned slices."""
413
- spb = self._seconds_per_bar() # seconds per bar
414
- chunk_secs = self.params.bars_per_chunk * spb
415
- xfade = float(self.mrt.config.crossfade_length) # seconds
416
- sr = int(self.mrt.sample_rate)
417
- chunk_samps = int(round(chunk_secs * sr))
418
-
419
- def _need(first_chunk_extra=False):
420
- """How many more samples we still need in the stream to emit next slice."""
421
- have = 0 if getattr(self, "_stream", None) is None else self._stream.shape[0] - getattr(self, "_next_emit_start", 0)
422
- want = chunk_samps
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
423
  if first_chunk_extra:
424
- # reserve two bars extra so first-chunk onset alignment has material
425
- want += int(round(2 * spb * sr))
426
- return max(0, want - have)
427
 
428
- def _mono_env(x: np.ndarray, sr: int, win_ms: float = 10.0) -> np.ndarray:
429
- if x.ndim == 2: x = x.mean(axis=1)
430
- x = np.abs(x).astype(np.float32)
431
- w = max(1, int(round(win_ms * 1e-3 * sr)))
432
- if w > 1:
433
- kern = np.ones(w, dtype=np.float32) / float(w)
434
- x = np.convolve(x, kern, mode="same")
435
- d = np.diff(x, prepend=x[:1])
436
- d[d < 0] = 0.0
437
- return d
438
-
439
- def _estimate_first_offset_samples(ref_loop_wav, gen_head_wav, sr: int, spb: float) -> int:
440
- """Tempo-aware first-downbeat offset (positive => model late)."""
441
- try:
442
- max_ms = int(max(160.0, min(0.25 * spb * 1000.0, 450.0)))
443
- ref = ref_loop_wav if ref_loop_wav.sample_rate == sr else ref_loop_wav.resample(sr)
444
- n_bar = int(round(spb * sr))
445
- ref_tail = ref.samples[-n_bar:, :] if ref.samples.shape[0] >= n_bar else ref.samples
446
- gen_head = gen_head_wav.samples[: int(2 * n_bar), :]
447
- if ref_tail.size == 0 or gen_head.size == 0:
448
- return 0
449
-
450
- # envelopes + z-score
451
- import numpy as np
452
- def _z(a):
453
- m, s = float(a.mean()), float(a.std() or 1.0); return (a - m) / s
454
- e_ref = _z(_mono_env(ref_tail, sr)).astype(np.float32)
455
- e_gen = _z(_mono_env(gen_head, sr)).astype(np.float32)
456
-
457
- # upsample x4 for finer lag
458
- def _upsample(a, r=4):
459
- n = len(a); grid = np.arange(n, dtype=np.float32)
460
- fine = np.linspace(0, n - 1, num=n * r, dtype=np.float32)
461
- return np.interp(fine, grid, a).astype(np.float32)
462
- up = 4
463
- e_ref_u, e_gen_u = _upsample(e_ref, up), _upsample(e_gen, up)
464
-
465
- max_lag_u = int(round((max_ms / 1000.0) * sr * up))
466
- seg = min(len(e_ref_u), len(e_gen_u))
467
- e_ref_u = e_ref_u[-seg:]
468
- pad = np.zeros(max_lag_u, dtype=np.float32)
469
- e_gen_u_pad = np.concatenate([pad, e_gen_u, pad])
470
-
471
- best_lag_u, best_score = 0, -1e9
472
- for lag_u in range(-max_lag_u, max_lag_u + 1):
473
- start = max_lag_u + lag_u
474
- b = e_gen_u_pad[start : start + seg]
475
- denom = (np.linalg.norm(e_ref_u) * np.linalg.norm(b)) or 1.0
476
- score = float(np.dot(e_ref_u, b) / denom)
477
- if score > best_score:
478
- best_score, best_lag_u = score, lag_u
479
- return int(round(best_lag_u / up))
480
- except Exception:
481
- return 0
482
-
483
- print("🚀 JamWorker started (bar-aligned streaming)…")
484
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
485
  while not self._stop_event.is_set():
 
486
  if not self._should_generate_next_chunk():
487
  time.sleep(0.25)
488
  continue
489
 
490
- # 1) Generate until we have enough material in the stream
491
- need = _need(first_chunk_extra=(self.idx == 0))
492
- while need > 0 and not self._stop_event.is_set():
 
493
  with self._lock:
494
  style_vec = self.params.style_vec
495
  self.mrt.guidance_weight = float(self.params.guidance_weight)
496
  self.mrt.temperature = float(self.params.temperature)
497
  self.mrt.topk = int(self.params.topk)
 
498
  wav, self.state = self.mrt.generate_chunk(state=self.state, style=style_vec)
499
- self._append_model_chunk_to_stream(wav) # equal-power xfade into a persistent stream
500
- need = _need(first_chunk_extra=(self.idx == 0))
 
 
 
 
501
 
502
  if self._stop_event.is_set():
503
  break
504
 
505
- # 2) One-time: align the emit pointer to the groove
506
- if (self.idx == 0 and self.params.combined_loop is not None) or self._needs_bar_realign:
507
- ref_loop = self._reseed_ref_loop or self.params.combined_loop
508
- if ref_loop is not None:
509
- head_len = min(self._stream.shape[0] - self._next_emit_start, int(round(2 * spb * sr)))
510
- seg = self._stream[self._next_emit_start : self._next_emit_start + head_len]
511
- gen_head = au.Waveform(seg.astype(np.float32, copy=False), sr).as_stereo()
512
- offs = _estimate_first_offset_samples(ref_loop, gen_head, sr, spb)
513
- if offs != 0:
514
- self._next_emit_start = max(0, self._next_emit_start + offs)
515
- print(f"🎯 Offset compensation: {offs/sr:+.3f}s")
516
- self._realign_emit_pointer_to_bar(sr)
517
  self._needs_bar_realign = False
518
- self._reseed_ref_loop = None
519
 
520
- # 3) Emit exactly bars_per_chunk × spb from the stream
521
- start = self._next_emit_start
522
- end = start + chunk_samps
523
- if end > self._stream.shape[0]:
524
- # shouldn't happen often; generate a bit more and loop
525
- continue
 
 
 
 
 
 
 
 
 
 
 
 
526
 
527
- slice_ = self._stream[start:end]
528
- self._next_emit_start = end
529
 
530
- y = au.Waveform(slice_.astype(np.float32, copy=False), sr).as_stereo()
 
 
531
 
532
- # 4) Post-processing / loudness
533
- if self.idx == 0 and self.params.ref_loop is not None:
534
- y, _ = match_loudness_to_reference(
535
- self.params.ref_loop, y,
536
- method=self.params.loudness_mode,
537
- headroom_db=self.params.headroom_db
538
- )
539
- else:
540
- apply_micro_fades(y, 3)
541
 
542
- # 5) Resample + exact-length snap + encode
543
- b64, meta = self._snap_and_encode(
544
- y, seconds=chunk_secs, target_sr=self.params.target_sr, bars=self.params.bars_per_chunk
545
- )
546
- meta["xfade_seconds"] = xfade
547
 
548
- # 6) Publish
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
549
  with self._lock:
550
  self.idx += 1
551
  self.outbox.append(JamChunk(index=self.idx, audio_base64=b64, metadata=meta))
 
552
  if len(self.outbox) > 10:
553
  cutoff = self._last_delivered_index - 5
554
  self.outbox = [ch for ch in self.outbox if ch.index > cutoff]
555
 
556
- # 👉 If a reseed was requested, apply it *now*, between chunks
557
- if self._pending_reseed is not None:
558
  pkg = self._pending_reseed
559
  self._pending_reseed = None
560
 
 
561
  new_state = self.mrt.init_state()
562
- new_state.context_tokens = pkg["ctx"] # exact (ctx_frames, depth)
563
  self.state = new_state
564
 
565
- # start a fresh stream and schedule one-time alignment
566
  self._stream = None
567
  self._next_emit_start = 0
568
  self._reseed_ref_loop = pkg.get("ref") or self.params.combined_loop
569
  self._needs_bar_realign = True
570
 
 
 
 
 
 
 
 
 
 
 
 
571
  print("🔁 Reseed installed at bar boundary; will realign before next slice")
572
 
 
 
 
 
 
 
 
573
  print(f"✅ Completed chunk {self.idx}")
574
 
575
- print("🛑 JamWorker stopped")
 
 
 
 
576
 
 
 
409
  self._pending_drop_intro_bars = getattr(self, "_pending_drop_intro_bars", 0) + 1
410
 
411
  def run(self):
412
+ """
413
+ Main worker loop:
414
+ Generate continuous audio at model/native SR (sr_in).
415
+ Maintain input-domain emit pointer for groove realign.
416
+ Maintain an OUTPUT-domain streaming resampler (sr_out = 44100 by default).
417
+ Emit EXACTLY bars_per_chunk at sr_out using a fractional phase accumulator.
418
+ • No per-chunk resampling; resampler carries state across chunks => seamless.
419
+ """
420
+ import numpy as np
421
+ import time
422
+ from math import floor, ceil
423
+ from utils import wav_bytes_base64, match_loudness_to_reference, apply_micro_fades
424
+
425
+ # ---------- Session timing ----------
426
+ spb = self._seconds_per_bar() # seconds per bar
427
+ chunk_secs = float(self.params.bars_per_chunk) * float(spb) # seconds per emitted chunk
428
+
429
+ # ---------- Sample rates ----------
430
+ sr_in = int(self.mrt.sample_rate) # model/native SR (e.g., 48000)
431
+ sr_out = int(getattr(self.params, "target_sr", 44100) or 44100) # desired client SR (44.1k by default)
432
+ self.params.target_sr = sr_out # reflect back in metadata
433
+
434
+ # ---------- Crossfade (model-side stitching), seconds ----------
435
+ xfade_seconds = float(self.mrt.config.crossfade_length)
436
+
437
+ # ---------- INPUT-domain emit step (used for groove realign + generation need) ----------
438
+ chunk_step_in_f = chunk_secs * sr_in # float samples per chunk (input domain)
439
+ self._emit_phase = float(getattr(self, "_emit_phase", 0.0)) # carry across loops
440
+
441
+ # ---------- OUTPUT-domain emit step (controls exact client length) ----------
442
+ chunk_step_out_f = chunk_secs * sr_out
443
+ self._emit_phase_out = float(getattr(self, "_emit_phase_out", 0.0))
444
+ self._next_emit_start_out = int(getattr(self, "_next_emit_start_out", 0))
445
+
446
+ # ---------- Continuous resampler state (into sr_out) ----------
447
+ self._resampler = None
448
+ self._stream_out = np.zeros((0, int(self.params.channels or 2)), dtype=np.float32)
449
+ if sr_out != sr_in:
450
+ # Lazy import to avoid hard dep if not needed
451
+ from utils import StreamingResampler
452
+ ch = int(self.params.channels or 2)
453
+ self._resampler = StreamingResampler(in_sr=sr_in, out_sr=sr_out, channels=ch, quality="VHQ")
454
+
455
+ # ---------- INPUT stream / pointers ----------
456
+ # self._stream: np.ndarray (S_in, C) grows as we generate
457
+ # self._next_emit_start: input-domain pointer we realign to bar boundary once at start / reseed
458
+ self._stream = getattr(self, "_stream", None)
459
+ self._next_emit_start = int(getattr(self, "_next_emit_start", 0))
460
+ self._needs_bar_realign = bool(getattr(self, "_needs_bar_realign", True))
461
+
462
+ # How much of INPUT we have already fed into the resampler (in samples @ sr_in)
463
+ input_consumed = int(getattr(self, "_input_consumed", 0))
464
+
465
+ # Delivery bookkeeping
466
+ self.idx = int(getattr(self, "idx", 0))
467
+ self._last_delivered_index = int(getattr(self, "_last_delivered_index", 0))
468
+ self.outbox = getattr(self, "outbox", [])
469
+
470
+ print("🚀 JamWorker started (bar-aligned streaming, stateful resampler)…")
471
+
472
+ # ---------- Helpers inside run() ----------
473
+ def _need_input(first_chunk_extra: bool = False) -> int:
474
+ """
475
+ How many INPUT-domain samples we still need in self._stream to be comfortable
476
+ before emitting the next slice. Mirrors your fractional step math without
477
+ mutating _emit_phase here.
478
+ """
479
+ total = 0 if self._stream is None else self._stream.shape[0]
480
+ start = int(getattr(self, "_next_emit_start", 0))
481
+ have = max(0, total - start)
482
+
483
+ # Integer step we will advance by (input domain), non-mutating:
484
+ step_int = int(floor(chunk_step_in_f + float(getattr(self, "_emit_phase", 0.0))))
485
+
486
+ want = step_int
487
  if first_chunk_extra:
488
+ # reserve 2 extra bars for downbeat/onset alignment safety
489
+ want += int(ceil(2.0 * spb * sr_in))
 
490
 
491
+ return max(0, want - have)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
492
 
493
+ def _feed_resampler_as_needed():
494
+ """
495
+ Ensure OUTPUT buffer (_stream_out) has resampled audio for any new INPUT
496
+ samples appended to self._stream since we last consumed it.
497
+ """
498
+ nonlocal input_consumed, sr_in, sr_out
499
+ total_in = 0 if self._stream is None else self._stream.shape[0]
500
+ if total_in <= input_consumed:
501
+ return # nothing new to feed
502
+
503
+ # Slice the new INPUT region and push through streaming resampler (or pass-through)
504
+ new_in = self._stream[input_consumed:total_in]
505
+ if new_in.size == 0:
506
+ return
507
+
508
+ if self._resampler is not None:
509
+ y_out = self._resampler.process(new_in, final=False)
510
+ else:
511
+ # No resampling needed; alias output to input
512
+ y_out = new_in
513
+
514
+ if y_out.size:
515
+ self._stream_out = y_out if self._stream_out.size == 0 else np.vstack([self._stream_out, y_out])
516
+
517
+ input_consumed = total_in # we've fed all available input into the (re)sampler
518
+
519
+ def _output_have():
520
+ """How many OUTPUT-domain samples are available to emit from current pointer."""
521
+ total_out = 0 if self._stream_out is None else self._stream_out.shape[0]
522
+ return max(0, total_out - self._next_emit_start_out)
523
+
524
+ def _compute_step_in() -> int:
525
+ """Integer input-domain step for internal pointer (non-mutating)."""
526
+ return int(floor(chunk_step_in_f + float(getattr(self, "_emit_phase", 0.0))))
527
+
528
+ def _compute_step_out() -> int:
529
+ """Integer output-domain step for emission (non-mutating)."""
530
+ return int(floor(chunk_step_out_f + float(getattr(self, "_emit_phase_out", 0.0))))
531
+
532
+ def _advance_input_pointer():
533
+ """Advance input emit pointer by the integer step and carry fractional phase."""
534
+ step_total = chunk_step_in_f + self._emit_phase
535
+ step_int = int(floor(step_total))
536
+ self._emit_phase = float(step_total - step_int)
537
+ self._next_emit_start += step_int
538
+
539
+ def _advance_output_pointer():
540
+ """Advance output emit pointer by the integer step and carry fractional phase."""
541
+ step_total = chunk_step_out_f + self._emit_phase_out
542
+ step_int = int(floor(step_total))
543
+ self._emit_phase_out = float(step_total - step_int)
544
+ self._next_emit_start_out += step_int
545
+
546
+ def _trim_buffers_if_needed():
547
+ """
548
+ Keep memory bounded by dropping already-emitted OUTPUT and corresponding INPUT,
549
+ while keeping indices consistent.
550
+ """
551
+ # Drop OUTPUT head
552
+ if self._next_emit_start_out > 3 * int(chunk_step_out_f or sr_out):
553
+ cut = int(self._next_emit_start_out)
554
+ self._stream_out = self._stream_out[cut:]
555
+ self._next_emit_start_out -= cut
556
+
557
+ # Drop INPUT head **only** if we've consumed it into resampler AND it's before emit start
558
+ # (emit start is for alignment math; after first chunk we keep advancing anyway)
559
+ head_can_drop = min(input_consumed, self._next_emit_start)
560
+ if head_can_drop > sr_in * 8: # keep a few bars as safety
561
+ drop = head_can_drop - int(sr_in * 4)
562
+ if drop > 0:
563
+ self._stream = self._stream[drop:]
564
+ self._next_emit_start -= drop
565
+ input_consumed -= drop
566
+
567
+ # ---------- Main loop ----------
568
  while not self._stop_event.is_set():
569
+ # Throttle if we're too far ahead of the consumer
570
  if not self._should_generate_next_chunk():
571
  time.sleep(0.25)
572
  continue
573
 
574
+ # 1) Ensure we have enough INPUT material for the next slice (and first-chunk extra)
575
+ need_in = _need_input(first_chunk_extra=(self.idx == 0))
576
+ while need_in > 0 and not self._stop_event.is_set():
577
+ # Model generation step; xfade into persistent INPUT stream
578
  with self._lock:
579
  style_vec = self.params.style_vec
580
  self.mrt.guidance_weight = float(self.params.guidance_weight)
581
  self.mrt.temperature = float(self.params.temperature)
582
  self.mrt.topk = int(self.params.topk)
583
+
584
  wav, self.state = self.mrt.generate_chunk(state=self.state, style=style_vec)
585
+ self._append_model_chunk_to_stream(wav) # equal-power crossfade into self._stream
586
+
587
+ # Feed any newly appended INPUT into the OUTPUT resampler
588
+ _feed_resampler_as_needed()
589
+
590
+ need_in = _need_input(first_chunk_extra=(self.idx == 0))
591
 
592
  if self._stop_event.is_set():
593
  break
594
 
595
+ # 2) One-time: tempo/bar realign in INPUT domain before emitting the *first* chunk
596
+ if self._needs_bar_realign:
597
+ self._realign_emit_pointer_to_bar(sr_in)
598
+ self._emit_phase = 0.0 # reset input fractional phase after snapping to grid
599
+
600
+ # Set INPUT→RESAMPLER start so the very first OUTPUT sample corresponds to _next_emit_start
601
+ input_consumed = max(input_consumed, self._next_emit_start)
 
 
 
 
 
602
  self._needs_bar_realign = False
 
603
 
604
+ # Feed any post-snap INPUT into OUTPUT resampler so we have aligned OUTPUT available
605
+ _feed_resampler_as_needed()
606
+
607
+ # 3) Ensure OUTPUT buffer has enough samples for the next emission step
608
+ step_out_int = _compute_step_out()
609
+ while _output_have() < step_out_int and not self._stop_event.is_set():
610
+ # If OUTPUT is short, try feeding more INPUT into resampler; if INPUT has no new data, generate more
611
+ _feed_resampler_as_needed()
612
+ if _output_have() < step_out_int:
613
+ # generate another model chunk
614
+ with self._lock:
615
+ style_vec = self.params.style_vec
616
+ self.mrt.guidance_weight = float(self.params.guidance_weight)
617
+ self.mrt.temperature = float(self.params.temperature)
618
+ self.mrt.topk = int(self.params.topk)
619
+ wav, self.state = self.mrt.generate_chunk(state=self.state, style=style_vec)
620
+ self._append_model_chunk_to_stream(wav)
621
+ _feed_resampler_as_needed()
622
 
623
+ if self._stop_event.is_set():
624
+ break
625
 
626
+ # 4) Slice OUTPUT-domain chunk exactly step_out_int long and (optionally) loudness-align the first one
627
+ start_out = int(self._next_emit_start_out)
628
+ end_out = start_out + int(step_out_int)
629
 
630
+ total_out = 0 if self._stream_out is None else self._stream_out.shape[0]
631
+ if end_out > total_out:
632
+ # Should be rare due to loop above, but guard anyway
633
+ time.sleep(0.01)
634
+ continue
 
 
 
 
635
 
636
+ y_send = self._stream_out[start_out:end_out]
 
 
 
 
637
 
638
+ if self.idx == 0 and getattr(self.params, "ref_loop", None) is not None:
639
+ # First chunk: match loudness to reference if requested
640
+ y_send, _ = match_loudness_to_reference(
641
+ self.params.ref_loop, y_send,
642
+ method=getattr(self.params, "loudness_mode", "integrated"),
643
+ headroom_db=getattr(self.params, "headroom_db", 1.0),
644
+ )
645
+ else:
646
+ # With a continuous stateful resampler, no per-chunk fades are needed.
647
+ # If you *really* want safety fades, do 1 ms only on first/last when stopping.
648
+ pass
649
+
650
+ # 5) Encode WAV (already exact length at sr_out)
651
+ b64, total_samples, channels = wav_bytes_base64(y_send, sr_out)
652
+ meta = {
653
+ "bpm": float(self.params.bpm),
654
+ "bars": int(self.params.bars_per_chunk),
655
+ "seconds": float(chunk_secs),
656
+ "sample_rate": int(sr_out),
657
+ "samples": int(total_samples),
658
+ "channels": int(channels),
659
+ "xfade_seconds": float(xfade_seconds),
660
+ }
661
+
662
+ # 6) Publish + advance both emit pointers
663
  with self._lock:
664
  self.idx += 1
665
  self.outbox.append(JamChunk(index=self.idx, audio_base64=b64, metadata=meta))
666
+ # prune outbox to keep memory in check
667
  if len(self.outbox) > 10:
668
  cutoff = self._last_delivered_index - 5
669
  self.outbox = [ch for ch in self.outbox if ch.index > cutoff]
670
 
671
+ # Handle reseed requests BETWEEN chunks
672
+ if getattr(self, "_pending_reseed", None) is not None:
673
  pkg = self._pending_reseed
674
  self._pending_reseed = None
675
 
676
+ # Reset model state with fresh bar-aligned context tokens
677
  new_state = self.mrt.init_state()
678
+ new_state.context_tokens = pkg["ctx"]
679
  self.state = new_state
680
 
681
+ # Reset INPUT stream and schedule one-time bar realign
682
  self._stream = None
683
  self._next_emit_start = 0
684
  self._reseed_ref_loop = pkg.get("ref") or self.params.combined_loop
685
  self._needs_bar_realign = True
686
 
687
+ # Reset OUTPUT-domain streaming state
688
+ self._stream_out = np.zeros((0, int(self.params.channels or 2)), dtype=np.float32)
689
+ self._next_emit_start_out = 0
690
+ self._emit_phase_out = 0.0
691
+ input_consumed = 0
692
+ if self._resampler is not None:
693
+ # Rebuild the resampler to clear its filter tail
694
+ from utils import StreamingResampler
695
+ ch = int(self.params.channels or 2)
696
+ self._resampler = StreamingResampler(in_sr=sr_in, out_sr=sr_out, channels=ch, quality="VHQ")
697
+
698
  print("🔁 Reseed installed at bar boundary; will realign before next slice")
699
 
700
+ # Advance both emit pointers for next round
701
+ _advance_input_pointer()
702
+ _advance_output_pointer()
703
+
704
+ # Keep memory tidy
705
+ _trim_buffers_if_needed()
706
+
707
  print(f"✅ Completed chunk {self.idx}")
708
 
709
+ # Stop: flush tail from resampler (optional)
710
+ if self._resampler is not None:
711
+ tail = self._resampler.flush()
712
+ if tail.size:
713
+ self._stream_out = tail if self._stream_out.size == 0 else np.vstack([self._stream_out, tail])
714
 
715
+ print("🛑 JamWorker stopped")