thecollabagepatch commited on
Commit
384e4ac
·
1 Parent(s): 892466c

creating model_management.py for extraction

Browse files
Files changed (3) hide show
  1. Dockerfile +3 -0
  2. app.py +239 -739
  3. model_management.py +374 -0
Dockerfile CHANGED
@@ -143,6 +143,9 @@ COPY --chown=appuser:appuser utils.py /home/appuser/app/utils.py
143
  COPY --chown=appuser:appuser jam_worker.py /home/appuser/app/jam_worker.py
144
 
145
  COPY --chown=appuser:appuser one_shot_generation.py /home/appuser/app/one_shot_generation.py
 
 
 
146
  COPY --chown=appuser:appuser documentation.html /home/appuser/app/documentation.html
147
 
148
  # Create docs directory and copy documentation files
 
143
  COPY --chown=appuser:appuser jam_worker.py /home/appuser/app/jam_worker.py
144
 
145
  COPY --chown=appuser:appuser one_shot_generation.py /home/appuser/app/one_shot_generation.py
146
+
147
+ COPY --chown=appuser:appuser model_management.py /home/appuser/app/model_management.py
148
+
149
  COPY --chown=appuser:appuser documentation.html /home/appuser/app/documentation.html
150
 
151
  # Create docs directory and copy documentation files
app.py CHANGED
@@ -74,6 +74,8 @@ from huggingface_hub import snapshot_download, HfApi
74
 
75
  from pydantic import BaseModel
76
 
 
 
77
  # ---- Finetune assets (mean & centroids) --------------------------------------
78
  _FINETUNE_REPO_DEFAULT = os.getenv("MRT_ASSETS_REPO", "thepatch/magenta-ft")
79
  _ASSETS_REPO_ID: str | None = None
@@ -82,28 +84,38 @@ _CENTROIDS: np.ndarray | None = None # shape (K, D) dtype float32
82
 
83
  _STEP_RE = re.compile(r"(?:^|/)checkpoint_(\d+)(?:/|\.tar\.gz|\.tgz)?$")
84
 
85
- def _list_ckpt_steps(repo_id: str, revision: str = "main") -> list[int]:
86
- """
87
- List available checkpoint steps in a HF model repo without downloading all weights.
88
- Looks for:
89
- checkpoint_<step>/
90
- checkpoint_<step>.tgz | .tar.gz
91
- archives/checkpoint_<step>.tgz | .tar.gz
92
- """
93
- api = HfApi()
94
- files = api.list_repo_files(repo_id=repo_id, repo_type="model", revision=revision)
95
- steps = set()
96
- for f in files:
97
- m = _STEP_RE.search(f)
98
- if m:
99
- try:
100
- steps.add(int(m.group(1)))
101
- except:
102
- pass
103
- return sorted(steps)
104
 
105
- def _step_exists(repo_id: str, revision: str, step: int) -> bool:
106
- return step in _list_ckpt_steps(repo_id, revision)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
 
108
  def _any_jam_running() -> bool:
109
  with jam_lock:
@@ -117,131 +129,131 @@ def _stop_all_jams(timeout: float = 5.0):
117
  w.join(timeout=timeout)
118
  jam_registry.pop(sid, None)
119
 
120
- def _load_finetune_assets_from_hf(repo_id: str | None) -> tuple[bool, str]:
121
- """
122
- Download & load mean_style_embed.npy and cluster_centroids.npy from a HF model repo.
123
- Safe to call multiple times; will overwrite globals if successful.
124
- """
125
- global _ASSETS_REPO_ID, _MEAN_EMBED, _CENTROIDS
126
- repo_id = repo_id or _FINETUNE_REPO_DEFAULT
127
- try:
128
- from huggingface_hub import hf_hub_download
129
- mean_path = None
130
- cent_path = None
131
- try:
132
- mean_path = hf_hub_download(repo_id, filename="mean_style_embed.npy", repo_type="model")
133
- except Exception:
134
- pass
135
- try:
136
- cent_path = hf_hub_download(repo_id, filename="cluster_centroids.npy", repo_type="model")
137
- except Exception:
138
- pass
139
-
140
- if mean_path is None and cent_path is None:
141
- return False, f"No finetune asset files found in repo {repo_id}"
142
-
143
- if mean_path is not None:
144
- m = np.load(mean_path)
145
- if m.ndim != 1:
146
- return False, f"mean_style_embed.npy must be 1-D (got {m.shape})"
147
- else:
148
- m = None
149
-
150
- if cent_path is not None:
151
- c = np.load(cent_path)
152
- if c.ndim != 2:
153
- return False, f"cluster_centroids.npy must be 2-D (got {c.shape})"
154
- else:
155
- c = None
156
-
157
- # Optional: shape check vs model embedding dim once model is alive
158
- try:
159
- d = int(get_mrt().style_model.config.embedding_dim)
160
- if m is not None and m.shape[0] != d:
161
- return False, f"mean_style_embed dim {m.shape[0]} != model dim {d}"
162
- if c is not None and c.shape[1] != d:
163
- return False, f"cluster_centroids dim {c.shape[1]} != model dim {d}"
164
- except Exception:
165
- # Model not built yet; we’ll trust the files and rely on runtime checks later
166
- pass
167
-
168
- _MEAN_EMBED = m.astype(np.float32, copy=False) if m is not None else None
169
- _CENTROIDS = c.astype(np.float32, copy=False) if c is not None else None
170
- _ASSETS_REPO_ID = repo_id
171
- logging.info("Loaded finetune assets from %s (mean=%s, centroids=%s)",
172
- repo_id,
173
- "yes" if _MEAN_EMBED is not None else "no",
174
- f"{_CENTROIDS.shape[0]}x{_CENTROIDS.shape[1]}" if _CENTROIDS is not None else "no")
175
- return True, "ok"
176
- except Exception as e:
177
- logging.exception("Failed to load finetune assets: %s", e)
178
- return False, str(e)
179
-
180
- def _ensure_assets_loaded():
181
- # Best-effort lazy load if nothing is loaded yet
182
- if _MEAN_EMBED is None and _CENTROIDS is None:
183
- _load_finetune_assets_from_hf(_ASSETS_REPO_ID or _FINETUNE_REPO_DEFAULT)
184
  # ------------------------------------------------------------------------------
185
 
186
- def _resolve_checkpoint_dir() -> str | None:
187
- repo_id = os.getenv("MRT_CKPT_REPO")
188
- if not repo_id:
189
- return None
190
- step = os.getenv("MRT_CKPT_STEP") # e.g. "1863001"
191
-
192
- root = Path(snapshot_download(
193
- repo_id=repo_id,
194
- repo_type="model",
195
- revision=os.getenv("MRT_CKPT_REV", "main"),
196
- local_dir="/home/appuser/.cache/mrt_ckpt/repo",
197
- local_dir_use_symlinks=False,
198
- ))
199
-
200
- # Prefer an archive if present (more reliable for Zarr/T5X)
201
- arch_names = [
202
- f"checkpoint_{step}.tgz",
203
- f"checkpoint_{step}.tar.gz",
204
- f"archives/checkpoint_{step}.tgz",
205
- f"archives/checkpoint_{step}.tar.gz",
206
- ] if step else []
207
-
208
- cache_root = Path("/home/appuser/.cache/mrt_ckpt/extracted")
209
- cache_root.mkdir(parents=True, exist_ok=True)
210
- for name in arch_names:
211
- arch = root / name
212
- if arch.is_file():
213
- out_dir = cache_root / f"checkpoint_{step}"
214
- marker = out_dir.with_suffix(".ok")
215
- if not marker.exists():
216
- out_dir.mkdir(parents=True, exist_ok=True)
217
- with tarfile.open(arch, "r:*") as tf:
218
- tf.extractall(out_dir)
219
- marker.write_text("ok")
220
- # sanity: require .zarray to exist inside the extracted tree
221
- if not any(out_dir.rglob(".zarray")):
222
- raise RuntimeError(f"Extracted archive missing .zarray files: {out_dir}")
223
- return str(out_dir / f"checkpoint_{step}") if (out_dir / f"checkpoint_{step}").exists() else str(out_dir)
224
-
225
- # No archive; try raw folder from repo and sanity check.
226
- if step:
227
- raw = root / f"checkpoint_{step}"
228
- if raw.is_dir():
229
- if not any(raw.rglob(".zarray")):
230
- raise RuntimeError(
231
- f"Downloaded checkpoint_{step} appears incomplete (no .zarray). "
232
- "Upload as a .tgz or push via git from a Unix shell."
233
- )
234
- return str(raw)
235
-
236
- # Pick latest if no step
237
- step_dirs = [d for d in root.iterdir() if d.is_dir() and re.match(r"checkpoint_\\d+$", d.name)]
238
- if step_dirs:
239
- pick = max(step_dirs, key=lambda d: int(d.name.split('_')[-1]))
240
- if not any(pick.rglob(".zarray")):
241
- raise RuntimeError(f"Downloaded {pick} appears incomplete (no .zarray).")
242
- return str(pick)
243
-
244
- return None
245
 
246
 
247
  async def send_json_safe(ws: WebSocket, obj) -> bool:
@@ -292,252 +304,6 @@ def _patch_t5x_for_gpu_coords():
292
  # Call the patch immediately at import time (before MagentaRT init)
293
  _patch_t5x_for_gpu_coords()
294
 
