|
|
|
""" |
|
Model management utilities for MagentaRT API. |
|
|
|
This module handles checkpoint discovery, asset loading, and model selection logic. |
|
It is designed to work with the global state managed in app.py without interfering |
|
with the critical JAX/XLA initialization sequence. |
|
""" |
|
|
|
import os |
|
import re |
|
import logging |
|
from pathlib import Path |
|
from typing import Optional, Union, Literal, Tuple, List |
|
import tarfile |
|
|
|
import numpy as np |
|
from pydantic import BaseModel |
|
from huggingface_hub import snapshot_download, HfApi, hf_hub_download |
|
|
|
|
|
|
|
_FINETUNE_REPO_DEFAULT = os.getenv("MRT_ASSETS_REPO", "thepatch/magenta-ft") |
|
_STEP_RE = re.compile(r"(?:^|/)checkpoint_(\d+)(?:/|\.tar\.gz|\.tgz)?$") |
|
|
|
|
|
|
|
class ModelSelect(BaseModel): |
|
size: Optional[Literal["base","large"]] = None |
|
repo_id: Optional[str] = None |
|
revision: Optional[str] = "main" |
|
step: Optional[Union[int, str]] = None |
|
assets_repo_id: Optional[str] = None |
|
sync_assets: bool = True |
|
prewarm: bool = False |
|
stop_active: bool = True |
|
dry_run: bool = False |
|
|
|
|
|
|
|
class CheckpointManager: |
|
"""Handles checkpoint discovery and validation without modifying global state.""" |
|
|
|
@staticmethod |
|
def list_ckpt_steps(repo_id: str, revision: str = "main") -> List[int]: |
|
""" |
|
List available checkpoint steps in a HF model repo without downloading all weights. |
|
Looks for: |
|
checkpoint_<step>/ |
|
checkpoint_<step>.tgz | .tar.gz |
|
archives/checkpoint_<step>.tgz | .tar.gz |
|
""" |
|
api = HfApi() |
|
files = api.list_repo_files(repo_id=repo_id, repo_type="model", revision=revision) |
|
steps = set() |
|
for f in files: |
|
m = _STEP_RE.search(f) |
|
if m: |
|
try: |
|
steps.add(int(m.group(1))) |
|
except: |
|
pass |
|
return sorted(steps) |
|
|
|
@staticmethod |
|
def step_exists(repo_id: str, revision: str, step: int) -> bool: |
|
"""Check if a specific checkpoint step exists in the repo.""" |
|
return step in CheckpointManager.list_ckpt_steps(repo_id, revision) |
|
|
|
@staticmethod |
|
def resolve_checkpoint_dir() -> Optional[str]: |
|
""" |
|
Resolve the checkpoint directory from environment variables. |
|
Downloads and extracts if necessary. |
|
Returns the path to the checkpoint directory or None if not configured. |
|
""" |
|
repo_id = os.getenv("MRT_CKPT_REPO") |
|
if not repo_id: |
|
return None |
|
step = os.getenv("MRT_CKPT_STEP") |
|
|
|
root = Path(snapshot_download( |
|
repo_id=repo_id, |
|
repo_type="model", |
|
revision=os.getenv("MRT_CKPT_REV", "main"), |
|
local_dir="/home/appuser/.cache/mrt_ckpt/repo", |
|
local_dir_use_symlinks=False, |
|
)) |
|
|
|
|
|
arch_names = [ |
|
f"checkpoint_{step}.tgz", |
|
f"checkpoint_{step}.tar.gz", |
|
f"archives/checkpoint_{step}.tgz", |
|
f"archives/checkpoint_{step}.tar.gz", |
|
] if step else [] |
|
|
|
cache_root = Path("/home/appuser/.cache/mrt_ckpt/extracted") |
|
cache_root.mkdir(parents=True, exist_ok=True) |
|
for name in arch_names: |
|
arch = root / name |
|
if arch.is_file(): |
|
out_dir = cache_root / f"checkpoint_{step}" |
|
marker = out_dir.with_suffix(".ok") |
|
if not marker.exists(): |
|
out_dir.mkdir(parents=True, exist_ok=True) |
|
with tarfile.open(arch, "r:*") as tf: |
|
tf.extractall(out_dir) |
|
marker.write_text("ok") |
|
|
|
if not any(out_dir.rglob(".zarray")): |
|
raise RuntimeError(f"Extracted archive missing .zarray files: {out_dir}") |
|
return str(out_dir / f"checkpoint_{step}") if (out_dir / f"checkpoint_{step}").exists() else str(out_dir) |
|
|
|
|
|
if step: |
|
raw = root / f"checkpoint_{step}" |
|
if raw.is_dir(): |
|
if not any(raw.rglob(".zarray")): |
|
raise RuntimeError( |
|
f"Downloaded checkpoint_{step} appears incomplete (no .zarray). " |
|
"Upload as a .tgz or push via git from a Unix shell." |
|
) |
|
return str(raw) |
|
|
|
|
|
step_dirs = [d for d in root.iterdir() if d.is_dir() and re.match(r"checkpoint_\d+$", d.name)] |
|
if step_dirs: |
|
pick = max(step_dirs, key=lambda d: int(d.name.split('_')[-1])) |
|
if not any(pick.rglob(".zarray")): |
|
raise RuntimeError(f"Downloaded {pick} appears incomplete (no .zarray).") |
|
return str(pick) |
|
|
|
return None |
|
|
|
|
|
|
|
class AssetManager: |
|
""" |
|
Handles finetune asset loading and management. |
|
|
|
This class modifies global variables in the calling module, but encapsulates |
|
the logic for loading and validating assets. |
|
""" |
|
|
|
def __init__(self): |
|
|
|
self.mean_embed = None |
|
self.centroids = None |
|
self.assets_repo_id = None |
|
|
|
def load_finetune_assets_from_hf(self, repo_id: Optional[str], mrt=None) -> Tuple[bool, str]: |
|
""" |
|
Download & load mean_style_embed.npy and cluster_centroids.npy from a HF model repo. |
|
Safe to call multiple times; will overwrite instance vars if successful. |
|
|
|
Args: |
|
repo_id: HuggingFace repo ID, defaults to _FINETUNE_REPO_DEFAULT |
|
mrt: MagentaRT instance for dimension validation (optional) |
|
|
|
Returns: |
|
Tuple of (success: bool, message: str) |
|
""" |
|
repo_id = repo_id or _FINETUNE_REPO_DEFAULT |
|
try: |
|
mean_path = None |
|
cent_path = None |
|
try: |
|
mean_path = hf_hub_download(repo_id, filename="mean_style_embed.npy", repo_type="model") |
|
except Exception: |
|
pass |
|
try: |
|
cent_path = hf_hub_download(repo_id, filename="cluster_centroids.npy", repo_type="model") |
|
except Exception: |
|
pass |
|
|
|
if mean_path is None and cent_path is None: |
|
return False, f"No finetune asset files found in repo {repo_id}" |
|
|
|
if mean_path is not None: |
|
m = np.load(mean_path) |
|
if m.ndim != 1: |
|
return False, f"mean_style_embed.npy must be 1-D (got {m.shape})" |
|
else: |
|
m = None |
|
|
|
if cent_path is not None: |
|
c = np.load(cent_path) |
|
if c.ndim != 2: |
|
return False, f"cluster_centroids.npy must be 2-D (got {c.shape})" |
|
else: |
|
c = None |
|
|
|
|
|
if mrt is not None: |
|
try: |
|
d = int(mrt.style_model.config.embedding_dim) |
|
if m is not None and m.shape[0] != d: |
|
return False, f"mean_style_embed dim {m.shape[0]} != model dim {d}" |
|
if c is not None and c.shape[1] != d: |
|
return False, f"cluster_centroids dim {c.shape[1]} != model dim {d}" |
|
except Exception: |
|
|
|
pass |
|
|
|
|
|
self.mean_embed = m.astype(np.float32, copy=False) if m is not None else None |
|
self.centroids = c.astype(np.float32, copy=False) if c is not None else None |
|
self.assets_repo_id = repo_id |
|
|
|
logging.info("Loaded finetune assets from %s (mean=%s, centroids=%s)", |
|
repo_id, |
|
"yes" if self.mean_embed is not None else "no", |
|
f"{self.centroids.shape[0]}x{self.centroids.shape[1]}" if self.centroids is not None else "no") |
|
return True, "ok" |
|
except Exception as e: |
|
logging.exception("Failed to load finetune assets: %s", e) |
|
return False, str(e) |
|
|
|
def ensure_assets_loaded(self, mrt=None): |
|
"""Best-effort lazy load if nothing is loaded yet.""" |
|
if self.mean_embed is None and self.centroids is None: |
|
self.load_finetune_assets_from_hf(self.assets_repo_id or _FINETUNE_REPO_DEFAULT, mrt) |
|
|
|
def get_status(self, mrt=None) -> dict: |
|
"""Get current asset status.""" |
|
d = None |
|
if mrt is not None: |
|
try: |
|
d = int(mrt.style_model.config.embedding_dim) |
|
except Exception: |
|
pass |
|
|
|
return { |
|
"repo_id": self.assets_repo_id, |
|
"mean_loaded": self.mean_embed is not None, |
|
"centroids_loaded": self.centroids is not None, |
|
"centroid_count": None if self.centroids is None else int(self.centroids.shape[0]), |
|
"embedding_dim": d, |
|
} |
|
|
|
|
|
|
|
class ModelSelector: |
|
""" |
|
Handles model selection and validation logic. |
|
|
|
This class encapsulates the complex logic from the /model/select endpoint |
|
while keeping environment variable management in the calling code. |
|
""" |
|
|
|
def __init__(self, checkpoint_manager: CheckpointManager, asset_manager: AssetManager): |
|
self.checkpoint_manager = checkpoint_manager |
|
self.asset_manager = asset_manager |
|
|
|
def validate_selection(self, req: ModelSelect) -> Tuple[bool, dict]: |
|
""" |
|
Validate a model selection request without making any changes. |
|
|
|
Returns: |
|
Tuple of (success: bool, result_dict: dict) |
|
""" |
|
|
|
cur = { |
|
"size": os.getenv("MRT_SIZE", "large"), |
|
"repo": os.getenv("MRT_CKPT_REPO"), |
|
"rev": os.getenv("MRT_CKPT_REV", "main"), |
|
"step": os.getenv("MRT_CKPT_STEP"), |
|
"assets": os.getenv("MRT_ASSETS_REPO", _FINETUNE_REPO_DEFAULT), |
|
} |
|
|
|
|
|
no_ckpt = isinstance(req.step, str) and req.step.lower() == "none" |
|
latest = isinstance(req.step, str) and req.step.lower() == "latest" |
|
|
|
|
|
tgt = { |
|
"size": req.size or cur["size"], |
|
"repo": None if no_ckpt else (req.repo_id or cur["repo"]), |
|
"rev": req.revision if req.revision is not None else cur["rev"], |
|
"step": None if (no_ckpt or latest) else (str(req.step) if req.step is not None else cur["step"]), |
|
"assets": req.assets_repo_id or req.repo_id or cur["assets"], |
|
} |
|
|
|
|
|
if no_ckpt: |
|
return True, { |
|
"target_size": tgt["size"], |
|
"target_repo": None, |
|
"target_revision": None, |
|
"target_step": None, |
|
"assets_repo": None, |
|
"assets_probe": {"ok": True, "message": "skipped"}, |
|
} |
|
|
|
|
|
if not tgt["repo"]: |
|
return False, {"error": "repo_id is required for model selection."} |
|
|
|
|
|
try: |
|
steps = self.checkpoint_manager.list_ckpt_steps(tgt["repo"], tgt["rev"]) |
|
except Exception as e: |
|
return False, {"error": f"Failed to list checkpoints: {e}"} |
|
|
|
if not steps: |
|
return False, { |
|
"error": f"No checkpoint files found in {tgt['repo']}@{tgt['rev']}", |
|
"discovered_steps": steps |
|
} |
|
|
|
|
|
chosen_step = int(tgt["step"]) if tgt["step"] is not None else steps[-1] |
|
if chosen_step not in steps: |
|
return False, { |
|
"error": f"checkpoint_{chosen_step} not present in {tgt['repo']}@{tgt['rev']}", |
|
"discovered_steps": steps |
|
} |
|
|
|
|
|
assets_ok, assets_msg = True, "skipped" |
|
if req.sync_assets: |
|
try: |
|
api = HfApi() |
|
files = set(api.list_repo_files(repo_id=tgt["assets"], repo_type="model")) |
|
if ("mean_style_embed.npy" not in files) and ("cluster_centroids.npy" not in files): |
|
assets_ok, assets_msg = False, f"No finetune asset files in {tgt['assets']}" |
|
else: |
|
assets_msg = "found" |
|
except Exception as e: |
|
assets_ok, assets_msg = False, f"probe failed: {e}" |
|
|
|
return True, { |
|
"target_size": tgt["size"], |
|
"target_repo": tgt["repo"], |
|
"target_revision": tgt["rev"], |
|
"target_step": chosen_step, |
|
"assets_repo": tgt["assets"] if req.sync_assets else None, |
|
"assets_probe": {"ok": assets_ok, "message": assets_msg}, |
|
} |
|
|
|
def prepare_env_changes(self, req: ModelSelect, validation_result: dict) -> dict: |
|
""" |
|
Prepare the environment variable changes needed for a model selection. |
|
|
|
Args: |
|
req: The model selection request |
|
validation_result: Result from validate_selection() |
|
|
|
Returns: |
|
Dictionary of environment variable changes to apply |
|
""" |
|
no_ckpt = isinstance(req.step, str) and req.step.lower() == "none" |
|
|
|
if no_ckpt: |
|
|
|
return { |
|
"MRT_SIZE": validation_result["target_size"], |
|
"MRT_CKPT_REPO": None, |
|
"MRT_CKPT_REV": None, |
|
"MRT_CKPT_STEP": None, |
|
"MRT_ASSETS_REPO": None, |
|
} |
|
else: |
|
|
|
env_changes = { |
|
"MRT_SIZE": validation_result["target_size"], |
|
"MRT_CKPT_REPO": validation_result["target_repo"], |
|
"MRT_CKPT_REV": validation_result["target_revision"], |
|
"MRT_CKPT_STEP": str(validation_result["target_step"]), |
|
} |
|
if req.sync_assets: |
|
env_changes["MRT_ASSETS_REPO"] = validation_result["assets_repo"] |
|
return env_changes |