Commit
·
60a96af
1
Parent(s):
a8de318
gradual style changes
Browse files- app.py +17 -4
- jam_worker.py +32 -15
app.py
CHANGED
@@ -1602,7 +1602,9 @@ async def ws_jam(websocket: WebSocket):
|
|
1602 |
# Stash rt session fields
|
1603 |
websocket._mrt = mrt
|
1604 |
websocket._state = state
|
1605 |
-
websocket.
|
|
|
|
|
1606 |
|
1607 |
websocket._rt_mean = mean_w
|
1608 |
websocket._rt_centroid_weights = cw
|
@@ -1628,7 +1630,15 @@ async def ws_jam(websocket: WebSocket):
|
|
1628 |
mrt.temperature = websocket._rt_temp
|
1629 |
mrt.topk = websocket._rt_topk
|
1630 |
|
1631 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1632 |
websocket._state = new_state
|
1633 |
|
1634 |
x = wav.samples.astype(np.float32, copy=False)
|
@@ -1726,8 +1736,7 @@ async def ws_jam(websocket: WebSocket):
|
|
1726 |
text_w = [float(x) for x in style_weights_str.split(",")] if style_weights_str else []
|
1727 |
|
1728 |
_ensure_assets_loaded()
|
1729 |
-
|
1730 |
-
websocket._style = build_style_vector(
|
1731 |
websocket._mrt,
|
1732 |
text_styles=text_list,
|
1733 |
text_weights=text_w,
|
@@ -1736,6 +1745,10 @@ async def ws_jam(websocket: WebSocket):
|
|
1736 |
mean_weight=float(websocket._rt_mean),
|
1737 |
centroid_weights=websocket._rt_centroid_weights,
|
1738 |
)
|
|
|
|
|
|
|
|
|
1739 |
await send_json({"type":"status","updated":"rt-knobs+style"})
|
1740 |
|
1741 |
elif mtype == "consume" and mode == "bar":
|
|
|
1602 |
# Stash rt session fields
|
1603 |
websocket._mrt = mrt
|
1604 |
websocket._state = state
|
1605 |
+
websocket._style_cur = style_vec
|
1606 |
+
websocket._style_tgt = style_vec
|
1607 |
+
websocket._style_ramp_s = float(params.get("style_ramp_seconds", 0.0))
|
1608 |
|
1609 |
websocket._rt_mean = mean_w
|
1610 |
websocket._rt_centroid_weights = cw
|
|
|
1630 |
mrt.temperature = websocket._rt_temp
|
1631 |
mrt.topk = websocket._rt_topk
|
1632 |
|
1633 |
+
# ramp style
|
1634 |
+
ramp = float(getattr(websocket, "_style_ramp_s", 0.0) or 0.0)
|
1635 |
+
if ramp <= 0.0:
|
1636 |
+
websocket._style_cur = websocket._style_tgt
|
1637 |
+
else:
|
1638 |
+
step = min(1.0, chunk_secs / ramp)
|
1639 |
+
websocket._style_cur = websocket._style_cur + step * (websocket._style_tgt - websocket._style_cur)
|
1640 |
+
|
1641 |
+
wav, new_state = mrt.generate_chunk(state=websocket._state, style=websocket._style_cur)
|
1642 |
websocket._state = new_state
|
1643 |
|
1644 |
x = wav.samples.astype(np.float32, copy=False)
|
|
|
1736 |
text_w = [float(x) for x in style_weights_str.split(",")] if style_weights_str else []
|
1737 |
|
1738 |
_ensure_assets_loaded()
|
1739 |
+
websocket._style_tgt = build_style_vector(
|
|
|
1740 |
websocket._mrt,
|
1741 |
text_styles=text_list,
|
1742 |
text_weights=text_w,
|
|
|
1745 |
mean_weight=float(websocket._rt_mean),
|
1746 |
centroid_weights=websocket._rt_centroid_weights,
|
1747 |
)
|
1748 |
+
# optionally allow live changes to ramp:
|
1749 |
+
if "style_ramp_seconds" in msg:
|
1750 |
+
try: websocket._style_ramp_s = float(msg["style_ramp_seconds"])
|
1751 |
+
except: pass
|
1752 |
await send_json({"type":"status","updated":"rt-knobs+style"})
|
1753 |
|
1754 |
elif mtype == "consume" and mode == "bar":
|
jam_worker.py
CHANGED
@@ -35,6 +35,7 @@ class JamParams:
|
|
35 |
guidance_weight: float = 1.1
|
36 |
temperature: float = 1.1
|
37 |
topk: int = 40
|
|
|
38 |
|
39 |
|
40 |
@dataclass
|
@@ -90,7 +91,11 @@ class JamWorker(threading.Thread):
|
|
90 |
self.mrt.topk = int(self.params.topk)
|
91 |
|
92 |
# style vector (already normalized upstream)
|
93 |
-
self._style_vec = self.params.style_vec
|
|
|
|
|
|
|
|
|
94 |
|
95 |
# codec/setup
|
96 |
self._codec_fps = float(self.mrt.codec.frame_rate)
|
@@ -124,7 +129,7 @@ class JamWorker(threading.Thread):
|
|
124 |
|
125 |
# control flags
|
126 |
self._stop_event = threading.Event()
|
127 |
-
self._max_buffer_ahead =
|
128 |
|
129 |
# reseed queues (install at next bar boundary after emission)
|
130 |
self._pending_reseed: Optional[dict] = None # legacy full reset path (kept for fallback)
|
@@ -136,6 +141,17 @@ class JamWorker(threading.Thread):
|
|
136 |
|
137 |
# ---------- lifecycle ----------
|
138 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
139 |
def stop(self):
|
140 |
self._stop_event.set()
|
141 |
|
@@ -400,17 +416,6 @@ class JamWorker(threading.Thread):
|
|
400 |
# also remember this as new "original"
|
401 |
self._original_context_tokens = np.copy(context_tokens)
|
402 |
|
403 |
-
def reseed_splice(self, recent_wav: au.Waveform, anchor_bars: float):
|
404 |
-
"""Queue a splice reseed to be applied right after the next emitted loop.
|
405 |
-
For now, we simply replace the context by recent wave tail; anchor is accepted
|
406 |
-
for API compatibility and future crossfade/token-splice logic."""
|
407 |
-
recent_wav = recent_wav.as_stereo().resample(self._model_sr)
|
408 |
-
tail = take_bar_aligned_tail(recent_wav, self.params.bpm, self.params.beats_per_bar, self._ctx_seconds)
|
409 |
-
tokens_full = self.mrt.codec.encode(tail).astype(np.int32)
|
410 |
-
depth = int(self.mrt.config.decoder_codec_rvq_depth)
|
411 |
-
new_ctx = tokens_full[:, :depth]
|
412 |
-
self._pending_reseed = {"ctx": new_ctx}
|
413 |
-
|
414 |
# ---------- core streaming helpers ----------
|
415 |
|
416 |
def _append_model_chunk_and_spool(self, wav: au.Waveform):
|
@@ -538,8 +543,20 @@ class JamWorker(threading.Thread):
|
|
538 |
# generate next model chunk
|
539 |
# snapshot current style vector under lock for this step
|
540 |
with self._lock:
|
541 |
-
|
542 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
543 |
# append and spool
|
544 |
self._append_model_chunk_and_spool(wav)
|
545 |
# try emitting zero or more chunks if available
|
|
|
35 |
guidance_weight: float = 1.1
|
36 |
temperature: float = 1.1
|
37 |
topk: int = 40
|
38 |
+
style_ramp_seconds: float = 0.0 # 0 => instant (current behavior), try 6.0–10.0 for gentle glides
|
39 |
|
40 |
|
41 |
@dataclass
|
|
|
91 |
self.mrt.topk = int(self.params.topk)
|
92 |
|
93 |
# style vector (already normalized upstream)
|
94 |
+
self._style_vec = (None if self.params.style_vec is None
|
95 |
+
else np.array(self.params.style_vec, dtype=np.float32, copy=True))
|
96 |
+
self._chunk_secs = (
|
97 |
+
self.mrt.config.chunk_length_frames * self.mrt.config.frame_length_samples
|
98 |
+
) / float(self._model_sr) # ≈ 2.0 s by default
|
99 |
|
100 |
# codec/setup
|
101 |
self._codec_fps = float(self.mrt.codec.frame_rate)
|
|
|
129 |
|
130 |
# control flags
|
131 |
self._stop_event = threading.Event()
|
132 |
+
self._max_buffer_ahead = 1
|
133 |
|
134 |
# reseed queues (install at next bar boundary after emission)
|
135 |
self._pending_reseed: Optional[dict] = None # legacy full reset path (kept for fallback)
|
|
|
141 |
|
142 |
# ---------- lifecycle ----------
|
143 |
|
144 |
+
def set_buffer_seconds(self, seconds: float):
|
145 |
+
"""Clamp how far ahead we allow, in *seconds* of audio."""
|
146 |
+
chunk_secs = float(self.params.bars_per_chunk) * self._bar_clock.seconds_per_bar()
|
147 |
+
max_chunks = max(0, int(round(seconds / max(chunk_secs, 1e-6))))
|
148 |
+
with self._cv:
|
149 |
+
self._max_buffer_ahead = max_chunks
|
150 |
+
|
151 |
+
def set_buffer_chunks(self, k: int):
|
152 |
+
with self._cv:
|
153 |
+
self._max_buffer_ahead = max(0, int(k))
|
154 |
+
|
155 |
def stop(self):
|
156 |
self._stop_event.set()
|
157 |
|
|
|
416 |
# also remember this as new "original"
|
417 |
self._original_context_tokens = np.copy(context_tokens)
|
418 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
419 |
# ---------- core streaming helpers ----------
|
420 |
|
421 |
def _append_model_chunk_and_spool(self, wav: au.Waveform):
|
|
|
543 |
# generate next model chunk
|
544 |
# snapshot current style vector under lock for this step
|
545 |
with self._lock:
|
546 |
+
target = self.params.style_vec
|
547 |
+
if target is None:
|
548 |
+
style_to_use = None
|
549 |
+
else:
|
550 |
+
if self._style_vec is None: # first use: start exactly at initial style (no glide)
|
551 |
+
self._style_vec = np.array(target, dtype=np.float32, copy=True)
|
552 |
+
else:
|
553 |
+
ramp = float(self.params.style_ramp_seconds or 0.0)
|
554 |
+
step = 1.0 if ramp <= 0.0 else min(1.0, self._chunk_secs / ramp)
|
555 |
+
# linear ramp in embedding space
|
556 |
+
self._style_vec += step * (target.astype(np.float32, copy=False) - self._style_vec)
|
557 |
+
style_to_use = self._style_vec
|
558 |
+
|
559 |
+
wav, self.state = self.mrt.generate_chunk(state=self.state, style=style_to_use)
|
560 |
# append and spool
|
561 |
self._append_model_chunk_and_spool(wav)
|
562 |
# try emitting zero or more chunks if available
|