thecollabagepatch commited on
Commit
aba0837
·
1 Parent(s): 49b5fee

refactored model_select

Browse files
Files changed (1) hide show
  1. app.py +66 -25
app.py CHANGED
@@ -1000,7 +1000,7 @@ class ModelSelect(BaseModel):
1000
 
1001
  @app.post("/model/select")
1002
  def model_select(req: ModelSelect):
1003
- # Resolve desired target config (fall back to current env)
1004
  cur = {
1005
  "size": os.getenv("MRT_SIZE", "large"),
1006
  "repo": os.getenv("MRT_CKPT_REPO"),
@@ -1008,34 +1008,74 @@ def model_select(req: ModelSelect):
1008
  "step": os.getenv("MRT_CKPT_STEP"),
1009
  "assets": os.getenv("MRT_ASSETS_REPO", _FINETUNE_REPO_DEFAULT),
1010
  }
 
 
 
 
 
 
1011
  tgt = {
1012
- "size": req.size or cur["size"],
1013
- "repo": req.repo_id or cur["repo"],
1014
  "rev": (req.revision if req.revision is not None else cur["rev"]),
1015
- "step": (None if (isinstance(req.step, str) and req.step.lower()=="latest") else (str(req.step) if req.step is not None else cur["step"])),
 
1016
  "assets": (req.assets_repo_id or req.repo_id or cur["assets"]),
1017
  }
1018
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1019
  if not tgt["repo"]:
1020
- raise HTTPException(status_code=400, detail="repo_id is required at least once before selecting 'latest'.")
1021
 
1022
- # ---- Dry-run validation (no env changes) ----
1023
- # 1) enumerate steps
1024
  steps = _list_ckpt_steps(tgt["repo"], tgt["rev"])
1025
  if not steps:
1026
  return {"ok": False, "error": f"No checkpoint files found in {tgt['repo']}@{tgt['rev']}", "discovered_steps": steps}
1027
 
1028
- # 2) choose step
1029
  chosen_step = int(tgt["step"]) if tgt["step"] is not None else steps[-1]
1030
  if chosen_step not in steps:
1031
  return {"ok": False, "error": f"checkpoint_{chosen_step} not present in {tgt['repo']}@{tgt['rev']}", "discovered_steps": steps}
1032
 
1033
- # 3) optional: quick asset sanity (only list, don’t download)
1034
- assets_ok = True
1035
- assets_msg = "skipped"
1036
  if req.sync_assets:
1037
  try:
1038
- # a quick probe: ensure either file exists; if not, allow anyway (assets are optional)
1039
  api = HfApi()
1040
  files = set(api.list_repo_files(repo_id=tgt["assets"], repo_type="model"))
1041
  if ("mean_style_embed.npy" not in files) and ("cluster_centroids.npy" not in files):
@@ -1050,34 +1090,34 @@ def model_select(req: ModelSelect):
1050
  "target_repo": tgt["repo"],
1051
  "target_revision": tgt["rev"],
1052
  "target_step": chosen_step,
1053
- "assets_repo": tgt["assets"] if req.sync_assets else None,
1054
  "assets_probe": {"ok": assets_ok, "message": assets_msg},
1055
  "active_jam": _any_jam_running(),
1056
  }
 
1057
  if req.dry_run:
1058
  return {"ok": True, "dry_run": True, **preview}
1059
 
1060
- # ---- Enforce jam policy ----
1061
  if _any_jam_running():
1062
  if req.stop_active:
1063
  _stop_all_jams()
1064
  else:
1065
  raise HTTPException(status_code=409, detail="A jam is running; retry with stop_active=true")
1066
 
1067
- # ---- Atomic swap with rollback ----
1068
  old_env = {
1069
- "MRT_SIZE": os.getenv("MRT_SIZE"),
1070
- "MRT_CKPT_REPO": os.getenv("MRT_CKPT_REPO"),
1071
- "MRT_CKPT_REV": os.getenv("MRT_CKPT_REV"),
1072
- "MRT_CKPT_STEP": os.getenv("MRT_CKPT_STEP"),
1073
- "MRT_ASSETS_REPO": os.getenv("MRT_ASSETS_REPO"),
1074
  }
