thecollabagepatch commited on
Commit
60a96af
·
1 Parent(s): a8de318

gradual style changes

Browse files
Files changed (2) hide show
  1. app.py +17 -4
  2. jam_worker.py +32 -15
app.py CHANGED
@@ -1602,7 +1602,9 @@ async def ws_jam(websocket: WebSocket):
1602
  # Stash rt session fields
1603
  websocket._mrt = mrt
1604
  websocket._state = state
1605
- websocket._style = style_vec
 
 
1606
 
1607
  websocket._rt_mean = mean_w
1608
  websocket._rt_centroid_weights = cw
@@ -1628,7 +1630,15 @@ async def ws_jam(websocket: WebSocket):
1628
  mrt.temperature = websocket._rt_temp
1629
  mrt.topk = websocket._rt_topk
1630
 
1631
- wav, new_state = mrt.generate_chunk(state=websocket._state, style=websocket._style)
 
 
 
 
 
 
 
 
1632
  websocket._state = new_state
1633
 
1634
  x = wav.samples.astype(np.float32, copy=False)
@@ -1726,8 +1736,7 @@ async def ws_jam(websocket: WebSocket):
1726
  text_w = [float(x) for x in style_weights_str.split(",")] if style_weights_str else []
1727
 
1728
  _ensure_assets_loaded()
1729
- # build final style vec (no loop embedding in rt-mode)
1730
- websocket._style = build_style_vector(
1731
  websocket._mrt,
1732
  text_styles=text_list,
1733
  text_weights=text_w,
@@ -1736,6 +1745,10 @@ async def ws_jam(websocket: WebSocket):
1736
  mean_weight=float(websocket._rt_mean),
1737
  centroid_weights=websocket._rt_centroid_weights,
1738
  )
 
 
 
 
1739
  await send_json({"type":"status","updated":"rt-knobs+style"})
1740
 
1741
  elif mtype == "consume" and mode == "bar":
 
1602
  # Stash rt session fields
1603
  websocket._mrt = mrt
1604
  websocket._state = state
1605
+ websocket._style_cur = style_vec
1606
+ websocket._style_tgt = style_vec
1607
+ websocket._style_ramp_s = float(params.get("style_ramp_seconds", 0.0))
1608
 
1609
  websocket._rt_mean = mean_w
1610
  websocket._rt_centroid_weights = cw
 
1630
  mrt.temperature = websocket._rt_temp
1631
  mrt.topk = websocket._rt_topk
1632
 
1633
+ # ramp style
1634
+ ramp = float(getattr(websocket, "_style_ramp_s", 0.0) or 0.0)
1635
+ if ramp <= 0.0:
1636
+ websocket._style_cur = websocket._style_tgt
1637
+ else:
1638
+ step = min(1.0, chunk_secs / ramp)
1639
+ websocket._style_cur = websocket._style_cur + step * (websocket._style_tgt - websocket._style_cur)
1640
+
1641
+ wav, new_state = mrt.generate_chunk(state=websocket._state, style=websocket._style_cur)
1642
  websocket._state = new_state
1643
 
1644
  x = wav.samples.astype(np.float32, copy=False)
 
1736
  text_w = [float(x) for x in style_weights_str.split(",")] if style_weights_str else []
1737
 
1738
  _ensure_assets_loaded()
1739
+ websocket._style_tgt = build_style_vector(
 
1740
  websocket._mrt,
1741
  text_styles=text_list,
1742
  text_weights=text_w,
 
1745
  mean_weight=float(websocket._rt_mean),
1746
  centroid_weights=websocket._rt_centroid_weights,
1747
  )
1748
+ # optionally allow live changes to ramp:
1749
+ if "style_ramp_seconds" in msg:
1750
+ try: websocket._style_ramp_s = float(msg["style_ramp_seconds"])
1751
+ except: pass
1752
  await send_json({"type":"status","updated":"rt-knobs+style"})
1753
 
1754
  elif mtype == "consume" and mode == "bar":
jam_worker.py CHANGED
@@ -35,6 +35,7 @@ class JamParams:
35
  guidance_weight: float = 1.1
36
  temperature: float = 1.1
37
  topk: int = 40
 
38
 
39
 
40
  @dataclass
@@ -90,7 +91,11 @@ class JamWorker(threading.Thread):
90
  self.mrt.topk = int(self.params.topk)
91
 
92
  # style vector (already normalized upstream)
93
- self._style_vec = self.params.style_vec
 
 
 
 
94
 
95
  # codec/setup
96
  self._codec_fps = float(self.mrt.codec.frame_rate)
@@ -124,7 +129,7 @@ class JamWorker(threading.Thread):
124
 
125
  # control flags
126
  self._stop_event = threading.Event()
127
- self._max_buffer_ahead = 5
128
 
129
  # reseed queues (install at next bar boundary after emission)
130
  self._pending_reseed: Optional[dict] = None # legacy full reset path (kept for fallback)
@@ -136,6 +141,17 @@ class JamWorker(threading.Thread):
136
 
