Commit
Β·
1b98b73
1
Parent(s):
783cbeb
fixing continuity
Browse files- jam_worker.py +106 -72
- utils.py +4 -2
jam_worker.py
CHANGED
@@ -350,88 +350,122 @@ class JamWorker(threading.Thread):
|
|
350 |
self._pending_drop_intro_bars = getattr(self, "_pending_drop_intro_bars", 0) + 1
|
351 |
|
352 |
def run(self):
|
353 |
-
"""
|
354 |
-
sr_model = int(self.mrt.sample_rate)
|
355 |
spb = self._seconds_per_bar()
|
356 |
-
chunk_secs =
|
357 |
-
|
358 |
-
|
359 |
-
|
360 |
-
#
|
361 |
-
|
362 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
363 |
|
364 |
-
print("π JamWorker
|
365 |
|
366 |
while not self._stop_event.is_set():
|
367 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
368 |
with self._lock:
|
369 |
-
if self.idx > self._last_delivered_index + self._max_buffer_ahead:
|
370 |
-
time.sleep(0.25)
|
371 |
-
continue
|
372 |
style_vec = self.params.style_vec
|
373 |
-
self.mrt.guidance_weight = self.params.guidance_weight
|
374 |
-
self.mrt.temperature = self.params.temperature
|
375 |
-
self.mrt.topk = self.params.topk
|
|
|
376 |
|
377 |
-
|
378 |
self.last_chunk_started_at = time.time()
|
379 |
-
wav, self.state = self.mrt.generate_chunk(state=self.state, style=style_vec)
|
380 |
-
self._append_model_chunk_to_stream(wav)
|
381 |
-
if getattr(self, "_needs_bar_realign", False):
|
382 |
-
self._realign_emit_pointer_to_bar(sr_model)
|
383 |
-
self._needs_bar_realign = False
|
384 |
-
# DEBUG
|
385 |
-
bar_samps = int(round(self._seconds_per_bar() * sr_model))
|
386 |
-
if bar_samps > 0 and (self._next_emit_start % bar_samps) != 0:
|
387 |
-
print(f"β οΈ emit pointer not aligned: phase={self._next_emit_start % bar_samps}")
|
388 |
-
else:
|
389 |
-
print("β
emit pointer aligned to bar")
|
390 |
-
|
391 |
-
self.last_chunk_completed_at = time.time()
|
392 |
-
|
393 |
-
# While we have at least one full 8-bar window available, emit it
|
394 |
-
while (getattr(self, "_stream", None) is not None and
|
395 |
-
self._stream.shape[0] - self._next_emit_start >= chunk_n_model and
|
396 |
-
not self._stop_event.is_set()):
|
397 |
-
|
398 |
-
seg = self._stream[self._next_emit_start:self._next_emit_start + chunk_n_model]
|
399 |
|
400 |
-
|
401 |
-
|
402 |
-
|
403 |
-
|
404 |
-
|
405 |
-
|
406 |
-
|
407 |
-
|
408 |
-
|
409 |
-
|
410 |
-
|
411 |
-
|
412 |
-
|
413 |
-
|
414 |
-
|
415 |
-
|
416 |
-
|
417 |
-
|
418 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
419 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
420 |
|
421 |
-
|
422 |
-
|
423 |
-
|
424 |
-
|
425 |
-
|
426 |
-
|
427 |
-
|
428 |
-
|
429 |
-
self._next_emit_start += chunk_n_model
|
430 |
|
431 |
-
|
432 |
-
|
433 |
-
if keep_from > 0:
|
434 |
-
self._stream = self._stream[keep_from:]
|
435 |
-
self._next_emit_start -= keep_from
|
436 |
|
437 |
-
print("π JamWorker
|
|
|
350 |
self._pending_drop_intro_bars = getattr(self, "_pending_drop_intro_bars", 0) + 1
|
351 |
|
352 |
def run(self):
|
353 |
+
"""Main worker loop - generate chunks continuously but don't get too far ahead"""
|
|
|
354 |
spb = self._seconds_per_bar()
|
355 |
+
chunk_secs = self.params.bars_per_chunk * spb
|
356 |
+
xfade = float(self.mrt.config.crossfade_length) # seconds
|
357 |
+
|
358 |
+
# local fallback stitcher that *keeps* the first head if utils.stitch_generated
|
359 |
+
# doesn't yet support drop_first_pre_roll
|
360 |
+
def _stitch_keep_head(chunks, sr: int, xfade_s: float):
|
361 |
+
from magenta_rt import audio as au
|
362 |
+
import numpy as _np
|
363 |
+
if not chunks:
|
364 |
+
raise ValueError("no chunks to stitch")
|
365 |
+
xfade_n = int(round(max(0.0, xfade_s) * sr))
|
366 |
+
# Fast-path: no crossfade
|
367 |
+
if xfade_n <= 0:
|
368 |
+
out = _np.concatenate([c.samples for c in chunks], axis=0)
|
369 |
+
return au.Waveform(out, sr)
|
370 |
+
# build equal-power curves
|
371 |
+
t = _np.linspace(0, _np.pi / 2, xfade_n, endpoint=False, dtype=_np.float32)
|
372 |
+
eq_in, eq_out = _np.sin(t)[:, None], _np.cos(t)[:, None]
|
373 |
+
|
374 |
+
first = chunks[0].samples
|
375 |
+
if first.shape[0] < xfade_n:
|
376 |
+
raise ValueError("chunk shorter than crossfade prefix")
|
377 |
+
out = first.copy() # π keep the head for live seam
|
378 |
+
|
379 |
+
for i in range(1, len(chunks)):
|
380 |
+
cur = chunks[i].samples
|
381 |
+
if cur.shape[0] < xfade_n:
|
382 |
+
# too short to crossfade; just butt-join
|
383 |
+
out = _np.concatenate([out, cur], axis=0)
|
384 |
+
continue
|
385 |
+
head, tail = cur[:xfade_n], cur[xfade_n:]
|
386 |
+
mixed = out[-xfade_n:] * eq_out + head * eq_in
|
387 |
+
out = _np.concatenate([out[:-xfade_n], mixed, tail], axis=0)
|
388 |
+
return au.Waveform(out, sr)
|
389 |
|
390 |
+
print("π JamWorker started with flow control...")
|
391 |
|
392 |
while not self._stop_event.is_set():
|
393 |
+
# Donβt get too far ahead of the consumer
|
394 |
+
if not self._should_generate_next_chunk():
|
395 |
+
# We're ahead enough, wait a bit for frontend to catch up
|
396 |
+
# (kept short so stop() stays responsive)
|
397 |
+
time.sleep(0.5)
|
398 |
+
continue
|
399 |
+
|
400 |
+
# Snapshot knobs + compute index atomically
|
401 |
with self._lock:
|
|
|
|
|
|
|
402 |
style_vec = self.params.style_vec
|
403 |
+
self.mrt.guidance_weight = float(self.params.guidance_weight)
|
404 |
+
self.mrt.temperature = float(self.params.temperature)
|
405 |
+
self.mrt.topk = int(self.params.topk)
|
406 |
+
next_idx = self.idx + 1
|
407 |
|
408 |
+
print(f"πΉ Generating chunk {next_idx}...")
|
409 |
self.last_chunk_started_at = time.time()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
410 |
|
411 |
+
# ---- Generate enough model sub-chunks to yield *audible* chunk_secs ----
|
412 |
+
# Count the first chunk at full length L, and each subsequent at (L - xfade)
|
413 |
+
assembled = 0.0
|
414 |
+
chunks = []
|
415 |
+
|
416 |
+
while assembled < chunk_secs and not self._stop_event.is_set():
|
417 |
+
# generate_chunk returns (au.Waveform, new_state)
|
418 |
+
wav, self.state = self.mrt.generate_chunk(state=self.state, style=style_vec)
|
419 |
+
chunks.append(wav)
|
420 |
+
L = wav.samples.shape[0] / float(self.mrt.sample_rate)
|
421 |
+
assembled += L if len(chunks) == 1 else max(0.0, L - xfade)
|
422 |
+
|
423 |
+
if self._stop_event.is_set():
|
424 |
+
break
|
425 |
+
|
426 |
+
# ---- Stitch and trim at model SR (keep first head for seamless handoff) ----
|
427 |
+
try:
|
428 |
+
# Preferred path if you've added the new param in utils.stitch_generated
|
429 |
+
y = stitch_generated(chunks, self.mrt.sample_rate, xfade, drop_first_pre_roll=False).as_stereo()
|
430 |
+
except TypeError:
|
431 |
+
# Backward-compatible: local stitcher that keeps the head
|
432 |
+
y = _stitch_keep_head(chunks, int(self.mrt.sample_rate), xfade).as_stereo()
|
433 |
+
|
434 |
+
# Hard trim to the exact musical duration (still at model SR)
|
435 |
+
y = hard_trim_seconds(y, chunk_secs)
|
436 |
+
|
437 |
+
# ---- Post-processing ----
|
438 |
+
if next_idx == 1 and self.params.ref_loop is not None:
|
439 |
+
# match loudness to the provided reference on the very first audible chunk
|
440 |
+
y, _ = match_loudness_to_reference(
|
441 |
+
self.params.ref_loop, y,
|
442 |
+
method=self.params.loudness_mode,
|
443 |
+
headroom_db=self.params.headroom_db
|
444 |
)
|
445 |
+
else:
|
446 |
+
# light micro-fades to guard against clicks
|
447 |
+
apply_micro_fades(y, 3)
|
448 |
+
|
449 |
+
# ---- Resample + bar-snap + encode ----
|
450 |
+
b64, meta = self._snap_and_encode(
|
451 |
+
y,
|
452 |
+
seconds=chunk_secs,
|
453 |
+
target_sr=self.params.target_sr,
|
454 |
+
bars=self.params.bars_per_chunk
|
455 |
+
)
|
456 |
+
# small hint for the client if you want UI butter between chunks
|
457 |
+
meta["xfade_seconds"] = xfade
|
458 |
|
459 |
+
# ---- Publish the completed chunk ----
|
460 |
+
with self._lock:
|
461 |
+
self.idx = next_idx
|
462 |
+
self.outbox.append(JamChunk(index=next_idx, audio_base64=b64, metadata=meta))
|
463 |
+
# Keep outbox bounded (trim far-behind entries)
|
464 |
+
if len(self.outbox) > 10:
|
465 |
+
cutoff = self._last_delivered_index - 5
|
466 |
+
self.outbox = [ch for ch in self.outbox if ch.index > cutoff]
|
|
|
467 |
|
468 |
+
self.last_chunk_completed_at = time.time()
|
469 |
+
print(f"β
Completed chunk {next_idx}")
|
|
|
|
|
|
|
470 |
|
471 |
+
print("π JamWorker stopped")
|
utils.py
CHANGED
@@ -69,7 +69,7 @@ def match_loudness_to_reference(
|
|
69 |
|
70 |
|
71 |
# ---------- Stitch / fades / trims ----------
|
72 |
-
def stitch_generated(chunks, sr: int, xfade_s: float
|
73 |
if not chunks:
|
74 |
raise ValueError("no chunks")
|
75 |
xfade_n = int(round(xfade_s * sr))
|
@@ -82,7 +82,9 @@ def stitch_generated(chunks, sr: int, xfade_s: float) -> au.Waveform:
|
|
82 |
first = chunks[0].samples
|
83 |
if first.shape[0] < xfade_n:
|
84 |
raise ValueError("chunk shorter than crossfade prefix")
|
85 |
-
|
|
|
|
|
86 |
|
87 |
for i in range(1, len(chunks)):
|
88 |
cur = chunks[i].samples
|
|
|
69 |
|
70 |
|
71 |
# ---------- Stitch / fades / trims ----------
|
72 |
+
def stitch_generated(chunks, sr: int, xfade_s: float, drop_first_pre_roll: bool = True):
|
73 |
if not chunks:
|
74 |
raise ValueError("no chunks")
|
75 |
xfade_n = int(round(xfade_s * sr))
|
|
|
82 |
first = chunks[0].samples
|
83 |
if first.shape[0] < xfade_n:
|
84 |
raise ValueError("chunk shorter than crossfade prefix")
|
85 |
+
|
86 |
+
# π§ key change:
|
87 |
+
out = first[xfade_n:].copy() if drop_first_pre_roll else first.copy()
|
88 |
|
89 |
for i in range(1, len(chunks)):
|
90 |
cur = chunks[i].samples
|