1075
  try:
1076
- os.environ["MRT_SIZE"] = str(tgt["size"])
1077
  os.environ["MRT_CKPT_REPO"] = str(tgt["repo"])
1078
  os.environ["MRT_CKPT_REV"] = str(tgt["rev"])
1079
  os.environ["MRT_CKPT_STEP"] = str(chosen_step)
1080
-
1081
  if req.sync_assets:
1082
  os.environ["MRT_ASSETS_REPO"] = str(tgt["assets"])
1083
 
@@ -1086,13 +1126,13 @@ def model_select(req: ModelSelect):
1086
  with _MRT_LOCK:
1087
  _MRT = None
1088
 
1089
- # optionally sync+load finetune assets
1090
  if req.sync_assets:
1091
  _load_finetune_assets_from_hf(os.getenv("MRT_ASSETS_REPO"))
1092
 
1093
- # optional pre-warm to amortize JIT
1094
  if req.prewarm:
1095
- get_mrt() # triggers snapshot_download/resolve + init
1096
 
1097
  return {"ok": True, **preview}
1098
  except Exception as e:
@@ -1111,6 +1151,7 @@ def model_select(req: ModelSelect):
1111
  raise HTTPException(status_code=500, detail=f"Swap failed: {e}")
1112
 
1113
 
 
1114
  @app.post("/generate")
1115
  def generate(
1116
  loop_audio: UploadFile = File(...),
 
1000
 
1001
  @app.post("/model/select")
1002
  def model_select(req: ModelSelect):
1003
+ # --- Current env defaults ---
1004
  cur = {
1005
  "size": os.getenv("MRT_SIZE", "large"),
1006
  "repo": os.getenv("MRT_CKPT_REPO"),
 
1008
  "step": os.getenv("MRT_CKPT_STEP"),
1009
  "assets": os.getenv("MRT_ASSETS_REPO", _FINETUNE_REPO_DEFAULT),
1010
  }
1011
+
1012
+ # --- Flags for special step values ---
1013
+ no_ckpt = isinstance(req.step, str) and req.step.lower() == "none"
1014
+ latest = isinstance(req.step, str) and req.step.lower() == "latest"
1015
+
1016
+ # --- Target selection (do not require repo when no_ckpt) ---
1017
  tgt = {
1018
+ "size": (req.size or cur["size"]),
1019
+ "repo": (None if no_ckpt else (req.repo_id or cur["repo"])),
1020
  "rev": (req.revision if req.revision is not None else cur["rev"]),
1021
+ # None => resolve to "latest" below. Keep None for no_ckpt as well.
1022
+ "step": (None if (no_ckpt or latest) else (str(req.step) if req.step is not None else cur["step"])),
1023
  "assets": (req.assets_repo_id or req.repo_id or cur["assets"]),
1024
  }
1025
 
1026
+ # ---------- CASE 1: run with NO FINETUNE (stock base/large) ----------
1027
+ if no_ckpt:
1028
+ preview = {
1029
+ "target_size": tgt["size"],
1030
+ "target_repo": None,
1031
+ "target_revision": None,
1032
+ "target_step": None,
1033
+ "assets_repo": None,
1034
+ "assets_probe": {"ok": True, "message": "skipped"},
1035
+ "active_jam": _any_jam_running(),
1036
+ }
1037
+ if req.dry_run:
1038
+ return {"ok": True, "dry_run": True, **preview}
1039
+
1040
+ # Jam policy
1041
+ if _any_jam_running():
1042
+ if req.stop_active:
1043
+ _stop_all_jams()
1044
+ else:
1045
+ raise HTTPException(status_code=409, detail="A jam is running; retry with stop_active=true")
1046
+
1047
+ # Clear checkpoint + asset env so get_mrt() uses stock weights
1048
+ for k in ("MRT_CKPT_REPO", "MRT_CKPT_REV", "MRT_CKPT_STEP", "MRT_ASSETS_REPO"):
1049
+ os.environ.pop(k, None)
1050
+ os.environ["MRT_SIZE"] = str(tgt["size"])
1051
+
1052
+ # Rebuild model and optionally prewarm
1053
+ global _MRT
1054
+ with _MRT_LOCK:
1055
+ _MRT = None
1056
+ if req.prewarm:
1057
+ get_mrt()
1058
+
1059
+ return {"ok": True, **preview}
1060
+
1061
+ # ---------- CASE 2: select a repo + step (supports "latest") ----------
1062
  if not tgt["repo"]:
1063
+ raise HTTPException(status_code=400, detail="repo_id is required for model selection.")
1064
 
1065
+ # 1) enumerate available steps
 
1066
  steps = _list_ckpt_steps(tgt["repo"], tgt["rev"])
1067
  if not steps:
1068
  return {"ok": False, "error": f"No checkpoint files found in {tgt['repo']}@{tgt['rev']}", "discovered_steps": steps}
1069
 
1070
+ # 2) choose step (explicit or latest)
1071
  chosen_step = int(tgt["step"]) if tgt["step"] is not None else steps[-1]
1072
  if chosen_step not in steps:
1073
  return {"ok": False, "error": f"checkpoint_{chosen_step} not present in {tgt['repo']}@{tgt['rev']}", "discovered_steps": steps}
1074
 
1075
+ # 3) optional finetune assets probe (no downloads, just listing)
1076
+ assets_ok, assets_msg = True, "skipped"
 
1077
  if req.sync_assets:
1078
  try:
 
1079
  api = HfApi()
1080
  files = set(api.list_repo_files(repo_id=tgt["assets"], repo_type="model"))
1081
  if ("mean_style_embed.npy" not in files) and ("cluster_centroids.npy" not in files):
 
1090
  "target_repo": tgt["repo"],
1091
  "target_revision": tgt["rev"],
1092
  "target_step": chosen_step,
1093
+ "assets_repo": (tgt["assets"] if req.sync_assets else None),
1094
  "assets_probe": {"ok": assets_ok, "message": assets_msg},
1095
  "active_jam": _any_jam_running(),
1096
  }