295
- def create_documentation_interface():
296
- """Create a Gradio interface for documentation and transparency"""
297
- with gr.Blocks(title="MagentaRT Research API", theme=gr.themes.Soft()) as interface:
298
- gr.Markdown(
299
- r"""
300
- # 🎵 MagentaRT Live Music Generation Research API
301
-
302
- **Research-only implementation for iOS/web app development**
303
-
304
- This API uses Google's [MagentaRT](https://github.com/magenta/magenta-realtime) to generate
305
- continuous music either as **bar-aligned chunks over HTTP** or as **low-latency realtime chunks via WebSocket**.
306
- """
307
- )
308
-
309
- with gr.Tabs():
310
- # ------------------------------------------------------------------
311
- # About & current status
312
- # ------------------------------------------------------------------
313
- with gr.Tab("📖 About & Status"):
314
- gr.Markdown(
315
- r"""
316
- ## What this is
317
- We're exploring AI‑assisted loop‑based music creation that can run on GPUs (not just TPUs) and stream to apps in realtime.
318
-
319
- ### Implemented backends
320
- - **HTTP (bar‑aligned):** `/generate`, `/jam/start`, `/jam/next`, `/jam/stop`, `/jam/update`, etc.
321
- - **WebSocket (realtime):** `ws://…/ws/jam` with `mode="rt"` (Colab‑style continuous chunks). New in this build.
322
-
323
- ## What we learned (GPU notes)
324
- - **L40S 48GB:** comfortably **faster than realtime** → we added a `pace: "realtime"` switch so the server doesn’t outrun playback.
325
- - **L4 24GB:** **consistently just under realtime**; even with pre‑roll buffering, TF32/JAX tunings, reduced chunk size, and the **base** checkpoint, we still see eventual under‑runs.
326
- - **Implication:** For production‑quality realtime, aim for ~**40GB VRAM** per user/session (e.g., **A100 40GB**, or MIG slices ≈ **35–40GB** on newer parts). Smaller GPUs can demo, but sustained realtime is not reliable.
327
-
328
- ## Model / audio specs
329
- - **Model:** MagentaRT (T5X; decoder RVQ depth = 16)
330
- - **Audio:** 48 kHz stereo, 2.0 s chunks by default, 40 ms crossfade
331
- - **Context:** 10 s rolling context window
332
- """
333
- )
334
-
335
- # ------------------------------------------------------------------
336
- # HTTP API
337
- # ------------------------------------------------------------------
338
- with gr.Tab("🔧 API (HTTP)"):
339
- gr.Markdown(
340
- r"""
341
- ### Single Generation
342
- ```bash
343
- curl -X POST \
344
- "$HOST/generate" \
345
- -F "loop_audio=@drum_loop.wav" \
346
- -F "bpm=120" \
347
- -F "bars=8" \
348
- -F "styles=acid house,techno" \
349
- -F "guidance_weight=5.0" \
350
- -F "temperature=1.1"
351
- ```
352
-
353
- ### Continuous Jamming (bar‑aligned, HTTP)
354
- ```bash
355
- # 1) Start a session
356
- echo $(curl -s -X POST "$HOST/jam/start" \
357
- -F "loop_audio=@loop.wav" \
358
- -F "bpm=120" \
359
- -F "bars_per_chunk=8") | jq .
360
- # → {"session_id":"…"}
361
-
362
- # 2) Pull next chunk (repeat)
363
- curl "$HOST/jam/next?session_id=$SESSION"
364
-
365
- # 3) Stop
366
- curl -X POST "$HOST/jam/stop" \
367
- -H "Content-Type: application/json" \
368
- -d '{"session_id":"'$SESSION'"}'
369
- ```
370
-
371
- ### Common parameters
372
- - **bpm** *(int)* – beats per minute
373
- - **bars / bars_per_chunk** *(int)* – musical length
374
- - **styles** *(str)* – comma‑separated text prompts (mixed internally)
375
- - **guidance_weight** *(float)* – style adherence (CFG weight)
376
- - **temperature / topk** – sampling controls
377
- - **intro_bars_to_drop** *(int, /generate)* – generate-and-trim intro
378
- """
379
- )
380
-
381
- # ------------------------------------------------------------------
382
- # WebSocket API: realtime (‘rt’ mode)
383
- # ------------------------------------------------------------------
384
- with gr.Tab("🧩 API (WebSocket • rt mode)"):
385
- gr.Markdown(
386
- r"""
387
- Connect to `wss://…/ws/jam` and send a **JSON control stream**. In `rt` mode the server emits ~2 s WAV chunks (or binary frames) continuously.
388
-
389
- ### Start (client → server)
390
- ```jsonc
391
- {
392
- "type": "start",
393
- "mode": "rt",
394
- "binary_audio": false, // true → raw WAV bytes + separate chunk_meta
395
- "params": {
396
- "styles": "heavy metal", // or "jazz, hiphop"
397
- "style_weights": "1.0,1.0", // optional, auto‑normalized
398
- "temperature": 1.1,
399
- "topk": 40,
400
- "guidance_weight": 1.1,
401
- "pace": "realtime", // "realtime" | "asap" (default)
402
- "max_decode_frames": 50 // 50≈2.0s; try 36–45 on smaller GPUs
403
- }
404
- }
405
- ```
406
-
407
- ### Server events (server → client)
408
- - `{"type":"started","mode":"rt"}` – handshake
409
- - `{"type":"chunk","audio_base64":"…","metadata":{…}}` – base64 WAV
410
- - `metadata.sample_rate` *(int)* – usually 48000
411
- - `metadata.chunk_frames` *(int)* – e.g., 50
412
- - `metadata.chunk_seconds` *(float)* – frames / 25.0
413
- - `metadata.crossfade_seconds` *(float)* – typically 0.04
414
- - `{"type":"chunk_meta","metadata":{…}}` – sent **after** a binary frame when `binary_audio=true`
415
- - `{"type":"status",…}`, `{"type":"error",…}`, `{"type":"stopped"}`
416
-
417
- ### Update (client → server)
418
- ```jsonc
419
- {
420
- "type": "update",
421
- "styles": "jazz, hiphop",
422
- "style_weights": "1.0,0.8",
423
- "temperature": 1.2,
424
- "topk": 64,
425
- "guidance_weight": 1.0,
426
- "pace": "realtime", // optional live flip
427
- "max_decode_frames": 40 // optional; <= 50
428
- }
429
- ```
430
-
431
- ### Stop / ping
432
- ```json
433
- {"type":"stop"}
434
- {"type":"ping"}
435
- ```
436
-
437
- ### Browser quick‑start (schedules seamlessly with 25–40 ms crossfade)
438
- ```html
439
- <script>
440
- const XFADE = 0.025; // 25 ms
441
- let ctx, gain, ws, nextTime = 0;
442
- async function start(){
443
- ctx = new (window.AudioContext||window.webkitAudioContext)();
444
- gain = ctx.createGain(); gain.connect(ctx.destination);
445
- ws = new WebSocket("wss://YOUR_SPACE/ws/jam");
446
- ws.onopen = ()=> ws.send(JSON.stringify({
447
- type:"start", mode:"rt", binary_audio:false,
448
- params:{ styles:"warmup", temperature:1.1, topk:40, guidance_weight:1.1, pace:"realtime" }
449
- }));
450
- ws.onmessage = async ev => {
451
- const msg = JSON.parse(ev.data);
452
- if (msg.type === "chunk" && msg.audio_base64){
453
- const bin = atob(msg.audio_base64); const buf = new Uint8Array(bin.length);
454
- for (let i=0;i<bin.length;i++) buf[i] = bin.charCodeAt(i);
455
- const ab = buf.buffer; const audio = await ctx.decodeAudioData(ab);
456
- const src = ctx.createBufferSource(); const g = ctx.createGain();
457
- src.buffer = audio; src.connect(g); g.connect(gain);
458
- if (nextTime < ctx.currentTime + 0.05) nextTime = ctx.currentTime + 0.12;
459
- const startAt = nextTime, dur = audio.duration;
460
- nextTime = startAt + Math.max(0, dur - XFADE);
461
- g.gain.setValueAtTime(0, startAt);
462
- g.gain.linearRampToValueAtTime(1, startAt + XFADE);
463
- g.gain.setValueAtTime(1, startAt + Math.max(0, dur - XFADE));
464
- g.gain.linearRampToValueAtTime(0, startAt + dur);
465
- src.start(startAt);
466
- }
467
- };
468
- }
469
- </script>
470
- ```
471
-
472
- ### Python client (async)
473
- ```python
474
- import asyncio, json, websockets, base64, soundfile as sf, io
475
- async def run(url):
476
- async with websockets.connect(url) as ws:
477
- await ws.send(json.dumps({"type":"start","mode":"rt","binary_audio":False,
478
- "params": {"styles":"warmup","temperature":1.1,"topk":40,"guidance_weight":1.1,"pace":"realtime"}}))
479
- while True:
480
- msg = json.loads(await ws.recv())
481
- if msg.get("type") == "chunk":
482
- wav = base64.b64decode(msg["audio_base64"]) # bytes of a WAV
483
- x, sr = sf.read(io.BytesIO(wav), dtype="float32")
484
- print("chunk", x.shape, sr)
485
- elif msg.get("type") in ("stopped","error"): break
486
- asyncio.run(run("wss://YOUR_SPACE/ws/jam"))
487
- ```
488
- """
489
- )
490
-
491
- # ------------------------------------------------------------------
492
- # Performance & hardware guidance
493
- # ------------------------------------------------------------------
494
- with gr.Tab("📊 Performance & Hardware"):
495
- gr.Markdown(
496
- r"""
497
- ### Current observations
498
- - **L40S 48GB** → faster than realtime. Use `pace:"realtime"` to avoid client over‑buffering.
499
- - **L4 24GB** → slightly **below** realtime even with pre‑roll buffering, TF32/Autotune, smaller chunks (`max_decode_frames`), and the **base** checkpoint.
500
-
501
- ### Practical guidance
502
- - For consistent realtime, target **~40GB VRAM per active stream** (e.g., **A100 40GB**, or MIG slices ≈ **35–40GB** on newer GPUs).
503
- - Keep client‑side **overlap‑add** (25–40 ms) for seamless chunk joins.
504
- - Prefer **`pace:"realtime"`** once playback begins; use **ASAP** only to build a short pre‑roll if needed.
505
- - Optional knob: **`max_decode_frames`** (default **50** ≈ 2.0 s). Reducing to **36–45** can lower per‑chunk latency/VRAM, but doesn’t increase frames/sec throughput.
506
-
507
- ### Concurrency
508
- This research build is designed for **one active jam per GPU**. Concurrency would require GPU partitioning (MIG) or horizontal scaling with a session scheduler.
509
- """
510
- )
511
-
512
- # ------------------------------------------------------------------
513
- # Changelog & legal
514
- # ------------------------------------------------------------------
515
- with gr.Tab("🗒️ Changelog & Legal"):
516
- gr.Markdown(
517
- r"""
518
- ### Recent changes
519
- - New **WebSocket realtime** route: `/ws/jam` (`mode:"rt"`)
520
- - Added server pacing flag: `pace: "realtime" | "asap"`
521
- - Exposed `max_decode_frames` for shorter chunks on smaller GPUs
522
- - Client test page now does proper **overlap‑add** crossfade between chunks
523
-
524
- ### Licensing
525
- This project uses MagentaRT under:
526
- - **Code:** Apache 2.0
527
- - **Model weights:** CC‑BY 4.0
528
- Please review the MagentaRT repo for full terms.
529
- """
530
- )
531
-
532
- gr.Markdown(
533
- r"""
534
- ---
535
- **🔬 Research Project** | **📱 iOS/Web Development** | **🎵 Powered by MagentaRT**
536
- """
537
- )
538
-
539
- return interface
540
-
541
  jam_registry: dict[str, JamWorker] = {}
