# jam_worker.py - SIMPLE FIX VERSION import threading, time, base64, io, uuid from dataclasses import dataclass, field import numpy as np import soundfile as sf from magenta_rt import audio as au from threading import RLock from utils import ( match_loudness_to_reference, stitch_generated, hard_trim_seconds, apply_micro_fades, make_bar_aligned_context, take_bar_aligned_tail, resample_and_snap, wav_bytes_base64 ) @dataclass class JamParams: bpm: float beats_per_bar: int bars_per_chunk: int target_sr: int loudness_mode: str = "auto" headroom_db: float = 1.0 style_vec: np.ndarray | None = None ref_loop: any = None combined_loop: any = None guidance_weight: float = 1.1 temperature: float = 1.1 topk: int = 40 @dataclass class JamChunk: index: int audio_base64: str metadata: dict class JamWorker(threading.Thread): def __init__(self, mrt, params: JamParams): super().__init__(daemon=True) self.mrt = mrt self.params = params self.state = mrt.init_state() # โœ… init synchronization + placeholders FIRST self._lock = threading.Lock() self._original_context_tokens = None # so hasattr checks are cheap/clear if params.combined_loop is not None: self._setup_context_from_combined_loop() self.idx = 0 self.outbox: list[JamChunk] = [] self._stop_event = threading.Event() # NEW: Track delivery state self._last_delivered_index = 0 self._max_buffer_ahead = 5 # Timing info self.last_chunk_started_at = None self.last_chunk_completed_at = None def _setup_context_from_combined_loop(self): """Set up MRT context tokens from the combined loop audio""" try: from utils import make_bar_aligned_context, take_bar_aligned_tail codec_fps = float(self.mrt.codec.frame_rate) ctx_seconds = float(self.mrt.config.context_length_frames) / codec_fps loop_for_context = take_bar_aligned_tail( self.params.combined_loop, self.params.bpm, self.params.beats_per_bar, ctx_seconds ) tokens_full = self.mrt.codec.encode(loop_for_context).astype(np.int32) tokens = tokens_full[:, :self.mrt.config.decoder_codec_rvq_depth] context_tokens = make_bar_aligned_context( tokens, bpm=self.params.bpm, fps=float(self.mrt.codec.frame_rate), # keep fractional fps ctx_frames=self.mrt.config.context_length_frames, beats_per_bar=self.params.beats_per_bar ) # Install fresh context self.state.context_tokens = context_tokens print(f"โœ… JamWorker: Set up fresh context from combined loop") # NEW: keep a copy of the *original* context tokens for future splice-reseed # (guard so we only set this once, at jam start) with self._lock: if not hasattr(self, "_original_context_tokens") or self._original_context_tokens is None: self._original_context_tokens = np.copy(context_tokens) # shape: [T, depth] except Exception as e: print(f"โŒ Failed to setup context from combined loop: {e}") def stop(self): self._stop_event.set() def update_knobs(self, *, guidance_weight=None, temperature=None, topk=None): with self._lock: if guidance_weight is not None: self.params.guidance_weight = float(guidance_weight) if temperature is not None: self.params.temperature = float(temperature) if topk is not None: self.params.topk = int(topk) def get_next_chunk(self) -> JamChunk | None: """Get the next sequential chunk (blocks/waits if not ready)""" target_index = self._last_delivered_index + 1 # Wait for the target chunk to be ready (with timeout) max_wait = 30.0 # seconds start_time = time.time() while time.time() - start_time < max_wait and not self._stop_event.is_set(): with self._lock: # Look for the exact chunk we need for chunk in self.outbox: if chunk.index == target_index: self._last_delivered_index = target_index print(f"๐Ÿ“ฆ Delivered chunk {target_index}") return chunk # Not ready yet, wait a bit time.sleep(0.1) # Timeout or stopped return None def mark_chunk_consumed(self, chunk_index: int): """Mark a chunk as consumed by the frontend""" with self._lock: self._last_delivered_index = max(self._last_delivered_index, chunk_index) print(f"โœ… Chunk {chunk_index} consumed") def _should_generate_next_chunk(self) -> bool: """Check if we should generate the next chunk (don't get too far ahead)""" with self._lock: # Don't generate if we're already too far ahead if self.idx > self._last_delivered_index + self._max_buffer_ahead: return False return True def _seconds_per_bar(self) -> float: return self.params.beats_per_bar * (60.0 / self.params.bpm) def _snap_and_encode(self, y, seconds, target_sr, bars): cur_sr = int(self.mrt.sample_rate) x = y.samples if y.samples.ndim == 2 else y.samples[:, None] x = resample_and_snap(x, cur_sr=cur_sr, target_sr=target_sr, seconds=seconds) b64, total_samples, channels = wav_bytes_base64(x, target_sr) meta = { "bpm": int(round(self.params.bpm)), "bars": int(bars), "beats_per_bar": int(self.params.beats_per_bar), "sample_rate": int(target_sr), "channels": channels, "total_samples": total_samples, "seconds_per_bar": self._seconds_per_bar(), "loop_duration_seconds": bars * self._seconds_per_bar(), "guidance_weight": self.params.guidance_weight, "temperature": self.params.temperature, "topk": self.params.topk, } return b64, meta def _append_model_chunk_to_stream(self, wav): """Incrementally append a model chunk with equal-power crossfade.""" xfade_s = float(self.mrt.config.crossfade_length) sr = int(self.mrt.sample_rate) xfade_n = int(round(xfade_s * sr)) s = wav.samples if wav.samples.ndim == 2 else wav.samples[:, None] if getattr(self, "_stream", None) is None: # First chunk: drop model pre-roll (xfade head) if s.shape[0] > xfade_n: self._stream = s[xfade_n:].astype(np.float32, copy=True) else: self._stream = np.zeros((0, s.shape[1]), dtype=np.float32) self._next_emit_start = 0 # pointer into _stream (model SR samples) return # Crossfade last xfade_n samples of _stream with head of new s if s.shape[0] <= xfade_n or self._stream.shape[0] < xfade_n: # Degenerate safeguard self._stream = np.concatenate([self._stream, s], axis=0) return tail = self._stream[-xfade_n:] head = s[:xfade_n] # Equal-power envelopes t = np.linspace(0, np.pi/2, xfade_n, endpoint=False, dtype=np.float32)[:, None] eq_in, eq_out = np.sin(t), np.cos(t) mixed = tail * eq_out + head * eq_in self._stream = np.concatenate([self._stream[:-xfade_n], mixed, s[xfade_n:]], axis=0) def reseed_from_waveform(self, wav): # 1) Re-init state new_state = self.mrt.init_state() # 2) Build bar-aligned context tokens from provided audio codec_fps = float(self.mrt.codec.frame_rate) ctx_seconds = float(self.mrt.config.context_length_frames) / codec_fps from utils import take_bar_aligned_tail, make_bar_aligned_context tail = take_bar_aligned_tail(wav, self.params.bpm, self.params.beats_per_bar, ctx_seconds) tokens_full = self.mrt.codec.encode(tail).astype(np.int32) tokens = tokens_full[:, :self.mrt.config.decoder_codec_rvq_depth] context_tokens = make_bar_aligned_context(tokens, bpm=self.params.bpm, fps=float(self.mrt.codec.frame_rate), ctx_frames=self.mrt.config.context_length_frames, beats_per_bar=self.params.beats_per_bar ) new_state.context_tokens = context_tokens self.state = new_state self._prepare_stream_for_reseed_handoff() def _frames_per_bar(self) -> int: # codec frame-rate (frames/s) -> frames per musical bar fps = float(self.mrt.codec.frame_rate) sec_per_bar = (60.0 / float(self.params.bpm)) * float(self.params.beats_per_bar) return int(round(fps * sec_per_bar)) def _ctx_frames(self) -> int: # how many codec frames fit in the modelโ€™s conditioning window return int(self.mrt.config.context_length_frames) def _make_recent_tokens_from_wave(self, wav) -> np.ndarray: """ Encode waveform and produce a BAR-ALIGNED context token window. """ tokens_full = self.mrt.codec.encode(wav).astype(np.int32) # [T, rvq_total] tokens = tokens_full[:, :self.mrt.config.decoder_codec_rvq_depth] from utils import make_bar_aligned_context ctx = make_bar_aligned_context( tokens, bpm=self.params.bpm, fps=float(self.mrt.codec.frame_rate), # keep fractional fps ctx_frames=self.mrt.config.context_length_frames, beats_per_bar=self.params.beats_per_bar ) return ctx def _bar_aligned_tail(self, tokens: np.ndarray, bars: float) -> np.ndarray: """ Take a tail slice that is an integer number of codec frames corresponding to `bars`. We round to nearest frame to stay phase-consistent with codec grid. """ frames_per_bar = self._frames_per_bar() want = max(frames_per_bar * int(round(bars)), 0) if want == 0: return tokens[:0] # empty if tokens.shape[0] <= want: return tokens return tokens[-want:] def _splice_context(self, original_tokens: np.ndarray, recent_tokens: np.ndarray, anchor_bars: float) -> np.ndarray: import math ctx_frames = self._ctx_frames() depth = original_tokens.shape[1] frames_per_bar = self._frames_per_bar() # 1) Anchor tail # Use floor, not round, to avoid grabbing an extra bar. anchor = self._bar_aligned_tail(original_tokens, math.floor(anchor_bars)) # 2) Fill remainder with recent (in whole bars when possible) a = anchor.shape[0] remain = max(ctx_frames - a, 0) if remain > 0: bars_fit = remain // frames_per_bar if bars_fit >= 1: want_recent_frames = int(bars_fit * frames_per_bar) recent = recent_tokens[-want_recent_frames:] if recent_tokens.shape[0] > want_recent_frames else recent_tokens else: recent = recent_tokens[-remain:] if recent_tokens.shape[0] > remain else recent_tokens else: recent = recent_tokens[:0] out = np.concatenate([anchor, recent], axis=0) if (anchor.size or recent.size) else recent_tokens[-ctx_frames:] if out.shape[0] > ctx_frames: out = out[-ctx_frames:] # --- NEW: force total length to a whole number of bars max_bar_aligned = (out.shape[0] // frames_per_bar) * frames_per_bar if max_bar_aligned > 0 and out.shape[0] != max_bar_aligned: out = out[-max_bar_aligned:] if out.shape[1] != depth: out = out[:, :depth] return out def _realign_emit_pointer_to_bar(self, sr_model: int): """Advance _next_emit_start to the next bar boundary in model-sample space.""" bar_samps = int(round(self._seconds_per_bar() * sr_model)) if bar_samps <= 0: return phase = self._next_emit_start % bar_samps if phase != 0: self._next_emit_start += (bar_samps - phase) def _prepare_stream_for_reseed_handoff(self): # OLD: keep crossfade tail -> causes phase offset # sr = int(self.mrt.sample_rate) # xfade_s = float(self.mrt.config.crossfade_length) # xfade_n = int(round(xfade_s * sr)) # if getattr(self, "_stream", None) is not None and self._stream.shape[0] > 0: # tail = self._stream[-xfade_n:] if self._stream.shape[0] > xfade_n else self._stream # self._stream = tail.copy() # else: # self._stream = None # NEW: throw away the tail completely; start fresh self._stream = None self._next_emit_start = 0 self._needs_bar_realign = True def reseed_splice(self, recent_wav, anchor_bars: float): """ Token-splice reseed: - original = the context we captured when the jam started - recent = tokens from the provided recent waveform (usually Swift-combined mix) - anchor_bars controls how much of the original vibe we re-inject """ with self._lock: if not hasattr(self, "_original_context_tokens") or self._original_context_tokens is None: # Fallback: if we somehow donโ€™t have originals, treat current as originals self._original_context_tokens = np.copy(self.state.context_tokens) recent_tokens = self._make_recent_tokens_from_wave(recent_wav) # [T, depth] new_ctx = self._splice_context(self._original_context_tokens, recent_tokens, anchor_bars) # install the new context window new_state = self.mrt.init_state() new_state.context_tokens = new_ctx self.state = new_state self._prepare_stream_for_reseed_handoff() # optional: ask streamer to drop an intro crossfade worth of audio right after reseed self._pending_drop_intro_bars = getattr(self, "_pending_drop_intro_bars", 0) + 1 def run(self): """Main worker loop - generate chunks continuously but don't get too far ahead""" spb = self._seconds_per_bar() chunk_secs = self.params.bars_per_chunk * spb xfade = float(self.mrt.config.crossfade_length) # seconds def _mono_env(x: np.ndarray, sr: int, win_ms: float = 10.0) -> np.ndarray: """Rectified moving-average envelope, then a simple onset-y novelty (half-wave diff).""" if x.ndim == 2: x = x.mean(axis=1) x = np.abs(x).astype(np.float32) w = max(1, int(round(win_ms * 1e-3 * sr))) if w > 1: kern = np.ones(w, dtype=np.float32) / float(w) x = np.convolve(x, kern, mode="same") # onset-ish novelty: positive first difference (half-wave) d = np.diff(x, prepend=x[:1]) d[d < 0] = 0.0 return d def _estimate_first_offset_samples(ref_loop_wav, gen_wav, sr: int, spb: float, max_ms: int = 180) -> int: """ Estimate how late/early the first downbeat is by correlating the last bar of the reference vs the first two bars of the generated chunk. Allows small +/- offsets; upsample envelopes x4 for sub-sample precision then round. """ try: ref = ref_loop_wav if ref_loop_wav.sample_rate == sr else ref_loop_wav.resample(sr) n_bar = int(round(spb * sr)) ref_tail = ref.samples[-n_bar:, :] if ref.samples.shape[0] >= n_bar else ref.samples gen_head = gen_wav.samples[: int(2 * n_bar), :] if ref_tail.size == 0 or gen_head.size == 0: return 0 e_ref = _mono_env(ref_tail, sr) # length ~ n_bar e_gen = _mono_env(gen_head, sr) # length ~ 2*n_bar # z-score for scale invariance def _z(a): m, s = float(a.mean()), float(a.std() or 1.0) return (a - m) / s e_ref = _z(e_ref).astype(np.float32) e_gen = _z(e_gen).astype(np.float32) # Light upsampling for finer lag resolution (x4) def _upsample(a, r=4): n = len(a) grid = np.arange(n, dtype=np.float32) fine = np.linspace(0, n - 1, num=n * r, dtype=np.float32) return np.interp(fine, grid, a).astype(np.float32) up = 4 e_ref_u = _upsample(e_ref, up) e_gen_u = _upsample(e_gen, up) # Correlate in a tight window max_lag_u = int(round((max_ms / 1000.0) * sr * up)) seg = min(len(e_ref_u), len(e_gen_u)) e_ref_u = e_ref_u[-seg:] # pad head so we can slide +/- lags pad = np.zeros(max_lag_u, dtype=np.float32) e_gen_u_pad = np.concatenate([pad, e_gen_u, pad]) best_lag_u, best_score = 0, -1e9 # allow tiny early OR late (negative = model early, positive = late) for lag_u in range(-max_lag_u, max_lag_u + 1): start = max_lag_u + lag_u b = e_gen_u_pad[start : start + seg] # normalized dot (already z-scored, but keep it consistent) denom = (np.linalg.norm(e_ref_u) * np.linalg.norm(b)) or 1.0 score = float(np.dot(e_ref_u, b) / denom) if score > best_score: best_score, best_lag_u = score, lag_u # convert envelope-lag back to audio samples and round lag_samples = int(round(best_lag_u / up)) return lag_samples except Exception: return 0 print("๐Ÿš€ JamWorker started with flow control...") while not self._stop_event.is_set(): # Donโ€™t get too far ahead of the consumer if not self._should_generate_next_chunk(): print("โธ๏ธ Buffer full, waiting for consumption...") time.sleep(0.5) continue # Snapshot knobs + compute index with self._lock: style_vec = self.params.style_vec self.mrt.guidance_weight = float(self.params.guidance_weight) self.mrt.temperature = float(self.params.temperature) self.mrt.topk = int(self.params.topk) next_idx = self.idx + 1 print(f"๐ŸŽน Generating chunk {next_idx}...") self.last_chunk_started_at = time.time() # ---- Generate enough model sub-chunks to yield *audible* chunk_secs ---- # First sub-chunk contributes full L; subsequent contribute (L - xfade) assembled = 0.0 chunks = [] while assembled < chunk_secs and not self._stop_event.is_set(): wav, self.state = self.mrt.generate_chunk(state=self.state, style=style_vec) chunks.append(wav) L = wav.samples.shape[0] / float(self.mrt.sample_rate) assembled += L if len(chunks) == 1 else max(0.0, L - xfade) if self._stop_event.is_set(): break # ---- Stitch (utils drops the very first model pre-roll) & trim at model SR ---- y = stitch_generated(chunks, self.mrt.sample_rate, xfade).as_stereo() y = hard_trim_seconds(y, chunk_secs) # ---- ONE-TIME: grid-align the very first jam chunk to kill the flam ---- if next_idx == 1 and self.params.combined_loop is not None: offset = _estimate_first_offset_samples( self.params.combined_loop, y, int(self.mrt.sample_rate), spb, max_ms=180 # try 160โ€“200 ) if offset != 0: # positive => model late: trim head; negative => model early: pad head (rare) if offset > 0: y.samples = y.samples[offset:, :] else: pad = np.zeros((abs(offset), y.samples.shape[1]), dtype=y.samples.dtype) y.samples = np.concatenate([pad, y.samples], axis=0) print(f"๐ŸŽฏ First-chunk offset compensation: {offset/self.mrt.sample_rate:+.3f}s") y = hard_trim_seconds(y, chunk_secs) # ---- Post-processing ---- if next_idx == 1 and self.params.ref_loop is not None: y, _ = match_loudness_to_reference( self.params.ref_loop, y, method=self.params.loudness_mode, headroom_db=self.params.headroom_db ) else: apply_micro_fades(y, 3) # ---- Resample + bar-snap + encode ---- b64, meta = self._snap_and_encode( y, seconds=chunk_secs, target_sr=self.params.target_sr, bars=self.params.bars_per_chunk ) meta["xfade_seconds"] = xfade # tiny hint for client if you want butter at chunk joins # ---- Publish ---- with self._lock: self.idx = next_idx self.outbox.append(JamChunk(index=next_idx, audio_base64=b64, metadata=meta)) if len(self.outbox) > 10: cutoff = self._last_delivered_index - 5 self.outbox = [ch for ch in self.outbox if ch.index > cutoff] self.last_chunk_completed_at = time.time() print(f"โœ… Completed chunk {next_idx}") print("๐Ÿ›‘ JamWorker stopped")