1097
+
1098
  if req.dry_run:
1099
  return {"ok": True, "dry_run": True, **preview}
1100
 
1101
+ # Jam policy
1102
  if _any_jam_running():
1103
  if req.stop_active:
1104
  _stop_all_jams()
1105
  else:
1106
  raise HTTPException(status_code=409, detail="A jam is running; retry with stop_active=true")
1107
 
1108
+ # 4) atomic swap with rollback
1109
  old_env = {
1110
+ "MRT_SIZE": os.getenv("MRT_SIZE"),
1111
+ "MRT_CKPT_REPO": os.getenv("MRT_CKPT_REPO"),
1112
+ "MRT_CKPT_REV": os.getenv("MRT_CKPT_REV"),
1113
+ "MRT_CKPT_STEP": os.getenv("MRT_CKPT_STEP"),
1114
+ "MRT_ASSETS_REPO": os.getenv("MRT_ASSETS_REPO"),
1115
  }
1116
  try:
1117
+ os.environ["MRT_SIZE"] = str(tgt["size"])
1118
  os.environ["MRT_CKPT_REPO"] = str(tgt["repo"])
1119
  os.environ["MRT_CKPT_REV"] = str(tgt["rev"])
1120
  os.environ["MRT_CKPT_STEP"] = str(chosen_step)
 
1121
  if req.sync_assets:
1122
  os.environ["MRT_ASSETS_REPO"] = str(tgt["assets"])
1123
 
 
1126
  with _MRT_LOCK:
1127
  _MRT = None
1128
 
1129
+ # optionally load finetune assets now
1130
  if req.sync_assets:
1131
  _load_finetune_assets_from_hf(os.getenv("MRT_ASSETS_REPO"))
1132
 
1133
+ # optional prewarm to amortize JIT
1134
  if req.prewarm:
1135
+ get_mrt()
1136
 
1137
  return {"ok": True, **preview}
1138
  except Exception as e:
 
1151
  raise HTTPException(status_code=500, detail=f"Swap failed: {e}")
1152
 
1153
 
1154
+
1155
  @app.post("/generate")
1156
  def generate(
1157
  loop_audio: UploadFile = File(...),