542
  jam_lock = threading.Lock()
543
 
@@ -562,170 +328,6 @@ try:
562
  except Exception:
563
  _HAS_LOUDNORM = False
564
 
565
- # # ----------------------------
566
- # # Main generation (single combined style vector)
567
- # # ----------------------------
568
- # def generate_loop_continuation_with_mrt(
569
- # mrt,
570
- # input_wav_path: str,
571
- # bpm: float,
572
- # extra_styles=None,
573
- # style_weights=None,
574
- # bars: int = 8,
575
- # beats_per_bar: int = 4,
576
- # loop_weight: float = 1.0,
577
- # loudness_mode: str = "auto",
578
- # loudness_headroom_db: float = 1.0,
579
- # intro_bars_to_drop: int = 0, # <— NEW
580
- # ):
581
- # # Load & prep (unchanged)
582
- # loop = au.Waveform.from_file(input_wav_path).resample(mrt.sample_rate).as_stereo()
583
-
584
- # # Use tail for context (your recent change)
585
- # codec_fps = float(mrt.codec.frame_rate)
586
- # ctx_seconds = float(mrt.config.context_length_frames) / codec_fps
587
- # loop_for_context = take_bar_aligned_tail(loop, bpm, beats_per_bar, ctx_seconds)
588
-
589
- # tokens_full = mrt.codec.encode(loop_for_context).astype(np.int32)
590
- # tokens = tokens_full[:, :mrt.config.decoder_codec_rvq_depth]
591
-
592
- # # Bar-aligned token window (unchanged)
593
- # context_tokens = make_bar_aligned_context(
594
- # tokens, bpm=bpm, fps=float(mrt.codec.frame_rate),
595
- # ctx_frames=mrt.config.context_length_frames, beats_per_bar=beats_per_bar
596
- # )
597
- # state = mrt.init_state()
598
- # state.context_tokens = context_tokens
599
-
600
- # # STYLE embed (optional: switch to loop_for_context if you want stronger “recent” bias)
601
- # loop_embed = mrt.embed_style(loop_for_context)
602
- # embeds, weights = [loop_embed], [float(loop_weight)]
603
- # if extra_styles:
604
- # for i, s in enumerate(extra_styles):
605
- # if s.strip():
606
- # embeds.append(mrt.embed_style(s.strip()))
607
- # w = style_weights[i] if (style_weights and i < len(style_weights)) else 1.0
608
- # weights.append(float(w))
609
- # wsum = float(sum(weights)) or 1.0
610
- # weights = [w / wsum for w in weights]
611
- # combined_style = np.sum([w * e for w, e in zip(weights, embeds)], axis=0).astype(loop_embed.dtype)
612
-
613
- # # --- Length math ---
614
- # seconds_per_bar = beats_per_bar * (60.0 / bpm)
615
- # total_secs = bars * seconds_per_bar
616
- # drop_bars = max(0, int(intro_bars_to_drop))
617
- # drop_secs = min(drop_bars, bars) * seconds_per_bar # clamp to <= bars
618
- # gen_total_secs = total_secs + drop_secs # generate extra
619
-
620
- # # Chunk scheduling to cover gen_total_secs
621
- # chunk_secs = mrt.config.chunk_length_frames * mrt.config.frame_length_samples / mrt.sample_rate # ~2.0
622
- # steps = int(math.ceil(gen_total_secs / chunk_secs)) + 1 # pad then trim
623
-
624
- # # Generate
625
- # chunks = []
626
- # for _ in range(steps):
627
- # wav, state = mrt.generate_chunk(state=state, style=combined_style)
628
- # chunks.append(wav)
629
-
630
- # # Stitch continuous audio
631
- # stitched = stitch_generated(chunks, mrt.sample_rate, mrt.config.crossfade_length).as_stereo()
632
-
633
- # # Trim to generated length (bars + dropped bars)
634
- # stitched = hard_trim_seconds(stitched, gen_total_secs)
635
-
636
- # # 👉 Drop the intro bars
637
- # if drop_secs > 0:
638
- # n_drop = int(round(drop_secs * stitched.sample_rate))
639
- # stitched = au.Waveform(stitched.samples[n_drop:], stitched.sample_rate)
640
-
641
- # # Final exact-length trim to requested bars
642
- # out = hard_trim_seconds(stitched, total_secs)
643
-
644
- # # Final polish AFTER drop
645
- # out = out.peak_normalize(0.95)
646
- # apply_micro_fades(out, 5)
647
-
648
- # # Loudness match to input (after drop) so bar 1 sits right
649
- # out, loud_stats = match_loudness_to_reference(
650
- # ref=loop, target=out,
651
- # method=loudness_mode, headroom_db=loudness_headroom_db
652
- # )
653
-
654
- # return out, loud_stats
655
-
656
- # # untested.
657
- # # not sure how it will retain the input bpm. we may want to use a metronome instead of silence. i think google might do that.
658
- # # does a generation with silent context rather than a combined loop
659
- # def generate_style_only_with_mrt(
660
- # mrt,
661
- # bpm: float,
662
- # bars: int = 8,
663
- # beats_per_bar: int = 4,
664
- # styles: str = "warmup",
665
- # style_weights: str = "",
666
- # intro_bars_to_drop: int = 0,
667
- # ):
668
- # """
669
- # Style-only, bar-aligned generation using a silent context (no input audio).
670
- # Returns: (au.Waveform out, dict loud_stats_or_None)
671
- # """
672
- # # ---- Build a 10s silent context, tokenized for the model ----
673
- # codec_fps = float(mrt.codec.frame_rate)
674
- # ctx_seconds = float(mrt.config.context_length_frames) / codec_fps
675
- # sr = int(mrt.sample_rate)
676
-
677
- # silent = au.Waveform(np.zeros((int(round(ctx_seconds * sr)), 2), np.float32), sr)
678
- # tokens_full = mrt.codec.encode(silent).astype(np.int32)
679
- # tokens = tokens_full[:, :mrt.config.decoder_codec_rvq_depth]
680
-
681
- # state = mrt.init_state()
682
- # state.context_tokens = tokens
683
-
684
- # # ---- Style vector (text prompts only, normalized weights) ----
685
- # prompts = [s.strip() for s in (styles.split(",") if styles else []) if s.strip()]
686
- # if not prompts:
687
- # prompts = ["warmup"]
688
- # sw = [float(x) for x in style_weights.split(",")] if style_weights else []
689
- # embeds, weights = [], []
690
- # for i, p in enumerate(prompts):
691
- # embeds.append(mrt.embed_style(p))
692
- # weights.append(sw[i] if i < len(sw) else 1.0)
693
- # wsum = float(sum(weights)) or 1.0
694
- # weights = [w / wsum for w in weights]
695
- # style_vec = np.sum([w * e for w, e in zip(weights, embeds)], axis=0).astype(np.float32)
696
-
697
- # # ---- Target length math ----
698
- # seconds_per_bar = beats_per_bar * (60.0 / bpm)
699
- # total_secs = bars * seconds_per_bar
700
- # drop_bars = max(0, int(intro_bars_to_drop))
701
- # drop_secs = min(drop_bars, bars) * seconds_per_bar
702
- # gen_total_secs = total_secs + drop_secs
703
-
704
- # # ~2.0s chunk length from model config
705
- # chunk_secs = (mrt.config.chunk_length_frames * mrt.config.frame_length_samples) / float(mrt.sample_rate)
706
-
707
- # # Generate enough chunks to cover total, plus a pad chunk for crossfade headroom
708
- # steps = int(math.ceil(gen_total_secs / chunk_secs)) + 1
709
-
710
- # chunks = []
711
- # for _ in range(steps):
712
- # wav, state = mrt.generate_chunk(state=state, style=style_vec)
713
- # chunks.append(wav)
714
-
715
- # # Stitch & trim to exact musical length
716
- # stitched = stitch_generated(chunks, mrt.sample_rate, mrt.config.crossfade_length).as_stereo()
717
- # stitched = hard_trim_seconds(stitched, gen_total_secs)
718
-
719
- # if drop_secs > 0:
720
- # n_drop = int(round(drop_secs * stitched.sample_rate))
721
- # stitched = au.Waveform(stitched.samples[n_drop:], stitched.sample_rate)
722
-
723
- # out = hard_trim_seconds(stitched, total_secs)
724
- # out = out.peak_normalize(0.95)
725
- # apply_micro_fades(out, 5)
726
-
727
- # return out, None # loudness stats not applicable (no reference)
728
-
729
  def _combine_styles(mrt, styles_str: str = "", weights_str: str = ""):
