Spaces:
Running
on
L40S
Running
on
L40S
Commit
·
dec57f5
1
Parent(s):
381da60
streaming resampler added
Browse files- jam_worker.py +152 -142
jam_worker.py
CHANGED
@@ -410,186 +410,196 @@ class JamWorker(threading.Thread):
|
|
410 |
self._pending_drop_intro_bars = getattr(self, "_pending_drop_intro_bars", 0) + 1
|
411 |
|
412 |
def run(self):
|
413 |
-
"""Main worker loop —
|
414 |
-
|
415 |
-
|
416 |
-
|
417 |
-
|
418 |
-
|
419 |
-
|
420 |
-
|
421 |
-
|
422 |
-
|
423 |
-
|
424 |
-
|
425 |
-
|
426 |
-
|
427 |
-
|
428 |
-
|
429 |
-
|
430 |
-
|
431 |
-
|
432 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
433 |
emit_phase = float(getattr(self, "_emit_phase", 0.0))
|
434 |
-
|
435 |
-
|
436 |
-
# How much we want available beyond 'start' for this emit.
|
437 |
-
want = step_int
|
438 |
if first_chunk_extra:
|
439 |
-
|
440 |
-
# Use ceil to be conservative so we don't under-request.
|
441 |
-
want += int(ceil(2.0 * spb * sr))
|
442 |
-
|
443 |
return max(0, want - have)
|
444 |
|
445 |
-
|
446 |
-
|
447 |
-
x = np.abs(x).astype(np.float32)
|
448 |
-
w = max(1, int(round(win_ms * 1e-3 * sr)))
|
449 |
-
if w > 1:
|
450 |
-
kern = np.ones(w, dtype=np.float32) / float(w)
|
451 |
-
x = np.convolve(x, kern, mode="same")
|
452 |
-
d = np.diff(x, prepend=x[:1])
|
453 |
-
d[d < 0] = 0.0
|
454 |
-
return d
|
455 |
-
|
456 |
-
def _estimate_first_offset_samples(ref_loop_wav, gen_head_wav, sr: int, spb: float) -> int:
|
457 |
-
"""Tempo-aware first-downbeat offset (positive => model late)."""
|
458 |
-
try:
|
459 |
-
max_ms = int(max(160.0, min(0.25 * spb * 1000.0, 450.0)))
|
460 |
-
ref = ref_loop_wav if ref_loop_wav.sample_rate == sr else ref_loop_wav.resample(sr)
|
461 |
-
n_bar = int(round(spb * sr))
|
462 |
-
ref_tail = ref.samples[-n_bar:, :] if ref.samples.shape[0] >= n_bar else ref.samples
|
463 |
-
gen_head = gen_head_wav.samples[: int(2 * n_bar), :]
|
464 |
-
if ref_tail.size == 0 or gen_head.size == 0:
|
465 |
-
return 0
|
466 |
-
|
467 |
-
# envelopes + z-score
|
468 |
-
def _z(a):
|
469 |
-
m, s = float(a.mean()), float(a.std() or 1.0); return (a - m) / s
|
470 |
-
e_ref = _z(_mono_env(ref_tail, sr)).astype(np.float32)
|
471 |
-
e_gen = _z(_mono_env(gen_head, sr)).astype(np.float32)
|
472 |
-
|
473 |
-
# upsample x4 for finer lag
|
474 |
-
def _upsample(a, r=4):
|
475 |
-
n = len(a); grid = np.arange(n, dtype=np.float32)
|
476 |
-
fine = np.linspace(0, n - 1, num=n * r, dtype=np.float32)
|
477 |
-
return np.interp(fine, grid, a).astype(np.float32)
|
478 |
-
up = 4
|
479 |
-
e_ref_u, e_gen_u = _upsample(e_ref, up), _upsample(e_gen, up)
|
480 |
-
|
481 |
-
max_lag_u = int(round((max_ms / 1000.0) * sr * up))
|
482 |
-
seg = min(len(e_ref_u), len(e_gen_u))
|
483 |
-
e_ref_u = e_ref_u[-seg:]
|
484 |
-
pad = np.zeros(max_lag_u, dtype=np.float32)
|
485 |
-
e_gen_u_pad = np.concatenate([pad, e_gen_u, pad])
|
486 |
-
|
487 |
-
best_lag_u, best_score = 0, -1e9
|
488 |
-
for lag_u in range(-max_lag_u, max_lag_u + 1):
|
489 |
-
start = max_lag_u + lag_u
|
490 |
-
b = e_gen_u_pad[start : start + seg]
|
491 |
-
denom = (np.linalg.norm(e_ref_u) * np.linalg.norm(b)) or 1.0
|
492 |
-
score = float(np.dot(e_ref_u, b) / denom)
|
493 |
-
if score > best_score:
|
494 |
-
best_score, best_lag_u = score, lag_u
|
495 |
-
return int(round(best_lag_u / up))
|
496 |
-
except Exception:
|
497 |
-
return 0
|
498 |
-
|
499 |
-
print("🚀 JamWorker started (bar-aligned streaming)…")
|
500 |
|
|
|
501 |
while not self._stop_event.is_set():
|
|
|
502 |
if not self._should_generate_next_chunk():
|
503 |
-
time.sleep(0.
|
504 |
continue
|
505 |
|
506 |
-
# 1)
|
507 |
need = _need(first_chunk_extra=(self.idx == 0))
|
508 |
-
|
509 |
-
|
510 |
-
|
511 |
-
|
512 |
-
|
513 |
-
|
|
|
514 |
wav, self.state = self.mrt.generate_chunk(state=self.state, style=style_vec)
|
515 |
-
|
516 |
-
|
517 |
-
|
518 |
-
|
519 |
-
|
520 |
-
|
521 |
-
|
522 |
-
|
523 |
-
|
524 |
-
|
525 |
-
|
526 |
-
|
527 |
-
|
528 |
-
|
529 |
-
|
530 |
-
|
531 |
-
|
532 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
533 |
self._needs_bar_realign = False
|
534 |
self._reseed_ref_loop = None
|
535 |
|
536 |
-
# 3)
|
537 |
-
|
538 |
-
|
539 |
-
|
540 |
-
|
541 |
-
|
542 |
-
|
543 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
544 |
continue
|
545 |
|
546 |
-
|
547 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
548 |
|
549 |
-
|
|
|
550 |
|
551 |
-
#
|
552 |
if self.idx == 0 and self.params.ref_loop is not None:
|
553 |
y, _ = match_loudness_to_reference(
|
554 |
self.params.ref_loop, y,
|
555 |
method=self.params.loudness_mode,
|
556 |
headroom_db=self.params.headroom_db
|
557 |
)
|
558 |
-
|
559 |
-
apply_micro_fades(y, 3)
|
560 |
|
561 |
-
#
|
562 |
-
b64,
|
563 |
-
y
|
564 |
)
|
565 |
-
meta["xfade_seconds"] = xfade
|
566 |
|
567 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
568 |
with self._lock:
|
569 |
self.idx += 1
|
570 |
self.outbox.append(JamChunk(index=self.idx, audio_base64=b64, metadata=meta))
|
|
|
571 |
if len(self.outbox) > 10:
|
572 |
cutoff = self._last_delivered_index - 5
|
573 |
self.outbox = [ch for ch in self.outbox if ch.index > cutoff]
|
574 |
|
575 |
-
#
|
576 |
-
if self
|
577 |
pkg = self._pending_reseed
|
578 |
self._pending_reseed = None
|
579 |
-
|
580 |
-
|
581 |
-
new_state.context_tokens = pkg["ctx"] # exact (ctx_frames, depth)
|
582 |
-
self.state = new_state
|
583 |
-
|
584 |
-
# start a fresh stream and schedule one-time alignment
|
585 |
-
self._stream = None
|
586 |
-
self._next_emit_start = 0
|
587 |
-
self._reseed_ref_loop = pkg.get("ref") or self.params.combined_loop
|
588 |
self._needs_bar_realign = True
|
|
|
589 |
|
590 |
-
|
591 |
|
592 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
593 |
|
594 |
print("🛑 JamWorker stopped")
|
595 |
-
|
|
|
410 |
self._pending_drop_intro_bars = getattr(self, "_pending_drop_intro_bars", 0) + 1
|
411 |
|
412 |
def run(self):
|
413 |
+
"""Main worker loop — continuous gen at model SR, stream resampled chunks butt-joined at target SR."""
|
414 |
+
import numpy as _np
|
415 |
+
from math import floor, ceil
|
416 |
+
spb = self._seconds_per_bar() # seconds per bar
|
417 |
+
chunk_secs = float(self.params.bars_per_chunk) * spb # seconds per client chunk
|
418 |
+
xfade_s = float(self.mrt.config.crossfade_length) # seconds of model equal-power xfade
|
419 |
+
sr_in = int(self.mrt.sample_rate) # model/native SR
|
420 |
+
sr_out = int(self.params.target_sr or sr_in) # desired output SR (e.g., 44100)
|
421 |
+
ch = 2 # enforce stereo out
|
422 |
+
|
423 |
+
# --- Fractional emit steppers (input + output domains) ---
|
424 |
+
chunk_step_in_f = chunk_secs * sr_in
|
425 |
+
self._emit_phase = float(getattr(self, "_emit_phase", 0.0))
|
426 |
+
|
427 |
+
chunk_step_out_f = chunk_secs * sr_out
|
428 |
+
self._emit_phase_out = float(getattr(self, "_emit_phase_out", 0.0))
|
429 |
+
self._next_emit_start_out = int(getattr(self, "_next_emit_start_out", 0))
|
430 |
+
|
431 |
+
# --- Streaming resampler state (input -> output); hold back xfade tail so overlapped region is final ---
|
432 |
+
xfade_n_in = int(round(xfade_s * sr_in))
|
433 |
+
self._resampler = None
|
434 |
+
self._stream_out = None
|
435 |
+
self._resample_cursor_in = int(getattr(self, "_resample_cursor_in", 0)) # how many INPUT samples we fed to the resampler
|
436 |
+
|
437 |
+
if sr_out != sr_in:
|
438 |
+
try:
|
439 |
+
from utils import StreamingResampler
|
440 |
+
self._resampler = StreamingResampler(in_sr=sr_in, out_sr=sr_out, channels=ch, quality="VHQ")
|
441 |
+
self._stream_out = _np.zeros((0, ch), dtype=_np.float32)
|
442 |
+
except Exception as e:
|
443 |
+
print(f"⚠️ Could not init StreamingResampler ({e}); falling back to alias-mode (sr_out==sr_in).")
|
444 |
+
sr_out = sr_in
|
445 |
+
self.params.target_sr = sr_out
|
446 |
+
self._resampler = None
|
447 |
+
self._stream_out = _np.zeros((0, ch), dtype=_np.float32)
|
448 |
+
self._resample_cursor_in = 0
|
449 |
+
else:
|
450 |
+
self._stream_out = _np.zeros((0, ch), dtype=_np.float32)
|
451 |
+
self._resample_cursor_in = 0
|
452 |
+
|
453 |
+
# --- helper: how many more INPUT samples (stable) we need to be able to emit next client chunk ---
|
454 |
+
def _need(first_chunk_extra: bool=False) -> int:
|
455 |
+
start = int(getattr(self, "_next_emit_start", 0))
|
456 |
+
total_in = 0 if getattr(self, "_stream", None) is None else int(self._stream.shape[0])
|
457 |
+
total_in_stable = max(0, total_in - xfade_n_in) # hold back xfade tail (overlap will be replaced)
|
458 |
+
have = max(0, total_in_stable - start)
|
459 |
emit_phase = float(getattr(self, "_emit_phase", 0.0))
|
460 |
+
step_int_in = int(floor(chunk_step_in_f + emit_phase))
|
461 |
+
want = step_int_in
|
|
|
|
|
462 |
if first_chunk_extra:
|
463 |
+
want += int(ceil(2.0 * spb * sr_in))
|
|
|
|
|
|
|
464 |
return max(0, want - have)
|
465 |
|
466 |
+
print(f"▶️ JamWorker starting: bpm={self.params.bpm}, bars/chunk={self.params.bars_per_chunk}, "
|
467 |
+
f"sr_in={sr_in}, sr_out={sr_out}, xfade_s={xfade_s:.3f}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
468 |
|
469 |
+
# --- main loop ---
|
470 |
while not self._stop_event.is_set():
|
471 |
+
# 0) Backpressure: don't run too far ahead
|
472 |
if not self._should_generate_next_chunk():
|
473 |
+
time.sleep(0.01)
|
474 |
continue
|
475 |
|
476 |
+
# 1) Ensure enough model audio exists (INPUT domain)
|
477 |
need = _need(first_chunk_extra=(self.idx == 0))
|
478 |
+
if need > 0:
|
479 |
+
# Generate one model chunk
|
480 |
+
style_vec = self.params.style_vec
|
481 |
+
self.mrt.guidance_weight = float(self.params.guidance_weight)
|
482 |
+
self.mrt.temperature = float(self.params.temperature)
|
483 |
+
self.mrt.topk = int(self.params.topk)
|
484 |
+
|
485 |
wav, self.state = self.mrt.generate_chunk(state=self.state, style=style_vec)
|
486 |
+
|
487 |
+
# Append (equal-power crossfade into persistent input-domain stream)
|
488 |
+
self._append_model_chunk_to_stream(wav)
|
489 |
+
|
490 |
+
# Feed *stable* portion into the resampler/output buffer
|
491 |
+
if getattr(self, "_stream", None) is not None and self._stream.shape[0] > 0:
|
492 |
+
stable_end_in = max(0, int(self._stream.shape[0]) - xfade_n_in)
|
493 |
+
if stable_end_in > self._resample_cursor_in:
|
494 |
+
x_in = self._stream[self._resample_cursor_in:stable_end_in]
|
495 |
+
if self._resampler is not None:
|
496 |
+
y_out = self._resampler.process(x_in.astype(_np.float32, copy=False), final=False)
|
497 |
+
if y_out.size:
|
498 |
+
self._stream_out = y_out if self._stream_out.size == 0 else _np.vstack([self._stream_out, y_out])
|
499 |
+
else:
|
500 |
+
# pass-through (sr_out == sr_in)
|
501 |
+
self._stream_out = x_in if self._stream_out.size == 0 else _np.vstack([self._stream_out, x_in])
|
502 |
+
self._resample_cursor_in = stable_end_in
|
503 |
+
# loop back to either generate more or try emitting
|
504 |
+
continue
|
505 |
+
|
506 |
+
# 2) Optional, one-shot bar realign (occurs on first slice or reseed)
|
507 |
+
if getattr(self, "_needs_bar_realign", False):
|
508 |
+
self._realign_emit_pointer_to_bar(sr_in)
|
509 |
+
self._emit_phase = 0.0 # restart fractional phase at clean bar
|
510 |
self._needs_bar_realign = False
|
511 |
self._reseed_ref_loop = None
|
512 |
|
513 |
+
# 3) Compute next emit window in BOTH domains
|
514 |
+
start_in = int(getattr(self, "_next_emit_start", 0))
|
515 |
+
step_total_in = chunk_step_in_f + self._emit_phase
|
516 |
+
step_int_in = int(floor(step_total_in))
|
517 |
+
new_phase_in = float(step_total_in - step_int_in)
|
518 |
+
end_in = start_in + step_int_in
|
519 |
+
|
520 |
+
start_out = int(self._next_emit_start_out)
|
521 |
+
step_total_out = chunk_step_out_f + self._emit_phase_out
|
522 |
+
step_int_out = int(floor(step_total_out))
|
523 |
+
new_phase_out = float(step_total_out - step_int_out)
|
524 |
+
end_out = start_out + step_int_out
|
525 |
+
|
526 |
+
# 4) Guards — do we actually have enough ready in both domains?
|
527 |
+
total_in_stable = 0
|
528 |
+
if getattr(self, "_stream", None) is not None:
|
529 |
+
total_in_stable = max(0, int(self._stream.shape[0]) - xfade_n_in)
|
530 |
+
total_out_ready = 0 if self._stream_out is None else int(self._stream_out.shape[0])
|
531 |
+
|
532 |
+
if end_in > total_in_stable or end_out > total_out_ready:
|
533 |
+
time.sleep(0.005)
|
534 |
continue
|
535 |
|
536 |
+
# 5) Slice OUTPUT-domain audio to send
|
537 |
+
slice_out = self._stream_out[start_out:end_out]
|
538 |
+
|
539 |
+
# Advance pointers + phases atomically
|
540 |
+
self._next_emit_start = end_in
|
541 |
+
self._emit_phase = new_phase_in
|
542 |
+
self._next_emit_start_out = end_out
|
543 |
+
self._emit_phase_out = new_phase_out
|
544 |
|
545 |
+
# 6) Post and encode
|
546 |
+
y = au.Waveform(slice_out.astype(_np.float32, copy=False), sr_out).as_stereo()
|
547 |
|
548 |
+
# Loudness: only on first chunk, match to ref if provided
|
549 |
if self.idx == 0 and self.params.ref_loop is not None:
|
550 |
y, _ = match_loudness_to_reference(
|
551 |
self.params.ref_loop, y,
|
552 |
method=self.params.loudness_mode,
|
553 |
headroom_db=self.params.headroom_db
|
554 |
)
|
555 |
+
# (No per-slice micro fades; stream continuity handles joins)
|
|
|
556 |
|
557 |
+
# Encode WAV (already sr_out and exact length by construction)
|
558 |
+
b64, total_samples, channels = wav_bytes_base64(
|
559 |
+
y.samples if y.samples.ndim == 2 else y.samples[:, None], sr_out
|
560 |
)
|
|
|
561 |
|
562 |
+
meta = {
|
563 |
+
"bpm": int(round(self.params.bpm)),
|
564 |
+
"bars": int(self.params.bars_per_chunk),
|
565 |
+
"beats_per_bar": int(self.params.beats_per_bar),
|
566 |
+
"sample_rate": int(sr_out),
|
567 |
+
"channels": int(channels),
|
568 |
+
"total_samples": int(total_samples),
|
569 |
+
"seconds_per_bar": float(spb),
|
570 |
+
"loop_duration_seconds": float(self.params.bars_per_chunk) * float(spb),
|
571 |
+
"guidance_weight": float(self.params.guidance_weight),
|
572 |
+
"temperature": float(self.params.temperature),
|
573 |
+
"topk": int(self.params.topk),
|
574 |
+
"xfade_seconds": float(xfade_s),
|
575 |
+
}
|
576 |
+
|
577 |
with self._lock:
|
578 |
self.idx += 1
|
579 |
self.outbox.append(JamChunk(index=self.idx, audio_base64=b64, metadata=meta))
|
580 |
+
# prune outbox
|
581 |
if len(self.outbox) > 10:
|
582 |
cutoff = self._last_delivered_index - 5
|
583 |
self.outbox = [ch for ch in self.outbox if ch.index > cutoff]
|
584 |
|
585 |
+
# Apply any pending reseed *between* chunks
|
586 |
+
if getattr(self, "_pending_reseed", None) is not None:
|
587 |
pkg = self._pending_reseed
|
588 |
self._pending_reseed = None
|
589 |
+
# A reseed handler has already swapped state.context_tokens upstream.
|
590 |
+
# Just request a one-shot bar realign against the new ref loop if present.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
591 |
self._needs_bar_realign = True
|
592 |
+
self._reseed_ref_loop = pkg.get("ref") if isinstance(pkg, dict) else None
|
593 |
|
594 |
+
time.sleep(0.001)
|
595 |
|
596 |
+
# --- graceful stop: flush resampler tail so last bits become available if client requests them ---
|
597 |
+
try:
|
598 |
+
if self._resampler is not None:
|
599 |
+
tail = self._resampler.flush()
|
600 |
+
if tail.size:
|
601 |
+
self._stream_out = tail if self._stream_out.size == 0 else _np.vstack([self._stream_out, tail])
|
602 |
+
except Exception as e:
|
603 |
+
print(f"⚠️ Resampler flush error: {e}")
|
604 |
|
605 |
print("🛑 JamWorker stopped")
|
|