Commit
·
49b5fee
1
Parent(s):
d373851
full model switching logic
Browse files
app.py
CHANGED
@@ -51,7 +51,7 @@ import uuid, threading
|
|
51 |
import logging
|
52 |
|
53 |
import gradio as gr
|
54 |
-
from typing import Optional
|
55 |
|
56 |
|
57 |
import json, asyncio, base64
|
@@ -68,7 +68,9 @@ except Exception:
|
|
68 |
|
69 |
import re, tarfile
|
70 |
from pathlib import Path
|
71 |
-
from huggingface_hub import snapshot_download
|
|
|
|
|
72 |
|
73 |
# ---- Finetune assets (mean & centroids) --------------------------------------
|
74 |
_FINETUNE_REPO_DEFAULT = os.getenv("MRT_ASSETS_REPO", "thepatch/magenta-ft")
|
@@ -76,6 +78,43 @@ _ASSETS_REPO_ID: str | None = None
|
|
76 |
_MEAN_EMBED: np.ndarray | None = None # shape (D,) dtype float32
|
77 |
_CENTROIDS: np.ndarray | None = None # shape (K, D) dtype float32
|
78 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
79 |
def _load_finetune_assets_from_hf(repo_id: str | None) -> tuple[bool, str]:
|
80 |
"""
|
81 |
Download & load mean_style_embed.npy and cluster_centroids.npy from a HF model repo.
|
@@ -927,6 +966,151 @@ def model_assets_status():
|
|
927 |
"embedding_dim": d,
|
928 |
}
|
929 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
930 |
@app.post("/generate")
|
931 |
def generate(
|
932 |
loop_audio: UploadFile = File(...),
|
|
|
51 |
import logging
|
52 |
|
53 |
import gradio as gr
|
54 |
+
from typing import Optional, Union, Literal
|
55 |
|
56 |
|
57 |
import json, asyncio, base64
|
|
|
68 |
|
69 |
import re, tarfile
|
70 |
from pathlib import Path
|
71 |
+
from huggingface_hub import snapshot_download, HfApi
|
72 |
+
|
73 |
+
from pydantic import BaseModel
|
74 |
|
75 |
# ---- Finetune assets (mean & centroids) --------------------------------------
|
76 |
_FINETUNE_REPO_DEFAULT = os.getenv("MRT_ASSETS_REPO", "thepatch/magenta-ft")
|
|
|
78 |
_MEAN_EMBED: np.ndarray | None = None # shape (D,) dtype float32
|
79 |
_CENTROIDS: np.ndarray | None = None # shape (K, D) dtype float32
|
80 |
|
81 |
+
_STEP_RE = re.compile(r"(?:^|/)checkpoint_(\d+)(?:/|\.tar\.gz|\.tgz)?$")
|
82 |
+
|
83 |
+
def _list_ckpt_steps(repo_id: str, revision: str = "main") -> list[int]:
|
84 |
+
"""
|
85 |
+
List available checkpoint steps in a HF model repo without downloading all weights.
|
86 |
+
Looks for:
|
87 |
+
checkpoint_<step>/
|
88 |
+
checkpoint_<step>.tgz | .tar.gz
|
89 |
+
archives/checkpoint_<step>.tgz | .tar.gz
|
90 |
+
"""
|
91 |
+
api = HfApi()
|
92 |
+
files = api.list_repo_files(repo_id=repo_id, repo_type="model", revision=revision)
|
93 |
+
steps = set()
|
94 |
+
for f in files:
|
95 |
+
m = _STEP_RE.search(f)
|
96 |
+
if m:
|
97 |
+
try:
|
98 |
+
steps.add(int(m.group(1)))
|
99 |
+
except:
|
100 |
+
pass
|
101 |
+
return sorted(steps)
|
102 |
+
|
103 |
+
def _step_exists(repo_id: str, revision: str, step: int) -> bool:
|
104 |
+
return step in _list_ckpt_steps(repo_id, revision)
|
105 |
+
|
106 |
+
def _any_jam_running() -> bool:
|
107 |
+
with jam_lock:
|
108 |
+
return any(w.is_alive() for w in jam_registry.values())
|
109 |
+
|
110 |
+
def _stop_all_jams(timeout: float = 5.0):
|
111 |
+
with jam_lock:
|
112 |
+
for sid, w in list(jam_registry.items()):
|
113 |
+
if w.is_alive():
|
114 |
+
w.stop()
|
115 |
+
w.join(timeout=timeout)
|
116 |
+
jam_registry.pop(sid, None)
|
117 |
+
|
118 |
def _load_finetune_assets_from_hf(repo_id: str | None) -> tuple[bool, str]:
|
119 |
"""
|
120 |
Download & load mean_style_embed.npy and cluster_centroids.npy from a HF model repo.
|
|
|
966 |
"embedding_dim": d,
|
967 |
}
|
968 |
|
969 |
+
@app.get("/model/config")
|
970 |
+
def model_config():
|
971 |
+
mrt = None
|
972 |
+
try:
|
973 |
+
mrt = get_mrt()
|
974 |
+
except Exception:
|
975 |
+
pass
|
976 |
+
return {
|
977 |
+
"size": os.getenv("MRT_SIZE", "large"),
|
978 |
+
"repo": os.getenv("MRT_CKPT_REPO"),
|
979 |
+
"revision": os.getenv("MRT_CKPT_REV", "main"),
|
980 |
+
"selected_step": os.getenv("MRT_CKPT_STEP"),
|
981 |
+
"resolved_ckpt_dir": _resolve_checkpoint_dir(), # may be None if not yet downloaded
|
982 |
+
"loaded": bool(mrt),
|
983 |
+
}
|
984 |
+
|
985 |
+
@app.get("/model/checkpoints")
|
986 |
+
def model_checkpoints(repo_id: str, revision: str = "main"):
|
987 |
+
steps = _list_ckpt_steps(repo_id, revision)
|
988 |
+
return {"repo": repo_id, "revision": revision, "steps": steps, "latest": (steps[-1] if steps else None)}
|
989 |
+
|
990 |
+
class ModelSelect(BaseModel):
|
991 |
+
size: Optional[Literal["base","large"]] = None
|
992 |
+
repo_id: Optional[str] = None
|
993 |
+
revision: Optional[str] = "main"
|
994 |
+
step: Optional[Union[int, str]] = None # allow "latest"
|
995 |
+
assets_repo_id: Optional[str] = None # default: follow repo_id
|
996 |
+
sync_assets: bool = True # load mean/centroids from repo
|
997 |
+
prewarm: bool = False # call get_mrt() to build right away
|
998 |
+
stop_active: bool = True # auto-stop jams; else 409
|
999 |
+
dry_run: bool = False # validate only, don't swap
|
1000 |
+
|
1001 |
+
@app.post("/model/select")
|
1002 |
+
def model_select(req: ModelSelect):
|
1003 |
+
# Resolve desired target config (fall back to current env)
|
1004 |
+
cur = {
|
1005 |
+
"size": os.getenv("MRT_SIZE", "large"),
|
1006 |
+
"repo": os.getenv("MRT_CKPT_REPO"),
|
1007 |
+
"rev": os.getenv("MRT_CKPT_REV", "main"),
|
1008 |
+
"step": os.getenv("MRT_CKPT_STEP"),
|
1009 |
+
"assets": os.getenv("MRT_ASSETS_REPO", _FINETUNE_REPO_DEFAULT),
|
1010 |
+
}
|
1011 |
+
tgt = {
|
1012 |
+
"size": req.size or cur["size"],
|
1013 |
+
"repo": req.repo_id or cur["repo"],
|
1014 |
+
"rev": (req.revision if req.revision is not None else cur["rev"]),
|
1015 |
+
"step": (None if (isinstance(req.step, str) and req.step.lower()=="latest") else (str(req.step) if req.step is not None else cur["step"])),
|
1016 |
+
"assets": (req.assets_repo_id or req.repo_id or cur["assets"]),
|
1017 |
+
}
|
1018 |
+
|
1019 |
+
if not tgt["repo"]:
|
1020 |
+
raise HTTPException(status_code=400, detail="repo_id is required at least once before selecting 'latest'.")
|
1021 |
+
|
1022 |
+
# ---- Dry-run validation (no env changes) ----
|
1023 |
+
# 1) enumerate steps
|
1024 |
+
steps = _list_ckpt_steps(tgt["repo"], tgt["rev"])
|
1025 |
+
if not steps:
|
1026 |
+
return {"ok": False, "error": f"No checkpoint files found in {tgt['repo']}@{tgt['rev']}", "discovered_steps": steps}
|
1027 |
+
|
1028 |
+
# 2) choose step
|
1029 |
+
chosen_step = int(tgt["step"]) if tgt["step"] is not None else steps[-1]
|
1030 |
+
if chosen_step not in steps:
|
1031 |
+
return {"ok": False, "error": f"checkpoint_{chosen_step} not present in {tgt['repo']}@{tgt['rev']}", "discovered_steps": steps}
|
1032 |
+
|
1033 |
+
# 3) optional: quick asset sanity (only list, don’t download)
|
1034 |
+
assets_ok = True
|
1035 |
+
assets_msg = "skipped"
|
1036 |
+
if req.sync_assets:
|
1037 |
+
try:
|
1038 |
+
# a quick probe: ensure either file exists; if not, allow anyway (assets are optional)
|
1039 |
+
api = HfApi()
|
1040 |
+
files = set(api.list_repo_files(repo_id=tgt["assets"], repo_type="model"))
|
1041 |
+
if ("mean_style_embed.npy" not in files) and ("cluster_centroids.npy" not in files):
|
1042 |
+
assets_ok, assets_msg = False, f"No finetune asset files in {tgt['assets']}"
|
1043 |
+
else:
|
1044 |
+
assets_msg = "found"
|
1045 |
+
except Exception as e:
|
1046 |
+
assets_ok, assets_msg = False, f"probe failed: {e}"
|
1047 |
+
|
1048 |
+
preview = {
|
1049 |
+
"target_size": tgt["size"],
|
1050 |
+
"target_repo": tgt["repo"],
|
1051 |
+
"target_revision": tgt["rev"],
|
1052 |
+
"target_step": chosen_step,
|
1053 |
+
"assets_repo": tgt["assets"] if req.sync_assets else None,
|
1054 |
+
"assets_probe": {"ok": assets_ok, "message": assets_msg},
|
1055 |
+
"active_jam": _any_jam_running(),
|
1056 |
+
}
|
1057 |
+
if req.dry_run:
|
1058 |
+
return {"ok": True, "dry_run": True, **preview}
|
1059 |
+
|
1060 |
+
# ---- Enforce jam policy ----
|
1061 |
+
if _any_jam_running():
|
1062 |
+
if req.stop_active:
|
1063 |
+
_stop_all_jams()
|
1064 |
+
else:
|
1065 |
+
raise HTTPException(status_code=409, detail="A jam is running; retry with stop_active=true")
|
1066 |
+
|
1067 |
+
# ---- Atomic swap with rollback ----
|
1068 |
+
old_env = {
|
1069 |
+
"MRT_SIZE": os.getenv("MRT_SIZE"),
|
1070 |
+
"MRT_CKPT_REPO": os.getenv("MRT_CKPT_REPO"),
|
1071 |
+
"MRT_CKPT_REV": os.getenv("MRT_CKPT_REV"),
|
1072 |
+
"MRT_CKPT_STEP": os.getenv("MRT_CKPT_STEP"),
|
1073 |
+
"MRT_ASSETS_REPO": os.getenv("MRT_ASSETS_REPO"),
|
1074 |
+
}
|
1075 |
+
try:
|
1076 |
+
os.environ["MRT_SIZE"] = str(tgt["size"])
|
1077 |
+
os.environ["MRT_CKPT_REPO"] = str(tgt["repo"])
|
1078 |
+
os.environ["MRT_CKPT_REV"] = str(tgt["rev"])
|
1079 |
+
os.environ["MRT_CKPT_STEP"] = str(chosen_step)
|
1080 |
+
|
1081 |
+
if req.sync_assets:
|
1082 |
+
os.environ["MRT_ASSETS_REPO"] = str(tgt["assets"])
|
1083 |
+
|
1084 |
+
# force rebuild
|
1085 |
+
global _MRT
|
1086 |
+
with _MRT_LOCK:
|
1087 |
+
_MRT = None
|
1088 |
+
|
1089 |
+
# optionally sync+load finetune assets
|
1090 |
+
if req.sync_assets:
|
1091 |
+
_load_finetune_assets_from_hf(os.getenv("MRT_ASSETS_REPO"))
|
1092 |
+
|
1093 |
+
# optional pre-warm to amortize JIT
|
1094 |
+
if req.prewarm:
|
1095 |
+
get_mrt() # triggers snapshot_download/resolve + init
|
1096 |
+
|
1097 |
+
return {"ok": True, **preview}
|
1098 |
+
except Exception as e:
|
1099 |
+
# rollback on error
|
1100 |
+
for k, v in old_env.items():
|
1101 |
+
if v is None:
|
1102 |
+
os.environ.pop(k, None)
|
1103 |
+
else:
|
1104 |
+
os.environ[k] = v
|
1105 |
+
with _MRT_LOCK:
|
1106 |
+
_MRT = None
|
1107 |
+
try:
|
1108 |
+
get_mrt()
|
1109 |
+
except Exception:
|
1110 |
+
pass
|
1111 |
+
raise HTTPException(status_code=500, detail=f"Swap failed: {e}")
|
1112 |
+
|
1113 |
+
|
1114 |
@app.post("/generate")
|
1115 |
def generate(
|
1116 |
loop_audio: UploadFile = File(...),
|