730
  extra = [s.strip() for s in (styles_str or "").split(",") if s.strip()]
731
  if not extra:
@@ -836,12 +438,13 @@ def get_mrt():
836
  if _MRT is None:
837
  with _MRT_LOCK:
838
  if _MRT is None:
839
- ckpt_dir = _resolve_checkpoint_dir() # ← points to checkpoint_1863001
 
840
  _MRT = system.MagentaRT(
841
- tag=os.getenv("MRT_SIZE", "large"), # keep 'large' if finetuned from large
842
  guidance_weight=5.0,
843
  device="gpu",
844
- checkpoint_dir=ckpt_dir, # ← uses your finetune
845
  lazy=False,
846
  )
847
  return _MRT
@@ -948,7 +551,12 @@ def model_swap(step: int = Form(...)):
948
 
949
  @app.post("/model/assets/load")
950
  def model_assets_load(repo_id: str = Form(None)):
951
- ok, msg = _load_finetune_assets_from_hf(repo_id)
 
 
 
 
 
952
  return {"ok": ok, "message": msg, "repo_id": _ASSETS_REPO_ID,
953
  "mean": _MEAN_EMBED is not None,
954
  "centroids": None if _CENTROIDS is None else int(_CENTROIDS.shape[0])}
@@ -987,15 +595,14 @@ def model_config():
987
  step = os.getenv("MRT_CKPT_STEP")
988
  assets = os.getenv("MRT_ASSETS_REPO")
989
 
990
- # Best-effort local cache probe (no network)
991
- def _local_ckpt_dir(step_str: str | None) -> str | None:
992
- if not step_str:
993
- return None
994
  try:
995
  from pathlib import Path
996
  import re
997
- step = re.escape(str(step_str))
998
- candidates: list[str] = []
999
  for root in ("/home/appuser/.cache/mrt_ckpt/extracted",
1000
  "/home/appuser/.cache/mrt_ckpt/repo"):
1001
  p = Path(root)
@@ -1005,11 +612,9 @@ def model_config():
1005
  for d in p.rglob(f"checkpoint_{step}"):
1006
  if d.is_dir():
1007
  candidates.append(str(d))
1008
- return candidates[0] if candidates else None
1009
  except Exception:
1010
- return None
1011
-
1012
- local_ckpt = _local_ckpt_dir(step)
1013
 
