Spaces:
Running
Running
# model_management.py | |
""" | |
Model management utilities for MagentaRT API. | |
This module handles checkpoint discovery, asset loading, and model selection logic. | |
It is designed to work with the global state managed in app.py without interfering | |
with the critical JAX/XLA initialization sequence. | |
""" | |
import os | |
import re | |
import logging | |
from pathlib import Path | |
from typing import Optional, Union, Literal, Tuple, List | |
import tarfile | |
import numpy as np | |
from pydantic import BaseModel | |
from huggingface_hub import snapshot_download, HfApi, hf_hub_download | |
# ---- Constants and Patterns ---- | |
_FINETUNE_REPO_DEFAULT = os.getenv("MRT_ASSETS_REPO", "thepatch/magenta-ft") | |
_STEP_RE = re.compile(r"(?:^|/)checkpoint_(\d+)(?:/|\.tar\.gz|\.tgz)?$") | |
# ---- Pydantic Models ---- | |
class ModelSelect(BaseModel): | |
size: Optional[Literal["base","large"]] = None | |
repo_id: Optional[str] = None | |
revision: Optional[str] = "main" | |
step: Optional[Union[int, str]] = None # allow "latest" | |
assets_repo_id: Optional[str] = None # default: follow repo_id | |
sync_assets: bool = True # load mean/centroids from repo | |
prewarm: bool = False # call get_mrt() to build right away | |
stop_active: bool = True # auto-stop jams; else 409 | |
dry_run: bool = False # validate only, don't swap | |
# ---- Checkpoint Discovery ---- | |
class CheckpointManager: | |
"""Handles checkpoint discovery and validation without modifying global state.""" | |
def list_ckpt_steps(repo_id: str, revision: str = "main") -> List[int]: | |
""" | |
List available checkpoint steps in a HF model repo without downloading all weights. | |
Looks for: | |
checkpoint_<step>/ | |
checkpoint_<step>.tgz | .tar.gz | |
archives/checkpoint_<step>.tgz | .tar.gz | |
""" | |
api = HfApi() | |
files = api.list_repo_files(repo_id=repo_id, repo_type="model", revision=revision) | |
steps = set() | |
for f in files: | |
m = _STEP_RE.search(f) | |
if m: | |
try: | |
steps.add(int(m.group(1))) | |
except: | |
pass | |
return sorted(steps) | |
def step_exists(repo_id: str, revision: str, step: int) -> bool: | |
"""Check if a specific checkpoint step exists in the repo.""" | |
return step in CheckpointManager.list_ckpt_steps(repo_id, revision) | |
def resolve_checkpoint_dir() -> Optional[str]: | |
""" | |
Resolve the checkpoint directory from environment variables. | |
Downloads and extracts if necessary. | |
Returns the path to the checkpoint directory or None if not configured. | |
""" | |
repo_id = os.getenv("MRT_CKPT_REPO") | |
if not repo_id: | |
return None | |
step = os.getenv("MRT_CKPT_STEP") # e.g. "1863001" | |
root = Path(snapshot_download( | |
repo_id=repo_id, | |
repo_type="model", | |
revision=os.getenv("MRT_CKPT_REV", "main"), | |
local_dir="/home/appuser/.cache/mrt_ckpt/repo", | |
local_dir_use_symlinks=False, | |
)) | |
# Prefer an archive if present (more reliable for Zarr/T5X) | |
arch_names = [ | |
f"checkpoint_{step}.tgz", | |
f"checkpoint_{step}.tar.gz", | |
f"archives/checkpoint_{step}.tgz", | |
f"archives/checkpoint_{step}.tar.gz", | |
] if step else [] | |
cache_root = Path("/home/appuser/.cache/mrt_ckpt/extracted") | |
cache_root.mkdir(parents=True, exist_ok=True) | |
for name in arch_names: | |
arch = root / name | |
if arch.is_file(): | |
out_dir = cache_root / f"checkpoint_{step}" | |
marker = out_dir.with_suffix(".ok") | |
if not marker.exists(): | |
out_dir.mkdir(parents=True, exist_ok=True) | |
with tarfile.open(arch, "r:*") as tf: | |
tf.extractall(out_dir) | |
marker.write_text("ok") | |
# sanity: require .zarray to exist inside the extracted tree | |
if not any(out_dir.rglob(".zarray")): | |
raise RuntimeError(f"Extracted archive missing .zarray files: {out_dir}") | |
return str(out_dir / f"checkpoint_{step}") if (out_dir / f"checkpoint_{step}").exists() else str(out_dir) | |
# No archive; try raw folder from repo and sanity check. | |
if step: | |
raw = root / f"checkpoint_{step}" | |
if raw.is_dir(): | |
if not any(raw.rglob(".zarray")): | |
raise RuntimeError( | |
f"Downloaded checkpoint_{step} appears incomplete (no .zarray). " | |
"Upload as a .tgz or push via git from a Unix shell." | |
) | |
return str(raw) | |
# Pick latest if no step | |
step_dirs = [d for d in root.iterdir() if d.is_dir() and re.match(r"checkpoint_\d+$", d.name)] | |
if step_dirs: | |
pick = max(step_dirs, key=lambda d: int(d.name.split('_')[-1])) | |
if not any(pick.rglob(".zarray")): | |
raise RuntimeError(f"Downloaded {pick} appears incomplete (no .zarray).") | |
return str(pick) | |
return None | |
# ---- Asset Management ---- | |
class AssetManager: | |
""" | |
Handles finetune asset loading and management. | |
This class modifies global variables in the calling module, but encapsulates | |
the logic for loading and validating assets. | |
""" | |
def __init__(self): | |
# These will be set by the calling module | |
self.mean_embed = None | |
self.centroids = None | |
self.assets_repo_id = None | |
def load_finetune_assets_from_hf(self, repo_id: Optional[str], mrt=None) -> Tuple[bool, str]: | |
""" | |
Download & load mean_style_embed.npy and cluster_centroids.npy from a HF model repo. | |
Safe to call multiple times; will overwrite instance vars if successful. | |
Args: | |
repo_id: HuggingFace repo ID, defaults to _FINETUNE_REPO_DEFAULT | |
mrt: MagentaRT instance for dimension validation (optional) | |
Returns: | |
Tuple of (success: bool, message: str) | |
""" | |
repo_id = repo_id or _FINETUNE_REPO_DEFAULT | |
try: | |
mean_path = None | |
cent_path = None | |
try: | |
mean_path = hf_hub_download(repo_id, filename="mean_style_embed.npy", repo_type="model") | |
except Exception: | |
pass | |
try: | |
cent_path = hf_hub_download(repo_id, filename="cluster_centroids.npy", repo_type="model") | |
except Exception: | |
pass | |
if mean_path is None and cent_path is None: | |
return False, f"No finetune asset files found in repo {repo_id}" | |
if mean_path is not None: | |
m = np.load(mean_path) | |
if m.ndim != 1: | |
return False, f"mean_style_embed.npy must be 1-D (got {m.shape})" | |
else: | |
m = None | |
if cent_path is not None: | |
c = np.load(cent_path) | |
if c.ndim != 2: | |
return False, f"cluster_centroids.npy must be 2-D (got {c.shape})" | |
else: | |
c = None | |
# Optional: shape check vs model embedding dim once model is alive | |
if mrt is not None: | |
try: | |
d = int(mrt.style_model.config.embedding_dim) | |
if m is not None and m.shape[0] != d: | |
return False, f"mean_style_embed dim {m.shape[0]} != model dim {d}" | |
if c is not None and c.shape[1] != d: | |
return False, f"cluster_centroids dim {c.shape[1]} != model dim {d}" | |
except Exception: | |
# Model not built yet; we'll trust the files and rely on runtime checks later | |
pass | |
# Update instance variables | |
self.mean_embed = m.astype(np.float32, copy=False) if m is not None else None | |
self.centroids = c.astype(np.float32, copy=False) if c is not None else None | |
self.assets_repo_id = repo_id | |
logging.info("Loaded finetune assets from %s (mean=%s, centroids=%s)", | |
repo_id, | |
"yes" if self.mean_embed is not None else "no", | |
f"{self.centroids.shape[0]}x{self.centroids.shape[1]}" if self.centroids is not None else "no") | |
return True, "ok" | |
except Exception as e: | |
logging.exception("Failed to load finetune assets: %s", e) | |
return False, str(e) | |
def ensure_assets_loaded(self, mrt=None): | |
"""Best-effort lazy load if nothing is loaded yet.""" | |
if self.mean_embed is None and self.centroids is None: | |
self.load_finetune_assets_from_hf(self.assets_repo_id or _FINETUNE_REPO_DEFAULT, mrt) | |
def get_status(self, mrt=None) -> dict: | |
"""Get current asset status.""" | |
d = None | |
if mrt is not None: | |
try: | |
d = int(mrt.style_model.config.embedding_dim) | |
except Exception: | |
pass | |
return { | |
"repo_id": self.assets_repo_id, | |
"mean_loaded": self.mean_embed is not None, | |
"centroids_loaded": self.centroids is not None, | |
"centroid_count": None if self.centroids is None else int(self.centroids.shape[0]), | |
"embedding_dim": d, | |
} | |
# ---- Model Selection Logic ---- | |
class ModelSelector: | |
""" | |
Handles model selection and validation logic. | |
This class encapsulates the complex logic from the /model/select endpoint | |
while keeping environment variable management in the calling code. | |
""" | |
def __init__(self, checkpoint_manager: CheckpointManager, asset_manager: AssetManager): | |
self.checkpoint_manager = checkpoint_manager | |
self.asset_manager = asset_manager | |
def validate_selection(self, req: ModelSelect) -> Tuple[bool, dict]: | |
""" | |
Validate a model selection request without making any changes. | |
Returns: | |
Tuple of (success: bool, result_dict: dict) | |
""" | |
# Current env defaults | |
cur = { | |
"size": os.getenv("MRT_SIZE", "large"), | |
"repo": os.getenv("MRT_CKPT_REPO"), | |
"rev": os.getenv("MRT_CKPT_REV", "main"), | |
"step": os.getenv("MRT_CKPT_STEP"), | |
"assets": os.getenv("MRT_ASSETS_REPO", _FINETUNE_REPO_DEFAULT), | |
} | |
# Flags for special step values | |
no_ckpt = isinstance(req.step, str) and req.step.lower() == "none" | |
latest = isinstance(req.step, str) and req.step.lower() == "latest" | |
# Target selection | |
tgt = { | |
"size": req.size or cur["size"], | |
"repo": None if no_ckpt else (req.repo_id or cur["repo"]), | |
"rev": req.revision if req.revision is not None else cur["rev"], | |
"step": None if (no_ckpt or latest) else (str(req.step) if req.step is not None else cur["step"]), | |
"assets": req.assets_repo_id or req.repo_id or cur["assets"], | |
} | |
# Case 1: No checkpoint (stock model) | |
if no_ckpt: | |
return True, { | |
"target_size": tgt["size"], | |
"target_repo": None, | |
"target_revision": None, | |
"target_step": None, | |
"assets_repo": None, | |
"assets_probe": {"ok": True, "message": "skipped"}, | |
} | |
# Case 2: Checkpoint selection | |
if not tgt["repo"]: | |
return False, {"error": "repo_id is required for model selection."} | |
# Enumerate available steps | |
try: | |
steps = self.checkpoint_manager.list_ckpt_steps(tgt["repo"], tgt["rev"]) | |
except Exception as e: | |
return False, {"error": f"Failed to list checkpoints: {e}"} | |
if not steps: | |
return False, { | |
"error": f"No checkpoint files found in {tgt['repo']}@{tgt['rev']}", | |
"discovered_steps": steps | |
} | |
# Choose step (explicit or latest) | |
chosen_step = int(tgt["step"]) if tgt["step"] is not None else steps[-1] | |
if chosen_step not in steps: | |
return False, { | |
"error": f"checkpoint_{chosen_step} not present in {tgt['repo']}@{tgt['rev']}", | |
"discovered_steps": steps | |
} | |
# Optional finetune assets probe | |
assets_ok, assets_msg = True, "skipped" | |
if req.sync_assets: | |
try: | |
api = HfApi() | |
files = set(api.list_repo_files(repo_id=tgt["assets"], repo_type="model")) | |
if ("mean_style_embed.npy" not in files) and ("cluster_centroids.npy" not in files): | |
assets_ok, assets_msg = False, f"No finetune asset files in {tgt['assets']}" | |
else: | |
assets_msg = "found" | |
except Exception as e: | |
assets_ok, assets_msg = False, f"probe failed: {e}" | |
return True, { | |
"target_size": tgt["size"], | |
"target_repo": tgt["repo"], | |
"target_revision": tgt["rev"], | |
"target_step": chosen_step, | |
"assets_repo": tgt["assets"] if req.sync_assets else None, | |
"assets_probe": {"ok": assets_ok, "message": assets_msg}, | |
} | |
def prepare_env_changes(self, req: ModelSelect, validation_result: dict) -> dict: | |
""" | |
Prepare the environment variable changes needed for a model selection. | |
Args: | |
req: The model selection request | |
validation_result: Result from validate_selection() | |
Returns: | |
Dictionary of environment variable changes to apply | |
""" | |
no_ckpt = isinstance(req.step, str) and req.step.lower() == "none" | |
if no_ckpt: | |
# Clear checkpoint env vars for stock model | |
return { | |
"MRT_SIZE": validation_result["target_size"], | |
"MRT_CKPT_REPO": None, # None means delete the env var | |
"MRT_CKPT_REV": None, | |
"MRT_CKPT_STEP": None, | |
"MRT_ASSETS_REPO": None, | |
} | |
else: | |
# Set checkpoint env vars | |
env_changes = { | |
"MRT_SIZE": validation_result["target_size"], | |
"MRT_CKPT_REPO": validation_result["target_repo"], | |
"MRT_CKPT_REV": validation_result["target_revision"], | |
"MRT_CKPT_STEP": str(validation_result["target_step"]), | |
} | |
if req.sync_assets: | |
env_changes["MRT_ASSETS_REPO"] = validation_result["assets_repo"] | |
return env_changes |