thecollabagepatch commited on
Commit
49b5fee
·
1 Parent(s): d373851

full model switching logic

Browse files
Files changed (1) hide show
  1. app.py +186 -2
app.py CHANGED
@@ -51,7 +51,7 @@ import uuid, threading
51
  import logging
52
 
53
  import gradio as gr
54
- from typing import Optional
55
 
56
 
57
  import json, asyncio, base64
@@ -68,7 +68,9 @@ except Exception:
68
 
69
  import re, tarfile
70
  from pathlib import Path
71
- from huggingface_hub import snapshot_download
 
 
72
 
73
  # ---- Finetune assets (mean & centroids) --------------------------------------
74
  _FINETUNE_REPO_DEFAULT = os.getenv("MRT_ASSETS_REPO", "thepatch/magenta-ft")
@@ -76,6 +78,43 @@ _ASSETS_REPO_ID: str | None = None
76
  _MEAN_EMBED: np.ndarray | None = None # shape (D,) dtype float32
77
  _CENTROIDS: np.ndarray | None = None # shape (K, D) dtype float32
78
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
  def _load_finetune_assets_from_hf(repo_id: str | None) -> tuple[bool, str]:
80
  """
81
  Download & load mean_style_embed.npy and cluster_centroids.npy from a HF model repo.
@@ -927,6 +966,151 @@ def model_assets_status():
927
  "embedding_dim": d,
928
  }
929
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
930
  @app.post("/generate")
931
  def generate(
932
  loop_audio: UploadFile = File(...),
 
51
  import logging
52
 
53
  import gradio as gr
54
+ from typing import Optional, Union, Literal
55
 
56
 
57
  import json, asyncio, base64
 
68
 
69
  import re, tarfile
70
  from pathlib import Path
71
+ from huggingface_hub import snapshot_download, HfApi
72
+
73
+ from pydantic import BaseModel
74
 
75
  # ---- Finetune assets (mean & centroids) --------------------------------------
76
  _FINETUNE_REPO_DEFAULT = os.getenv("MRT_ASSETS_REPO", "thepatch/magenta-ft")
 
78
  _MEAN_EMBED: np.ndarray | None = None # shape (D,) dtype float32
79
  _CENTROIDS: np.ndarray | None = None # shape (K, D) dtype float32
80
 
81
+ _STEP_RE = re.compile(r"(?:^|/)checkpoint_(\d+)(?:/|\.tar\.gz|\.tgz)?$")
82
+
83
+ def _list_ckpt_steps(repo_id: str, revision: str = "main") -> list[int]:
84
+ """
85
+ List available checkpoint steps in a HF model repo without downloading all weights.
86
+ Looks for:
87
+ checkpoint_<step>/
88
+ checkpoint_<step>.tgz | .tar.gz
89
+ archives/checkpoint_<step>.tgz | .tar.gz
90
+ """
91
+ api = HfApi()
92
+ files = api.list_repo_files(repo_id=repo_id, repo_type="model", revision=revision)
93
+ steps = set()
94
+ for f in files:
95
+ m = _STEP_RE.search(f)
96
+ if m:
97
+ try:
98
+ steps.add(int(m.group(1)))
99
+ except:
100
+ pass
101
+ return sorted(steps)
102
+
103
+ def _step_exists(repo_id: str, revision: str, step: int) -> bool:
104
+ return step in _list_ckpt_steps(repo_id, revision)
105
+
106
+ def _any_jam_running() -> bool:
107
+ with jam_lock:
108
+ return any(w.is_alive() for w in jam_registry.values())
109
+
110
+ def _stop_all_jams(timeout: float = 5.0):
111
+ with jam_lock:
112
+ for sid, w in list(jam_registry.items()):
113
+ if w.is_alive():
114
+ w.stop()
115
+ w.join(timeout=timeout)
116
+ jam_registry.pop(sid, None)
117
+
118
  def _load_finetune_assets_from_hf(repo_id: str | None) -> tuple[bool, str]:
119
  """