1014
  return {
1015
  "size": size,
@@ -1032,160 +637,89 @@ def model_config():
1032
 
1033
  @app.get("/model/checkpoints")
1034
  def model_checkpoints(repo_id: str, revision: str = "main"):
1035
- steps = _list_ckpt_steps(repo_id, revision)
1036
  return {"repo": repo_id, "revision": revision, "steps": steps, "latest": (steps[-1] if steps else None)}
1037
 
1038
- class ModelSelect(BaseModel):
1039
- size: Optional[Literal["base","large"]] = None
1040
- repo_id: Optional[str] = None
1041
- revision: Optional[str] = "main"
1042
- step: Optional[Union[int, str]] = None # allow "latest"
1043
- assets_repo_id: Optional[str] = None # default: follow repo_id
1044
- sync_assets: bool = True # load mean/centroids from repo
1045
- prewarm: bool = False # call get_mrt() to build right away
1046
- stop_active: bool = True # auto-stop jams; else 409
1047
- dry_run: bool = False # validate only, don't swap
1048
 
1049
  @app.post("/model/select")
1050
  def model_select(req: ModelSelect):
1051
- # --- Current env defaults ---
1052
- global _MRT
1053
- cur = {
1054
- "size": os.getenv("MRT_SIZE", "large"),
1055
- "repo": os.getenv("MRT_CKPT_REPO"),
1056
- "rev": os.getenv("MRT_CKPT_REV", "main"),
1057
- "step": os.getenv("MRT_CKPT_STEP"),
1058
- "assets": os.getenv("MRT_ASSETS_REPO", _FINETUNE_REPO_DEFAULT),
1059
- }
1060
-
1061
- # --- Flags for special step values ---
1062
- no_ckpt = isinstance(req.step, str) and req.step.lower() == "none"
1063
- latest = isinstance(req.step, str) and req.step.lower() == "latest"
1064
-
1065
- # --- Target selection (do not require repo when no_ckpt) ---
1066
- tgt = {
1067
- "size": (req.size or cur["size"]),
1068
- "repo": (None if no_ckpt else (req.repo_id or cur["repo"])),
1069
- "rev": (req.revision if req.revision is not None else cur["rev"]),
1070
- # None => resolve to "latest" below. Keep None for no_ckpt as well.
1071
- "step": (None if (no_ckpt or latest) else (str(req.step) if req.step is not None else cur["step"])),
1072
- "assets": (req.assets_repo_id or req.repo_id or cur["assets"]),
1073
- }
1074
-
1075
- # ---------- CASE 1: run with NO FINETUNE (stock base/large) ----------
1076
- if no_ckpt:
1077
- preview = {
1078
- "target_size": tgt["size"],
1079
- "target_repo": None,
1080
- "target_revision": None,
1081
- "target_step": None,
1082
- "assets_repo": None,
1083
- "assets_probe": {"ok": True, "message": "skipped"},
1084
- "active_jam": _any_jam_running(),
1085
- }
1086
- if req.dry_run:
1087
- return {"ok": True, "dry_run": True, **preview}
1088
-
1089
- # Jam policy
1090
- if _any_jam_running():
1091
- if req.stop_active:
1092
- _stop_all_jams()
1093
- else:
1094
- raise HTTPException(status_code=409, detail="A jam is running; retry with stop_active=true")
1095
-
1096
- # Clear checkpoint + asset env so get_mrt() uses stock weights
1097
- for k in ("MRT_CKPT_REPO", "MRT_CKPT_REV", "MRT_CKPT_STEP", "MRT_ASSETS_REPO"):
1098
- os.environ.pop(k, None)
1099
- os.environ["MRT_SIZE"] = str(tgt["size"])
1100
-
1101
- # Rebuild model and optionally prewarm
1102
-
1103
- with _MRT_LOCK:
1104
- _MRT = None
1105
- if req.prewarm:
1106
- get_mrt()
1107
-
1108
- return {"ok": True, **preview}
1109
-
1110
- # ---------- CASE 2: select a repo + step (supports "latest") ----------
1111
- if not tgt["repo"]:
1112
- raise HTTPException(status_code=400, detail="repo_id is required for model selection.")
1113
-
1114
- # 1) enumerate available steps
1115
- steps = _list_ckpt_steps(tgt["repo"], tgt["rev"])
1116
- if not steps:
1117
- return {"ok": False, "error": f"No checkpoint files found in {tgt['repo']}@{tgt['rev']}", "discovered_steps": steps}
1118
-
1119
- # 2) choose step (explicit or latest)
1120
- chosen_step = int(tgt["step"]) if tgt["step"] is not None else steps[-1]
1121
- if chosen_step not in steps:
1122
- return {"ok": False, "error": f"checkpoint_{chosen_step} not present in {tgt['repo']}@{tgt['rev']}", "discovered_steps": steps}
1123
-
1124
- # 3) optional finetune assets probe (no downloads, just listing)
1125
- assets_ok, assets_msg = True, "skipped"
1126
- if req.sync_assets:
1127
- try:
1128
- api = HfApi()
1129
- files = set(api.list_repo_files(repo_id=tgt["assets"], repo_type="model"))
1130
- if ("mean_style_embed.npy" not in files) and ("cluster_centroids.npy" not in files):
1131
- assets_ok, assets_msg = False, f"No finetune asset files in {tgt['assets']}"
1132
- else:
1133
- assets_msg = "found"
1134
- except Exception as e:
1135
- assets_ok, assets_msg = False, f"probe failed: {e}"
1136
-
1137
- preview = {
1138
- "target_size": tgt["size"],
1139
- "target_repo": tgt["repo"],
1140
- "target_revision": tgt["rev"],
1141
- "target_step": chosen_step,
1142
- "assets_repo": (tgt["assets"] if req.sync_assets else None),
1143
- "assets_probe": {"ok": assets_ok, "message": assets_msg},
1144
- "active_jam": _any_jam_running(),
1145
- }
1146
-
1147
  if req.dry_run:
1148
- return {"ok": True, "dry_run": True, **preview}
1149
 
1150
- # Jam policy
1151
  if _any_jam_running():
1152
  if req.stop_active:
1153
  _stop_all_jams()
1154
  else:
1155
  raise HTTPException(status_code=409, detail="A jam is running; retry with stop_active=true")
1156
 
1157
- # 4) atomic swap with rollback
 
 
 
1158
  old_env = {
1159
- "MRT_SIZE": os.getenv("MRT_SIZE"),
1160
- "MRT_CKPT_REPO": os.getenv("MRT_CKPT_REPO"),
1161
- "MRT_CKPT_REV": os.getenv("MRT_CKPT_REV"),
1162
- "MRT_CKPT_STEP": os.getenv("MRT_CKPT_STEP"),
1163
- "MRT_ASSETS_REPO": os.getenv("MRT_ASSETS_REPO"),
1164
  }
 
1165
  try:
1166
- os.environ["MRT_SIZE"] = str(tgt["size"])
1167
- os.environ["MRT_CKPT_REPO"] = str(tgt["repo"])
1168
- os.environ["MRT_CKPT_REV"] = str(tgt["rev"])
1169
- os.environ["MRT_CKPT_STEP"] = str(chosen_step)
1170
- if req.sync_assets:
1171
- os.environ["MRT_ASSETS_REPO"] = str(tgt["assets"])
1172
-
1173
- # force rebuild
1174
-
1175
  with _MRT_LOCK:
1176
  _MRT = None
1177
 
1178
- # optionally load finetune assets now
1179
- if req.sync_assets:
1180
- _load_finetune_assets_from_hf(os.getenv("MRT_ASSETS_REPO"))
1181
-
1182
- # optional prewarm to amortize JIT
 
 
 
 
 
 
 
 
1183
  if req.prewarm:
1184
  get_mrt()
1185
 
1186
- return {"ok": True, **preview}
 
1187
  except Exception as e:
1188
- # rollback on error
1189
  for k, v in old_env.items():
1190
  if v is None:
1191
  os.environ.pop(k, None)
@@ -1193,6 +727,7 @@ def model_select(req: ModelSelect):
1193
  os.environ[k] = v
1194
  with _MRT_LOCK:
1195
  _MRT = None
 
1196
  try:
1197
  get_mrt()
1198
  except Exception:
@@ -1379,7 +914,7 @@ def jam_start(
1379
  topk: int = Form(40),
1380
  target_sample_rate: int | None = Form(None),
1381
  ):
1382
- _ensure_assets_loaded()
1383
 
1384
  # enforce single active jam per GPU
1385
  with jam_lock:
@@ -1534,7 +1069,7 @@ def jam_update(
1534
  mean: Optional[float] = Form(None),
1535
  centroid_weights: str = Form(""),
1536
  ):
1537
- _ensure_assets_loaded()
1538
 
1539
  with jam_lock:
1540
  worker = jam_registry.get(session_id)
@@ -1842,7 +1377,7 @@ async def ws_jam(websocket: WebSocket):
1842
  state.context_tokens = tokens
1843
 
1844
  # Parse params (including steering)
1845
- _ensure_assets_loaded()
1846
  styles_str = params.get("styles", "warmup") or ""
1847
  style_weights_str = params.get("style_weights", "") or ""
1848
  mean_w = float(params.get("mean", 0.0) or 0.0)
@@ -2009,7 +1544,7 @@ async def ws_jam(websocket: WebSocket):
2009
  text_list = [s for s in (styles_str.split(",") if styles_str else []) if s.strip()]
2010
  text_w = [float(x) for x in style_weights_str.split(",")] if style_weights_str else []
2011
 
2012
- _ensure_assets_loaded()
2013
  websocket._style_tgt = build_style_vector(
2014
  websocket._mrt,
2015
  text_styles=text_list,
@@ -2116,39 +1651,4 @@ def read_root():
2116
  <p>Documentation file not found. Please check documentation.html</p>
2117
  </body></html>
2118
  """
2119
- return Response(content=html_content, media_type="text/html")
2120
-
2121
- def load_doc_content(filename: str) -> str:
2122
- """Load markdown content from docs directory, with fallback."""
2123
- try:
2124
- doc_path = Path(__file__).parent / "docs" / filename
2125
- return doc_path.read_text(encoding='utf-8')
2126
- except FileNotFoundError:
2127
- return f"⚠️ Documentation file `{filename}` not found. Please check the docs directory."
2128
- except Exception as e:
2129
- return f"⚠️ Error loading `{filename}`: {e}"
2130
-
2131
- @app.get("/documentation")
2132
- def documentation():
2133
- # Just return a simple combined markdown page
2134
- all_content = f"""
2135
- # MagentaRT Documentation
2136
-
2137
- ## About & Status
2138
- {load_doc_content("about_status.md")}
2139
-
2140
- ## HTTP API
2141
- {load_doc_content("api_http.md")}
2142
-
2143
- ## WebSocket API
2144
- {load_doc_content("api_websocket.md")}
2145
-
2146
- ## Performance
2147
- {load_doc_content("performance.md")}
2148
-
2149
- ## Changelog
2150
- {load_doc_content("changelog.md")}
2151
- """
2152
-
2153
- # Convert markdown to HTML if you want, or just serve as plain text
2154
- return Response(content=all_content, media_type="text/plain")
 
74
 
75
  from pydantic import BaseModel
76
 
77
+ from model_management import CheckpointManager, AssetManager, ModelSelector, ModelSelect
78
+
79
  # ---- Finetune assets (mean & centroids) --------------------------------------
80
  _FINETUNE_REPO_DEFAULT = os.getenv("MRT_ASSETS_REPO", "thepatch/magenta-ft")
81
  _ASSETS_REPO_ID: str | None = None
 
84
 
85
  _STEP_RE = re.compile(r"(?:^|/)checkpoint_(\d+)(?:/|\.tar\.gz|\.tgz)?$")
86
 
87
+ # Create instances (these don't modify globals)
88
+ asset_manager = AssetManager()
89
+ model_selector = ModelSelector(CheckpointManager(), asset_manager)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
 
91
+ # Sync asset manager with existing globals
92
+ def _sync_asset_manager():
93
+ asset_manager.mean_embed = _MEAN_EMBED
94
+ asset_manager.centroids = _CENTROIDS
95
+ asset_manager.assets_repo_id = _ASSETS_REPO_ID
96
+
97
+ # def _list_ckpt_steps(repo_id: str, revision: str = "main") -> list[int]:
98
+ # """
99
+ # List available checkpoint steps in a HF model repo without downloading all weights.
100
+ # Looks for:
101
+ # checkpoint_<step>/
102
+ # checkpoint_<step>.tgz | .tar.gz
103
+ # archives/checkpoint_<step>.tgz | .tar.gz
104
+ # """
105
+ # api = HfApi()
106
+ # files = api.list_repo_files(repo_id=repo_id, repo_type="model", revision=revision)
107
+ # steps = set()
108
+ # for f in files:
109
+ # m = _STEP_RE.search(f)
110
+ # if m:
111
+ # try:
112
+ # steps.add(int(m.group(1)))
113
+ # except:
114
+ # pass
115
+ # return sorted(steps)
116
+
117
+ # def _step_exists(repo_id: str, revision: str, step: int) -> bool:
118
+ # return step in _list_ckpt_steps(repo_id, revision)
119
 
120
  def _any_jam_running() -> bool:
121
  with jam_lock:
 
129
  w.join(timeout=timeout)
130
  jam_registry.pop(sid, None)
131
 
132
+ # def _load_finetune_assets_from_hf(repo_id: str | None) -> tuple[bool, str]:
133
+ # """
134
+ # Download & load mean_style_embed.npy and cluster_centroids.npy from a HF model repo.
135
+ # Safe to call multiple times; will overwrite globals if successful.
136
+ # """
137
+ # global _ASSETS_REPO_ID, _MEAN_EMBED, _CENTROIDS
138
+ # repo_id = repo_id or _FINETUNE_REPO_DEFAULT
139
+ # try:
140
+ # from huggingface_hub import hf_hub_download
141
+ # mean_path = None
142
+ # cent_path = None
143
+ # try:
144
+ # mean_path = hf_hub_download(repo_id, filename="mean_style_embed.npy", repo_type="model")
145
+ # except Exception:
146
+ # pass
147
+ # try:
148
+ # cent_path = hf_hub_download(repo_id, filename="cluster_centroids.npy", repo_type="model")
149
+ # except Exception:
150
+ # pass
151
+
152
+ # if mean_path is None and cent_path is None:
153
+ # return False, f"No finetune asset files found in repo {repo_id}"
154
+
155
+ # if mean_path is not None:
156
+ # m = np.load(mean_path)
157
+ # if m.ndim != 1:
158
+ # return False, f"mean_style_embed.npy must be 1-D (got {m.shape})"
159
+ # else:
160
+ # m = None
161
+
162
+ # if cent_path is not None:
163
+ # c = np.load(cent_path)
164
+ # if c.ndim != 2:
165
+ # return False, f"cluster_centroids.npy must be 2-D (got {c.shape})"
166
+ # else:
167
+ # c = None
168
+
169
+ # # Optional: shape check vs model embedding dim once model is alive
170
+ # try:
171
+ # d = int(get_mrt().style_model.config.embedding_dim)
172
+ # if m is not None and m.shape[0] != d:
173
+ # return False, f"mean_style_embed dim {m.shape[0]} != model dim {d}"
174
+ # if c is not None and c.shape[1] != d:
175
+ # return False, f"cluster_centroids dim {c.shape[1]} != model dim {d}"
176
+ # except Exception:
177
+ # # Model not built yet; we’ll trust the files and rely on runtime checks later
178
+ # pass
179
+
180
+ # _MEAN_EMBED = m.astype(np.float32, copy=False) if m is not None else None
181
+ # _CENTROIDS = c.astype(np.float32, copy=False) if c is not None else None
182
+ # _ASSETS_REPO_ID = repo_id
183
+ # logging.info("Loaded finetune assets from %s (mean=%s, centroids=%s)",
184
+ # repo_id,
185
+ # "yes" if _MEAN_EMBED is not None else "no",
186
+ # f"{_CENTROIDS.shape[0]}x{_CENTROIDS.shape[1]}" if _CENTROIDS is not None else "no")
187
+ # return True, "ok"
188
+ # except Exception as e:
189
+ # logging.exception("Failed to load finetune assets: %s", e)
190
+ # return False, str(e)
191
+
192
+ # def _ensure_assets_loaded():
193
+ # # Best-effort lazy load if nothing is loaded yet
194
+ # if _MEAN_EMBED is None and _CENTROIDS is None:
195
+ # _load_finetune_assets_from_hf(_ASSETS_REPO_ID or _FINETUNE_REPO_DEFAULT)
196
  # ------------------------------------------------------------------------------
197
 
198
+ # def _resolve_checkpoint_dir() -> str | None:
199
+ # repo_id = os.getenv("MRT_CKPT_REPO")
200
+ # if not repo_id:
201
+ # return None
202
+ # step = os.getenv("MRT_CKPT_STEP") # e.g. "1863001"
203
+
204
+ # root = Path(snapshot_download(
205
+ # repo_id=repo_id,
206
+ # repo_type="model",
207
+ # revision=os.getenv("MRT_CKPT_REV", "main"),
208
+ # local_dir="/home/appuser/.cache/mrt_ckpt/repo",
209
+ # local_dir_use_symlinks=False,
210
+ # ))
211
+
212
+ # # Prefer an archive if present (more reliable for Zarr/T5X)
213
+ # arch_names = [
214
+ # f"checkpoint_{step}.tgz",
215
+ # f"checkpoint_{step}.tar.gz",
216
+ # f"archives/checkpoint_{step}.tgz",
217
+ # f"archives/checkpoint_{step}.tar.gz",
218
+ # ] if step else []
219
+
220
+ # cache_root = Path("/home/appuser/.cache/mrt_ckpt/extracted")
221
+ # cache_root.mkdir(parents=True, exist_ok=True)
222
+ # for name in arch_names:
223
+ # arch = root / name
224
+ # if arch.is_file():
225
+ # out_dir = cache_root / f"checkpoint_{step}"
226
+ # marker = out_dir.with_suffix(".ok")
227
+ # if not marker.exists():
228
+ # out_dir.mkdir(parents=True, exist_ok=True)
229
+ # with tarfile.open(arch, "r:*") as tf:
230
+ # tf.extractall(out_dir)
231
+ # marker.write_text("ok")
232
+ # # sanity: require .zarray to exist inside the extracted tree
233
+ # if not any(out_dir.rglob(".zarray")):
234
+ # raise RuntimeError(f"Extracted archive missing .zarray files: {out_dir}")
235
+ # return str(out_dir / f"checkpoint_{step}") if (out_dir / f"checkpoint_{step}").exists() else str(out_dir)
236
+
237
+ # # No archive; try raw folder from repo and sanity check.
238
+ # if step:
239
+ # raw = root / f"checkpoint_{step}"
240
+ # if raw.is_dir():
241
+ # if not any(raw.rglob(".zarray")):
242
+ # raise RuntimeError(
243
+ # f"Downloaded checkpoint_{step} appears incomplete (no .zarray). "
244
+ # "Upload as a .tgz or push via git from a Unix shell."
245
+ # )
246
+ # return str(raw)
247
+
248
+ # # Pick latest if no step
249
+ # step_dirs = [d for d in root.iterdir() if d.is_dir() and re.match(r"checkpoint_\\d+$", d.name)]
250
+ # if step_dirs:
251
+ # pick = max(step_dirs, key=lambda d: int(d.name.split('_')[-1]))
252
+ # if not any(pick.rglob(".zarray")):
253
+ # raise RuntimeError(f"Downloaded {pick} appears incomplete (no .zarray).")
254
+ # return str(pick)
255
+
256
+ # return None
257
 
258
 
259
  async def send_json_safe(ws: WebSocket, obj) -> bool:
 
304
  # Call the patch immediately at import time (before MagentaRT init)
305
  _patch_t5x_for_gpu_coords()
306
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
307
  jam_registry: dict[str, JamWorker] = {}
308
  jam_lock = threading.Lock()
309
 
 
328
  except Exception:
329
  _HAS_LOUDNORM = False
330
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
331
  def _combine_styles(mrt, styles_str: str = "", weights_str: str = ""):
332
  extra = [s.strip() for s in (styles_str or "").split(",") if s.strip()]
333
  if not extra:
 
438
  if _MRT is None:
439
  with _MRT_LOCK:
440
  if _MRT is None:
441
+ from model_management import CheckpointManager
442
+ ckpt_dir = CheckpointManager.resolve_checkpoint_dir() # ← Updated call
443
  _MRT = system.MagentaRT(
444
+ tag=os.getenv("MRT_SIZE", "large"),
445
  guidance_weight=5.0,
446
  device="gpu",
447
+ checkpoint_dir=ckpt_dir,
448
  lazy=False,
449
  )
450
  return _MRT
 
551
 
552
  @app.post("/model/assets/load")
553
  def model_assets_load(repo_id: str = Form(None)):
554
+ global _MEAN_EMBED, _CENTROIDS, _ASSETS_REPO_ID
555
+ ok, msg = asset_manager.load_finetune_assets_from_hf(repo_id, get_mrt())
556
+ # Sync globals after loading
557
+ _MEAN_EMBED = asset_manager.mean_embed
558
+ _CENTROIDS = asset_manager.centroids
559
+ _ASSETS_REPO_ID = asset_manager.assets_repo_id
560
  return {"ok": ok, "message": msg, "repo_id": _ASSETS_REPO_ID,
561
  "mean": _MEAN_EMBED is not None,
562
  "centroids": None if _CENTROIDS is None else int(_CENTROIDS.shape[0])}
 
595
  step = os.getenv("MRT_CKPT_STEP")
596
  assets = os.getenv("MRT_ASSETS_REPO")
597
 
598
+ # Use CheckpointManager for local cache probe (no network)
599
+ local_ckpt = None
600
+ if step:
 
601
  try:
602
  from pathlib import Path
603
  import re
604
+ step_escaped = re.escape(str(step))
605
+ candidates = []
606
  for root in ("/home/appuser/.cache/mrt_ckpt/extracted",
607
  "/home/appuser/.cache/mrt_ckpt/repo"):
608
  p = Path(root)
 
612
  for d in p.rglob(f"checkpoint_{step}"):
613
  if d.is_dir():
614
  candidates.append(str(d))
615
+ local_ckpt = candidates[0] if candidates else None
616
  except Exception:
617
+ local_ckpt = None
 
 
618
 
619
  return {
620
  "size": size,
 
637
 
638
  @app.get("/model/checkpoints")
639
  def model_checkpoints(repo_id: str, revision: str = "main"):
640
+ steps = CheckpointManager.list_ckpt_steps(repo_id, revision)
641
  return {"repo": repo_id, "revision": revision, "steps": steps, "latest": (steps[-1] if steps else None)}
642
 
643
+ # class ModelSelect(BaseModel):
644
+ # size: Optional[Literal["base","large"]] = None
645
+ # repo_id: Optional[str] = None
646
+ # revision: Optional[str] = "main"
647
+ # step: Optional[Union[int, str]] = None # allow "latest"
648
+ # assets_repo_id: Optional[str] = None # default: follow repo_id
649
+ # sync_assets: bool = True # load mean/centroids from repo
650
+ # prewarm: bool = False # call get_mrt() to build right away
651
+ # stop_active: bool = True # auto-stop jams; else 409
652
+ # dry_run: bool = False # validate only, don't swap
653
 
654
  @app.post("/model/select")
655
  def model_select(req: ModelSelect):
656
+ global _MRT, _MEAN_EMBED, _CENTROIDS, _ASSETS_REPO_ID
657
+
658
+ # Use ModelSelector to validate the request
659
+ success, validation_result = model_selector.validate_selection(req)
660
+ if not success:
661
+ if "error" in validation_result:
662
+ raise HTTPException(status_code=400, detail=validation_result["error"])
663
+ return {"ok": False, **validation_result}
664
+
665
+ # Add active_jam status to the validation result
666
+ validation_result["active_jam"] = _any_jam_running()
667
+
668
+ # If dry run, return the validation result
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
669
  if req.dry_run:
670
+ return {"ok": True, "dry_run": True, **validation_result}
671
 
672
+ # Handle jam policy
673
  if _any_jam_running():
674
  if req.stop_active:
675
  _stop_all_jams()
676
  else:
677
  raise HTTPException(status_code=409, detail="A jam is running; retry with stop_active=true")
678
 
679
+ # Prepare environment changes
680
+ env_changes = model_selector.prepare_env_changes(req, validation_result)
681
+
682
+ # Save current environment for rollback
683
  old_env = {
684
+ "MRT_SIZE": os.getenv("MRT_SIZE"),
685
+ "MRT_CKPT_REPO": os.getenv("MRT_CKPT_REPO"),
686
+ "MRT_CKPT_REV": os.getenv("MRT_CKPT_REV"),
687
+ "MRT_CKPT_STEP": os.getenv("MRT_CKPT_STEP"),
688
+ "MRT_ASSETS_REPO": os.getenv("MRT_ASSETS_REPO"),
689
  }
690
+
691
  try:
692
+ # Apply environment changes atomically
693
+ for key, value in env_changes.items():
694
+ if value is None:
695
+ os.environ.pop(key, None)
696
+ else:
697
+ os.environ[key] = str(value)
698
+
699
+ # Force model rebuild
 
700
  with _MRT_LOCK:
701
  _MRT = None
702
 
703
+ # Load finetune assets if requested
704
+ if req.sync_assets and validation_result.get("assets_repo"):
705
+ ok, msg = asset_manager.load_finetune_assets_from_hf(
706
+ validation_result["assets_repo"],
707
+ get_mrt() if req.prewarm else None
708
+ )
709
+ if ok:
710
+ # Sync globals after successful asset loading
711
+ _MEAN_EMBED = asset_manager.mean_embed
712
+ _CENTROIDS = asset_manager.centroids
713
+ _ASSETS_REPO_ID = asset_manager.assets_repo_id
714
+
715
+ # Optional prewarm to amortize JIT
716
  if req.prewarm:
717
  get_mrt()
718
 
719
+ return {"ok": True, **validation_result}
720
+
721
  except Exception as e:
722
+ # Rollback on error
723
  for k, v in old_env.items():
724
  if v is None:
725
  os.environ.pop(k, None)
 
727
  os.environ[k] = v
728
  with _MRT_LOCK:
729
  _MRT = None
730
+ # Try to restore working state
731
  try:
732
  get_mrt()
733
  except Exception:
 
914
  topk: int = Form(40),
915
  target_sample_rate: int | None = Form(None),
916
  ):
917
+ asset_manager.ensure_assets_loaded(get_mrt())
918
 
919
  # enforce single active jam per GPU
920
  with jam_lock:
 
1069
  mean: Optional[float] = Form(None),
1070
  centroid_weights: str = Form(""),
1071
  ):
1072
+ asset_manager.ensure_assets_loaded(get_mrt())
1073
 
1074
  with jam_lock:
1075
  worker = jam_registry.get(session_id)
 
1377
  state.context_tokens = tokens
1378
 
1379
  # Parse params (including steering)
1380
+ asset_manager.ensure_assets_loaded(get_mrt())
1381
  styles_str = params.get("styles", "warmup") or ""
1382
  style_weights_str = params.get("style_weights", "") or ""
1383
  mean_w = float(params.get("mean", 0.0) or 0.0)
 
1544
  text_list = [s for s in (styles_str.split(",") if styles_str else []) if s.strip()]
1545
  text_w = [float(x) for x in style_weights_str.split(",")] if style_weights_str else []
1546
 
1547
+ asset_manager.ensure_assets_loaded(get_mrt())
1548
  websocket._style_tgt = build_style_vector(
1549
  websocket._mrt,
1550
  text_styles=text_list,
 
1651
  <p>Documentation file not found. Please check documentation.html</p>
1652
  </body></html>
1653
  """
1654
+ return Response(content=html_content, media_type="text/html")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
model_management.py ADDED
@@ -0,0 +1,374 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # model_management.py
2
+ """
3
+ Model management utilities for MagentaRT API.
4
+
5
+ This module handles checkpoint discovery, asset loading, and model selection logic.
6
+ It is designed to work with the global state managed in app.py without interfering
7
+ with the critical JAX/XLA initialization sequence.
8
+ """
9
+
10
+ import os
11
+ import re
12
+ import logging
13
+ from pathlib import Path
14
+ from typing import Optional, Union, Literal, Tuple, List
15
+ import tarfile
16
+
17
+ import numpy as np
18
+ from pydantic import BaseModel
19
+ from huggingface_hub import snapshot_download, HfApi, hf_hub_download
20
+
21
+
22
+ # ---- Constants and Patterns ----
23
+ _FINETUNE_REPO_DEFAULT = os.getenv("MRT_ASSETS_REPO", "thepatch/magenta-ft")
24
+ _STEP_RE = re.compile(r"(?:^|/)checkpoint_(\d+)(?:/|\.tar\.gz|\.tgz)?$")
25
+
26
+
27
+ # ---- Pydantic Models ----
28
+ class ModelSelect(BaseModel):
29
+ size: Optional[Literal["base","large"]] = None
30
+ repo_id: Optional[str] = None
31
+ revision: Optional[str] = "main"
32
+ step: Optional[Union[int, str]] = None # allow "latest"
33
+ assets_repo_id: Optional[str] = None # default: follow repo_id
34
+ sync_assets: bool = True # load mean/centroids from repo
35
+ prewarm: bool = False # call get_mrt() to build right away
36
+ stop_active: bool = True # auto-stop jams; else 409
37
+ dry_run: bool = False # validate only, don't swap
38
+
39
+
40
+ # ---- Checkpoint Discovery ----
41
+ class CheckpointManager:
42
+ """Handles checkpoint discovery and validation without modifying global state."""
43
+
44
+ @staticmethod
45
+ def list_ckpt_steps(repo_id: str, revision: str = "main") -> List[int]:
46
+ """
47
+ List available checkpoint steps in a HF model repo without downloading all weights.
48
+ Looks for:
49
+ checkpoint_<step>/
50
+ checkpoint_<step>.tgz | .tar.gz
51
+ archives/checkpoint_<step>.tgz | .tar.gz
52
+ """
53
+ api = HfApi()
54
+ files = api.list_repo_files(repo_id=repo_id, repo_type="model", revision=revision)
55
+ steps = set()
56
+ for f in files:
57
+ m = _STEP_RE.search(f)
58
+ if m:
59
+ try:
60
+ steps.add(int(m.group(1)))
61
+ except:
62
+ pass
63
+ return sorted(steps)
64
+
65
+ @staticmethod
66
+ def step_exists(repo_id: str, revision: str, step: int) -> bool:
67
+ """Check if a specific checkpoint step exists in the repo."""
68
+ return step in CheckpointManager.list_ckpt_steps(repo_id, revision)
69
+
70
+ @staticmethod
71
+ def resolve_checkpoint_dir() -> Optional[str]:
72
+ """
73
+ Resolve the checkpoint directory from environment variables.
74
+ Downloads and extracts if necessary.
75
+ Returns the path to the checkpoint directory or None if not configured.
76
+ """
77
+ repo_id = os.getenv("MRT_CKPT_REPO")
78
+ if not repo_id:
79
+ return None
80
+ step = os.getenv("MRT_CKPT_STEP") # e.g. "1863001"
81
+
82
+ root = Path(snapshot_download(
83
+ repo_id=repo_id,
84
+ repo_type="model",
85
+ revision=os.getenv("MRT_CKPT_REV", "main"),
86
+ local_dir="/home/appuser/.cache/mrt_ckpt/repo",
87
+ local_dir_use_symlinks=False,
88
+ ))
89
+
90
+ # Prefer an archive if present (more reliable for Zarr/T5X)
91
+ arch_names = [
92
+ f"checkpoint_{step}.tgz",
93
+ f"checkpoint_{step}.tar.gz",
94
+ f"archives/checkpoint_{step}.tgz",
95
+ f"archives/checkpoint_{step}.tar.gz",
96
+ ] if step else []
97
+
98
+ cache_root = Path("/home/appuser/.cache/mrt_ckpt/extracted")
99
+ cache_root.mkdir(parents=True, exist_ok=True)
100
+ for name in arch_names:
101
+ arch = root / name
102
+ if arch.is_file():
103
+ out_dir = cache_root / f"checkpoint_{step}"
104
+ marker = out_dir.with_suffix(".ok")
105
+ if not marker.exists():
106
+ out_dir.mkdir(parents=True, exist_ok=True)
107
+ with tarfile.open(arch, "r:*") as tf:
108
+ tf.extractall(out_dir)
109
+ marker.write_text("ok")
110
+ # sanity: require .zarray to exist inside the extracted tree
111
+ if not any(out_dir.rglob(".zarray")):
112
+ raise RuntimeError(f"Extracted archive missing .zarray files: {out_dir}")
113
+ return str(out_dir / f"checkpoint_{step}") if (out_dir / f"checkpoint_{step}").exists() else str(out_dir)
114
+
115
+ # No archive; try raw folder from repo and sanity check.
116
+ if step:
117
+ raw = root / f"checkpoint_{step}"
118
+ if raw.is_dir():
119
+ if not any(raw.rglob(".zarray")):
120
+ raise RuntimeError(
121
+ f"Downloaded checkpoint_{step} appears incomplete (no .zarray). "
122
+ "Upload as a .tgz or push via git from a Unix shell."
123
+ )
124
+ return str(raw)
125
+
126
+ # Pick latest if no step
127
+ step_dirs = [d for d in root.iterdir() if d.is_dir() and re.match(r"checkpoint_\d+$", d.name)]
128
+ if step_dirs:
129
+ pick = max(step_dirs, key=lambda d: int(d.name.split('_')[-1]))
130
+ if not any(pick.rglob(".zarray")):
131
+ raise RuntimeError(f"Downloaded {pick} appears incomplete (no .zarray).")
132
+ return str(pick)
133
+
134
+ return None
135
+
136
+
137
+ # ---- Asset Management ----
138
+ class AssetManager:
139
+ """
140
+ Handles finetune asset loading and management.
141
+
142
+ This class modifies global variables in the calling module, but encapsulates
143
+ the logic for loading and validating assets.
144
+ """
145
+
146
+ def __init__(self):
147
+ # These will be set by the calling module
148
+ self.mean_embed = None
149
+ self.centroids = None
150
+ self.assets_repo_id = None
151
+
152
+ def load_finetune_assets_from_hf(self, repo_id: Optional[str], mrt=None) -> Tuple[bool, str]:
153
+ """
154
+ Download & load mean_style_embed.npy and cluster_centroids.npy from a HF model repo.
155
+ Safe to call multiple times; will overwrite instance vars if successful.
156
+
157
+ Args:
158
+ repo_id: HuggingFace repo ID, defaults to _FINETUNE_REPO_DEFAULT
159
+ mrt: MagentaRT instance for dimension validation (optional)
160
+
161
+ Returns:
162
+ Tuple of (success: bool, message: str)
163
+ """
164
+ repo_id = repo_id or _FINETUNE_REPO_DEFAULT
165
+ try:
166
+ mean_path = None
167
+ cent_path = None
168
+ try:
169
+ mean_path = hf_hub_download(repo_id, filename="mean_style_embed.npy", repo_type="model")
170
+ except Exception:
171
+ pass
172
+ try:
173
+ cent_path = hf_hub_download(repo_id, filename="cluster_centroids.npy", repo_type="model")
174
+ except Exception:
175
+ pass
176
+
177
+ if mean_path is None and cent_path is None:
178
+ return False, f"No finetune asset files found in repo {repo_id}"
179
+
180
+ if mean_path is not None:
181
+ m = np.load(mean_path)
182
+ if m.ndim != 1:
183
+ return False, f"mean_style_embed.npy must be 1-D (got {m.shape})"
184
+ else:
185
+ m = None
186
+
187
+ if cent_path is not None:
188
+ c = np.load(cent_path)
189
+ if c.ndim != 2:
190
+ return False, f"cluster_centroids.npy must be 2-D (got {c.shape})"
191
+ else:
192
+ c = None
193
+
194
+ # Optional: shape check vs model embedding dim once model is alive
195
+ if mrt is not None:
196
+ try:
197
+ d = int(mrt.style_model.config.embedding_dim)
198
+ if m is not None and m.shape[0] != d:
199
+ return False, f"mean_style_embed dim {m.shape[0]} != model dim {d}"
200
+ if c is not None and c.shape[1] != d:
201
+ return False, f"cluster_centroids dim {c.shape[1]} != model dim {d}"
202
+ except Exception:
203
+ # Model not built yet; we'll trust the files and rely on runtime checks later
204
+ pass
205
+
206
+ # Update instance variables
207
+ self.mean_embed = m.astype(np.float32, copy=False) if m is not None else None
208
+ self.centroids = c.astype(np.float32, copy=False) if c is not None else None
209
+ self.assets_repo_id = repo_id
210
+
211
+ logging.info("Loaded finetune assets from %s (mean=%s, centroids=%s)",
212
+ repo_id,
213
+ "yes" if self.mean_embed is not None else "no",
214
+ f"{self.centroids.shape[0]}x{self.centroids.shape[1]}" if self.centroids is not None else "no")
215
+ return True, "ok"
216
+ except Exception as e:
217
+ logging.exception("Failed to load finetune assets: %s", e)
218
+ return False, str(e)
219
+
220
+ def ensure_assets_loaded(self, mrt=None):
221
+ """Best-effort lazy load if nothing is loaded yet."""
222
+ if self.mean_embed is None and self.centroids is None:
223
+ self.load_finetune_assets_from_hf(self.assets_repo_id or _FINETUNE_REPO_DEFAULT, mrt)
224
+
225
+ def get_status(self, mrt=None) -> dict:
226
+ """Get current asset status."""
227
+ d = None
228
+ if mrt is not None:
229
+ try:
230
+ d = int(mrt.style_model.config.embedding_dim)
231
+ except Exception:
232
+ pass
233
+
234
+ return {
235
+ "repo_id": self.assets_repo_id,
236
+ "mean_loaded": self.mean_embed is not None,
237
+ "centroids_loaded": self.centroids is not None,
238
+ "centroid_count": None if self.centroids is None else int(self.centroids.shape[0]),
239
+ "embedding_dim": d,
240
+ }
241
+
242
+
243
+ # ---- Model Selection Logic ----
244
+ class ModelSelector:
245
+ """
246
+ Handles model selection and validation logic.
247
+
248
+ This class encapsulates the complex logic from the /model/select endpoint
249
+ while keeping environment variable management in the calling code.
250
+ """
251
+
252
+ def __init__(self, checkpoint_manager: CheckpointManager, asset_manager: AssetManager):
253
+ self.checkpoint_manager = checkpoint_manager
254
+ self.asset_manager = asset_manager
255
+
256
+ def validate_selection(self, req: ModelSelect) -> Tuple[bool, dict]:
257
+ """
258
+ Validate a model selection request without making any changes.
259
+
260
+ Returns:
261
+ Tuple of (success: bool, result_dict: dict)
262
+ """
263
+ # Current env defaults
264
+ cur = {
265
+ "size": os.getenv("MRT_SIZE", "large"),
266
+ "repo": os.getenv("MRT_CKPT_REPO"),
267
+ "rev": os.getenv("MRT_CKPT_REV", "main"),
268
+ "step": os.getenv("MRT_CKPT_STEP"),
269
+ "assets": os.getenv("MRT_ASSETS_REPO", _FINETUNE_REPO_DEFAULT),
270
+ }
271
+
272
+ # Flags for special step values
273
+ no_ckpt = isinstance(req.step, str) and req.step.lower() == "none"
274
+ latest = isinstance(req.step, str) and req.step.lower() == "latest"
275
+
276
+ # Target selection
277
+ tgt = {
278
+ "size": req.size or cur["size"],
279
+ "repo": None if no_ckpt else (req.repo_id or cur["repo"]),
280
+ "rev": req.revision if req.revision is not None else cur["rev"],
281
+ "step": None if (no_ckpt or latest) else (str(req.step) if req.step is not None else cur["step"]),
282
+ "assets": req.assets_repo_id or req.repo_id or cur["assets"],
283
+ }
284
+
285
+ # Case 1: No checkpoint (stock model)
286
+ if no_ckpt:
287
+ return True, {
288
+ "target_size": tgt["size"],
289
+ "target_repo": None,
290
+ "target_revision": None,
291
+ "target_step": None,
292
+ "assets_repo": None,
293
+ "assets_probe": {"ok": True, "message": "skipped"},
294
+ }
295
+
296
+ # Case 2: Checkpoint selection
297
+ if not tgt["repo"]:
298
+ return False, {"error": "repo_id is required for model selection."}
299
+
300
+ # Enumerate available steps
301
+ try:
302
+ steps = self.checkpoint_manager.list_ckpt_steps(tgt["repo"], tgt["rev"])
303
+ except Exception as e:
304
+ return False, {"error": f"Failed to list checkpoints: {e}"}
305
+
306
+ if not steps:
307
+ return False, {
308
+ "error": f"No checkpoint files found in {tgt['repo']}@{tgt['rev']}",
309
+ "discovered_steps": steps
310
+ }
311
+
312
+ # Choose step (explicit or latest)
313
+ chosen_step = int(tgt["step"]) if tgt["step"] is not None else steps[-1]
314
+ if chosen_step not in steps:
315
+ return False, {
316
+ "error": f"checkpoint_{chosen_step} not present in {tgt['repo']}@{tgt['rev']}",
317
+ "discovered_steps": steps
318
+ }
319
+
320
+ # Optional finetune assets probe
321
+ assets_ok, assets_msg = True, "skipped"
322
+ if req.sync_assets:
323
+ try:
324
+ api = HfApi()
325
+ files = set(api.list_repo_files(repo_id=tgt["assets"], repo_type="model"))
326
+ if ("mean_style_embed.npy" not in files) and ("cluster_centroids.npy" not in files):
327
+ assets_ok, assets_msg = False, f"No finetune asset files in {tgt['assets']}"
328
+ else:
329
+ assets_msg = "found"
330
+ except Exception as e:
331
+ assets_ok, assets_msg = False, f"probe failed: {e}"
332
+
333
+ return True, {
334
+ "target_size": tgt["size"],
335
+ "target_repo": tgt["repo"],
336
+ "target_revision": tgt["rev"],
337
+ "target_step": chosen_step,
338
+ "assets_repo": tgt["assets"] if req.sync_assets else None,
339
+ "assets_probe": {"ok": assets_ok, "message": assets_msg},
340
+ }
341
+
342
+ def prepare_env_changes(self, req: ModelSelect, validation_result: dict) -> dict:
343
+ """
344
+ Prepare the environment variable changes needed for a model selection.
345
+
346
+ Args:
347
+ req: The model selection request
348
+ validation_result: Result from validate_selection()
349
+
350
+ Returns:
351
+ Dictionary of environment variable changes to apply
352
+ """
353
+ no_ckpt = isinstance(req.step, str) and req.step.lower() == "none"
354
+
355
+ if no_ckpt:
356
+ # Clear checkpoint env vars for stock model
357
+ return {
358
+ "MRT_SIZE": validation_result["target_size"],
359
+ "MRT_CKPT_REPO": None, # None means delete the env var
360
+ "MRT_CKPT_REV": None,
361
+ "MRT_CKPT_STEP": None,
362
+ "MRT_ASSETS_REPO": None,
363
+ }
364
+ else:
365
+ # Set checkpoint env vars
366
+ env_changes = {
367
+ "MRT_SIZE": validation_result["target_size"],
368
+ "MRT_CKPT_REPO": validation_result["target_repo"],
369
+ "MRT_CKPT_REV": validation_result["target_revision"],
370
+ "MRT_CKPT_STEP": str(validation_result["target_step"]),
371
+ }
372
+ if req.sync_assets:
373
+ env_changes["MRT_ASSETS_REPO"] = validation_result["assets_repo"]
374
+ return env_changes