thecollabagepatch commited on
Commit
02fcba6
·
1 Parent(s): 7ae6392

lets see how far the vibes go

Browse files
Files changed (1) hide show
  1. jam_worker.py +79 -8
jam_worker.py CHANGED
@@ -167,22 +167,93 @@ class JamWorker(threading.Thread):
167
 
168
  # ---------- context / reseed ----------
169
 
170
- def _install_context_from_loop(self, loop: au.Waveform):
171
- # Build a bar-aligned tail and encode to context tokens
172
- loop = loop.as_stereo().resample(self._model_sr)
173
- tail = take_bar_aligned_tail(loop, self.params.bpm, self.params.beats_per_bar, self._ctx_seconds)
174
- tokens_full = self.mrt.codec.encode(tail).astype(np.int32)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
175
  depth = int(self.mrt.config.decoder_codec_rvq_depth)
176
- context_tokens = tokens_full[:, :depth]
 
 
 
 
 
 
 
 
 
 
 
177
 
178
- # install state
 
 
 
179
  s = self.mrt.init_state()
180
  s.context_tokens = context_tokens
181
  self.state = s
 
182
 
183
- # keep an original copy for future splices
 
 
 
 
 
 
184
  self._original_context_tokens = np.copy(context_tokens)
185
 
 
 
 
 
 
 
186
  def reseed_from_waveform(self, wav: au.Waveform):
187
  """Immediate reseed: replace context from provided wave (bar-aligned tail)."""
188
  wav = wav.as_stereo().resample(self._model_sr)
 
167
 
168
  # ---------- context / reseed ----------
169
 
170
+ def _encode_exact_context_tokens(self, loop: au.Waveform) -> np.ndarray:
171
+ """Build *exactly* context_length_frames worth of tokens (e.g., 250 @ 25fps),
172
+ while ensuring the *end* of the audio lands on a bar boundary.
173
+ Strategy: take the largest integer number of bars <= ctx_seconds as the tail,
174
+ then left-fill from just before that tail (wrapping if needed) to reach exactly
175
+ ctx_seconds; finally, pad/trim to exact samples and, as a last resort, pad/trim
176
+ tokens to the expected frame count.
177
+ """
178
+ wav = loop.as_stereo().resample(self._model_sr)
179
+ data = wav.samples.astype(np.float32, copy=False)
180
+ if data.ndim == 1:
181
+ data = data[:, None]
182
+
183
+ spb = self._bar_clock.seconds_per_bar()
184
+ ctx_sec = float(self._ctx_seconds)
185
+ sr = int(self._model_sr)
186
+
187
+ # bars that fit fully inside ctx_sec (at least 1)
188
+ bars_fit = max(1, int(ctx_sec // spb))
189
+ tail_len_samps = int(round(bars_fit * spb * sr))
190
+
191
+ # ensure we have enough source by tiling
192
+ need = int(round(ctx_sec * sr)) + tail_len_samps
193
+ if data.shape[0] == 0:
194
+ data = np.zeros((1, 2), dtype=np.float32)
195
+ reps = int(np.ceil(need / float(data.shape[0])))
196
+ tiled = np.tile(data, (reps, 1))
197
+
198
+ end = tiled.shape[0]
199
+ tail = tiled[end - tail_len_samps:end]
200
+
201
+ # left-fill to reach exact ctx samples (keeps end-of-bar alignment)
202
+ ctx_samps = int(round(ctx_sec * sr))
203
+ pad_len = ctx_samps - tail.shape[0]
204
+ if pad_len > 0:
205
+ pre = tiled[end - tail_len_samps - pad_len:end - tail_len_samps]
206
+ ctx = np.concatenate([pre, tail], axis=0)
207
+ else:
208
+ ctx = tail[-ctx_samps:]
209
+
210
+ # final snap to *exact* ctx samples
211
+ if ctx.shape[0] < ctx_samps:
212
+ pad = np.zeros((ctx_samps - ctx.shape[0], ctx.shape[1]), dtype=np.float32)
213
+ ctx = np.concatenate([pad, ctx], axis=0)
214
+ elif ctx.shape[0] > ctx_samps:
215
+ ctx = ctx[-ctx_samps:]
216
+
217
+ exact = au.Waveform(ctx, sr)
218
+ tokens_full = self.mrt.codec.encode(exact).astype(np.int32)
219
  depth = int(self.mrt.config.decoder_codec_rvq_depth)
220
+ tokens = tokens_full[:, :depth]
221
+
222
+ # Last defense: force expected frame count
223
+ frames = tokens.shape[0]
224
+ exp = int(self._ctx_frames)
225
+ if frames < exp:
226
+ # repeat last frame
227
+ pad = np.repeat(tokens[-1:, :], exp - frames, axis=0)
228
+ tokens = np.concatenate([pad, tokens], axis=0)
229
+ elif frames > exp:
230
+ tokens = tokens[-exp:, :]
231
+ return tokens
232
 
233
+
234
+ def _install_context_from_loop(self, loop: au.Waveform):
235
+ # Build exact-length, bar-locked context tokens
236
+ context_tokens = self._encode_exact_context_tokens(loop)
237
  s = self.mrt.init_state()
238
  s.context_tokens = context_tokens
239
  self.state = s
240
+ self._original_context_tokens = np.copy(context_tokens)
241
 
242
+ def reseed_from_waveform(self, wav: au.Waveform):
243
+ """Immediate reseed: replace context from provided wave (bar-locked, exact length)."""
244
+ context_tokens = self._encode_exact_context_tokens(wav)
245
+ s = self.mrt.init_state()
246
+ s.context_tokens = context_tokens
247
+ self.state = s
248
+ self._model_stream = None # drop model-domain continuity so next chunk starts cleanly
249
  self._original_context_tokens = np.copy(context_tokens)
250
 
251
+ def reseed_splice(self, recent_wav: au.Waveform, anchor_bars: float):
252
+ """Queue a splice reseed to be applied right after the next emitted loop."""
253
+ new_ctx = self._encode_exact_context_tokens(recent_wav)
254
+ self._pending_reseed = {"ctx": new_ctx}
255
+
256
+
257
  def reseed_from_waveform(self, wav: au.Waveform):
258
  """Immediate reseed: replace context from provided wave (bar-aligned tail)."""
259
  wav = wav.as_stereo().resample(self._model_sr)