120
  Download & load mean_style_embed.npy and cluster_centroids.npy from a HF model repo.
 
966
  "embedding_dim": d,
967
  }
968
 
969
+ @app.get("/model/config")
970
+ def model_config():
971
+ mrt = None
972
+ try:
973
+ mrt = get_mrt()
974
+ except Exception:
975
+ pass
976
+ return {
977
+ "size": os.getenv("MRT_SIZE", "large"),
978
+ "repo": os.getenv("MRT_CKPT_REPO"),
979
+ "revision": os.getenv("MRT_CKPT_REV", "main"),
980
+ "selected_step": os.getenv("MRT_CKPT_STEP"),
981
+ "resolved_ckpt_dir": _resolve_checkpoint_dir(), # may be None if not yet downloaded
982
+ "loaded": bool(mrt),
983
+ }
984
+
985
+ @app.get("/model/checkpoints")
986
+ def model_checkpoints(repo_id: str, revision: str = "main"):
987
+ steps = _list_ckpt_steps(repo_id, revision)
988
+ return {"repo": repo_id, "revision": revision, "steps": steps, "latest": (steps[-1] if steps else None)}
989
+
990
+ class ModelSelect(BaseModel):
991
+ size: Optional[Literal["base","large"]] = None
992
+ repo_id: Optional[str] = None
993
+ revision: Optional[str] = "main"
994
+ step: Optional[Union[int, str]] = None # allow "latest"
995
+ assets_repo_id: Optional[str] = None # default: follow repo_id
996
+ sync_assets: bool = True # load mean/centroids from repo
997
+ prewarm: bool = False # call get_mrt() to build right away
998
+ stop_active: bool = True # auto-stop jams; else 409
999
+ dry_run: bool = False # validate only, don't swap
1000
+
1001
+ @app.post("/model/select")
1002
+ def model_select(req: ModelSelect):
1003
+ # Resolve desired target config (fall back to current env)
1004
+ cur = {
1005
+ "size": os.getenv("MRT_SIZE", "large"),
1006
+ "repo": os.getenv("MRT_CKPT_REPO"),
1007
+ "rev": os.getenv("MRT_CKPT_REV", "main"),
1008
+ "step": os.getenv("MRT_CKPT_STEP"),
1009
+ "assets": os.getenv("MRT_ASSETS_REPO", _FINETUNE_REPO_DEFAULT),
1010
+ }
1011
+ tgt = {
1012
+ "size": req.size or cur["size"],
1013
+ "repo": req.repo_id or cur["repo"],
1014
+ "rev": (req.revision if req.revision is not None else cur["rev"]),
1015
+ "step": (None if (isinstance(req.step, str) and req.step.lower()=="latest") else (str(req.step) if req.step is not None else cur["step"])),
1016
+ "assets": (req.assets_repo_id or req.repo_id or cur["assets"]),
1017
+ }
1018
+
1019
+ if not tgt["repo"]:
1020
+ raise HTTPException(status_code=400, detail="repo_id is required at least once before selecting 'latest'.")
1021
+
1022
+ # ---- Dry-run validation (no env changes) ----
1023
+ # 1) enumerate steps
1024
+ steps = _list_ckpt_steps(tgt["repo"], tgt["rev"])
1025
+ if not steps:
1026
+ return {"ok": False, "error": f"No checkpoint files found in {tgt['repo']}@{tgt['rev']}", "discovered_steps": steps}
1027
+
1028
+ # 2) choose step
1029
+ chosen_step = int(tgt["step"]) if tgt["step"] is not None else steps[-1]
1030
+ if chosen_step not in steps:
1031
+ return {"ok": False, "error": f"checkpoint_{chosen_step} not present in {tgt['repo']}@{tgt['rev']}", "discovered_steps": steps}
1032
+
1033
+ # 3) optional: quick asset sanity (only list, don’t download)
1034
+ assets_ok = True
1035
+ assets_msg = "skipped"
1036
+ if req.sync_assets:
1037
+ try:
1038
+ # a quick probe: ensure either file exists; if not, allow anyway (assets are optional)
1039
+ api = HfApi()
1040
+ files = set(api.list_repo_files(repo_id=tgt["assets"], repo_type="model"))
1041
+ if ("mean_style_embed.npy" not in files) and ("cluster_centroids.npy" not in files):
1042
+ assets_ok, assets_msg = False, f"No finetune asset files in {tgt['assets']}"
1043
+ else:
1044
+ assets_msg = "found"
1045
+ except Exception as e:
1046
+ assets_ok, assets_msg = False, f"probe failed: {e}"
1047
+
1048
+ preview = {
1049
+ "target_size": tgt["size"],
1050
+ "target_repo": tgt["repo"],
1051
+ "target_revision": tgt["rev"],
1052
+ "target_step": chosen_step,
1053
+ "assets_repo": tgt["assets"] if req.sync_assets else None,
1054
+ "assets_probe": {"ok": assets_ok, "message": assets_msg},
1055
+ "active_jam": _any_jam_running(),
1056
+ }
1057
+ if req.dry_run:
1058
+ return {"ok": True, "dry_run": True, **preview}
1059
+
1060
+ # ---- Enforce jam policy ----
1061
+ if _any_jam_running():
1062
+ if req.stop_active:
1063
+ _stop_all_jams()
1064
+ else:
1065
+ raise HTTPException(status_code=409, detail="A jam is running; retry with stop_active=true")
1066
+
1067
+ # ---- Atomic swap with rollback ----
1068
+ old_env = {
1069
+ "MRT_SIZE": os.getenv("MRT_SIZE"),
1070
+ "MRT_CKPT_REPO": os.getenv("MRT_CKPT_REPO"),
1071
+ "MRT_CKPT_REV": os.getenv("MRT_CKPT_REV"),
1072
+ "MRT_CKPT_STEP": os.getenv("MRT_CKPT_STEP"),
1073
+ "MRT_ASSETS_REPO": os.getenv("MRT_ASSETS_REPO"),
1074
+ }
1075
+ try:
1076
+ os.environ["MRT_SIZE"] = str(tgt["size"])
1077
+ os.environ["MRT_CKPT_REPO"] = str(tgt["repo"])
1078
+ os.environ["MRT_CKPT_REV"] = str(tgt["rev"])
1079
+ os.environ["MRT_CKPT_STEP"] = str(chosen_step)
1080
+
1081
+ if req.sync_assets:
1082
+ os.environ["MRT_ASSETS_REPO"] = str(tgt["assets"])
1083
+
1084
+ # force rebuild
1085
+ global _MRT
1086
+ with _MRT_LOCK:
1087
+ _MRT = None
1088
+
1089
+ # optionally sync+load finetune assets
1090
+ if req.sync_assets:
1091
+ _load_finetune_assets_from_hf(os.getenv("MRT_ASSETS_REPO"))
1092
+
1093
+ # optional pre-warm to amortize JIT
1094
+ if req.prewarm:
1095
+ get_mrt() # triggers snapshot_download/resolve + init
1096
+
1097
+ return {"ok": True, **preview}
1098
+ except Exception as e:
1099
+ # rollback on error
1100
+ for k, v in old_env.items():
1101
+ if v is None:
1102
+ os.environ.pop(k, None)
1103
+ else:
1104
+ os.environ[k] = v
1105
+ with _MRT_LOCK:
1106
+ _MRT = None
1107
+ try:
1108
+ get_mrt()
1109
+ except Exception:
1110
+ pass
1111
+ raise HTTPException(status_code=500, detail=f"Swap failed: {e}")
1112
+
1113
+
1114
  @app.post("/generate")
1115
  def generate(
1116
  loop_audio: UploadFile = File(...),