137
  # ---------- lifecycle ----------
138
 
 
 
 
 
 
 
 
 
 
 
 
139
  def stop(self):
140
  self._stop_event.set()
141
 
@@ -400,17 +416,6 @@ class JamWorker(threading.Thread):
400
  # also remember this as new "original"
401
  self._original_context_tokens = np.copy(context_tokens)
402
 
403
- def reseed_splice(self, recent_wav: au.Waveform, anchor_bars: float):
404
- """Queue a splice reseed to be applied right after the next emitted loop.
405
- For now, we simply replace the context by recent wave tail; anchor is accepted
406
- for API compatibility and future crossfade/token-splice logic."""
407
- recent_wav = recent_wav.as_stereo().resample(self._model_sr)
408
- tail = take_bar_aligned_tail(recent_wav, self.params.bpm, self.params.beats_per_bar, self._ctx_seconds)
409
- tokens_full = self.mrt.codec.encode(tail).astype(np.int32)
410
- depth = int(self.mrt.config.decoder_codec_rvq_depth)
411
- new_ctx = tokens_full[:, :depth]
412
- self._pending_reseed = {"ctx": new_ctx}
413
-
414
  # ---------- core streaming helpers ----------
415
 
416
  def _append_model_chunk_and_spool(self, wav: au.Waveform):
@@ -538,8 +543,20 @@ class JamWorker(threading.Thread):
538
  # generate next model chunk
539
  # snapshot current style vector under lock for this step
540
  with self._lock:
541
- style_vec = self.params.style_vec
542
- wav, self.state = self.mrt.generate_chunk(state=self.state, style=style_vec)
 
 
 
 
 
 
 
 
 
 
 
 
543
  # append and spool
544
  self._append_model_chunk_and_spool(wav)
545
  # try emitting zero or more chunks if available
 
35
  guidance_weight: float = 1.1
36
  temperature: float = 1.1
37
  topk: int = 40
38
+ style_ramp_seconds: float = 0.0 # 0 => instant (current behavior), try 6.0–10.0 for gentle glides
39
 
40
 
41
  @dataclass
 
91
  self.mrt.topk = int(self.params.topk)
92
 
93
  # style vector (already normalized upstream)
94
+ self._style_vec = (None if self.params.style_vec is None
95
+ else np.array(self.params.style_vec, dtype=np.float32, copy=True))
96
+ self._chunk_secs = (
97
+ self.mrt.config.chunk_length_frames * self.mrt.config.frame_length_samples
98
+ ) / float(self._model_sr) # ≈ 2.0 s by default
99
 
100
  # codec/setup
101
  self._codec_fps = float(self.mrt.codec.frame_rate)
 
129
 
130
  # control flags
131
  self._stop_event = threading.Event()
132
+ self._max_buffer_ahead = 1
133
 
134
  # reseed queues (install at next bar boundary after emission)
135
  self._pending_reseed: Optional[dict] = None # legacy full reset path (kept for fallback)
 
141
 
142
  # ---------- lifecycle ----------
143
 
144
+ def set_buffer_seconds(self, seconds: float):
145
+ """Clamp how far ahead we allow, in *seconds* of audio."""
146
+ chunk_secs = float(self.params.bars_per_chunk) * self._bar_clock.seconds_per_bar()
147
+ max_chunks = max(0, int(round(seconds / max(chunk_secs, 1e-6))))
148
+ with self._cv:
149
+ self._max_buffer_ahead = max_chunks
150
+
151
+ def set_buffer_chunks(self, k: int):
152
+ with self._cv:
153
+ self._max_buffer_ahead = max(0, int(k))
154
+
155
  def stop(self):
156
  self._stop_event.set()
157
 
 
416
  # also remember this as new "original"
417
  self._original_context_tokens = np.copy(context_tokens)
418
 
 
 
 
 
 
 
 
 
 
 
 
419
  # ---------- core streaming helpers ----------
420
 
421
  def _append_model_chunk_and_spool(self, wav: au.Waveform):
 
543
  # generate next model chunk
544
  # snapshot current style vector under lock for this step
545
  with self._lock:
546
+ target = self.params.style_vec
547
+ if target is None:
548
+ style_to_use = None
549
+ else:
550
+ if self._style_vec is None: # first use: start exactly at initial style (no glide)
551
+ self._style_vec = np.array(target, dtype=np.float32, copy=True)
552
+ else:
553
+ ramp = float(self.params.style_ramp_seconds or 0.0)
554
+ step = 1.0 if ramp <= 0.0 else min(1.0, self._chunk_secs / ramp)
555
+ # linear ramp in embedding space
556
+ self._style_vec += step * (target.astype(np.float32, copy=False) - self._style_vec)
557
+ style_to_use = self._style_vec
558
+
559
+ wav, self.state = self.mrt.generate_chunk(state=self.state, style=style_to_use)
560
  # append and spool
561
  self._append_model_chunk_and_spool(wav)
562
  # try emitting zero or more chunks if available