Commit
·
aba0837
1
Parent(s):
49b5fee
refactored model_select
Browse files
app.py
CHANGED
@@ -1000,7 +1000,7 @@ class ModelSelect(BaseModel):
|
|
1000 |
|
1001 |
@app.post("/model/select")
|
1002 |
def model_select(req: ModelSelect):
|
1003 |
-
#
|
1004 |
cur = {
|
1005 |
"size": os.getenv("MRT_SIZE", "large"),
|
1006 |
"repo": os.getenv("MRT_CKPT_REPO"),
|
@@ -1008,34 +1008,74 @@ def model_select(req: ModelSelect):
|
|
1008 |
"step": os.getenv("MRT_CKPT_STEP"),
|
1009 |
"assets": os.getenv("MRT_ASSETS_REPO", _FINETUNE_REPO_DEFAULT),
|
1010 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
1011 |
tgt = {
|
1012 |
-
"size": req.size or cur["size"],
|
1013 |
-
"repo": req.repo_id or cur["repo"],
|
1014 |
"rev": (req.revision if req.revision is not None else cur["rev"]),
|
1015 |
-
|
|
|
1016 |
"assets": (req.assets_repo_id or req.repo_id or cur["assets"]),
|
1017 |
}
|
1018 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1019 |
if not tgt["repo"]:
|
1020 |
-
raise HTTPException(status_code=400, detail="repo_id is required
|
1021 |
|
1022 |
-
#
|
1023 |
-
# 1) enumerate steps
|
1024 |
steps = _list_ckpt_steps(tgt["repo"], tgt["rev"])
|
1025 |
if not steps:
|
1026 |
return {"ok": False, "error": f"No checkpoint files found in {tgt['repo']}@{tgt['rev']}", "discovered_steps": steps}
|
1027 |
|
1028 |
-
# 2) choose step
|
1029 |
chosen_step = int(tgt["step"]) if tgt["step"] is not None else steps[-1]
|
1030 |
if chosen_step not in steps:
|
1031 |
return {"ok": False, "error": f"checkpoint_{chosen_step} not present in {tgt['repo']}@{tgt['rev']}", "discovered_steps": steps}
|
1032 |
|
1033 |
-
# 3) optional
|
1034 |
-
assets_ok = True
|
1035 |
-
assets_msg = "skipped"
|
1036 |
if req.sync_assets:
|
1037 |
try:
|
1038 |
-
# a quick probe: ensure either file exists; if not, allow anyway (assets are optional)
|
1039 |
api = HfApi()
|
1040 |
files = set(api.list_repo_files(repo_id=tgt["assets"], repo_type="model"))
|
1041 |
if ("mean_style_embed.npy" not in files) and ("cluster_centroids.npy" not in files):
|
@@ -1050,34 +1090,34 @@ def model_select(req: ModelSelect):
|
|
1050 |
"target_repo": tgt["repo"],
|
1051 |
"target_revision": tgt["rev"],
|
1052 |
"target_step": chosen_step,
|
1053 |
-
"assets_repo": tgt["assets"] if req.sync_assets else None,
|
1054 |
"assets_probe": {"ok": assets_ok, "message": assets_msg},
|
1055 |
"active_jam": _any_jam_running(),
|
1056 |
}
|
|
|
1057 |
if req.dry_run:
|
1058 |
return {"ok": True, "dry_run": True, **preview}
|
1059 |
|
1060 |
-
#
|
1061 |
if _any_jam_running():
|
1062 |
if req.stop_active:
|
1063 |
_stop_all_jams()
|
1064 |
else:
|
1065 |
raise HTTPException(status_code=409, detail="A jam is running; retry with stop_active=true")
|
1066 |
|
1067 |
-
#
|
1068 |
old_env = {
|
1069 |
-
"MRT_SIZE":
|
1070 |
-
"MRT_CKPT_REPO":
|
1071 |
-
"MRT_CKPT_REV":
|
1072 |
-
"MRT_CKPT_STEP":
|
1073 |
-
"MRT_ASSETS_REPO":
|
1074 |
}
|
1075 |
try:
|
1076 |
-
os.environ["MRT_SIZE"]
|
1077 |
os.environ["MRT_CKPT_REPO"] = str(tgt["repo"])
|
1078 |
os.environ["MRT_CKPT_REV"] = str(tgt["rev"])
|
1079 |
os.environ["MRT_CKPT_STEP"] = str(chosen_step)
|
1080 |
-
|
1081 |
if req.sync_assets:
|
1082 |
os.environ["MRT_ASSETS_REPO"] = str(tgt["assets"])
|
1083 |
|
@@ -1086,13 +1126,13 @@ def model_select(req: ModelSelect):
|
|
1086 |
with _MRT_LOCK:
|
1087 |
_MRT = None
|
1088 |
|
1089 |
-
# optionally
|
1090 |
if req.sync_assets:
|
1091 |
_load_finetune_assets_from_hf(os.getenv("MRT_ASSETS_REPO"))
|
1092 |
|
1093 |
-
# optional
|
1094 |
if req.prewarm:
|
1095 |
-
get_mrt()
|
1096 |
|
1097 |
return {"ok": True, **preview}
|
1098 |
except Exception as e:
|
@@ -1111,6 +1151,7 @@ def model_select(req: ModelSelect):
|
|
1111 |
raise HTTPException(status_code=500, detail=f"Swap failed: {e}")
|
1112 |
|
1113 |
|
|
|
1114 |
@app.post("/generate")
|
1115 |
def generate(
|
1116 |
loop_audio: UploadFile = File(...),
|
|
|
1000 |
|
1001 |
@app.post("/model/select")
|
1002 |
def model_select(req: ModelSelect):
|
1003 |
+
# --- Current env defaults ---
|
1004 |
cur = {
|
1005 |
"size": os.getenv("MRT_SIZE", "large"),
|
1006 |
"repo": os.getenv("MRT_CKPT_REPO"),
|
|
|
1008 |
"step": os.getenv("MRT_CKPT_STEP"),
|
1009 |
"assets": os.getenv("MRT_ASSETS_REPO", _FINETUNE_REPO_DEFAULT),
|
1010 |
}
|
1011 |
+
|
1012 |
+
# --- Flags for special step values ---
|
1013 |
+
no_ckpt = isinstance(req.step, str) and req.step.lower() == "none"
|
1014 |
+
latest = isinstance(req.step, str) and req.step.lower() == "latest"
|
1015 |
+
|
1016 |
+
# --- Target selection (do not require repo when no_ckpt) ---
|
1017 |
tgt = {
|
1018 |
+
"size": (req.size or cur["size"]),
|
1019 |
+
"repo": (None if no_ckpt else (req.repo_id or cur["repo"])),
|
1020 |
"rev": (req.revision if req.revision is not None else cur["rev"]),
|
1021 |
+
# None => resolve to "latest" below. Keep None for no_ckpt as well.
|
1022 |
+
"step": (None if (no_ckpt or latest) else (str(req.step) if req.step is not None else cur["step"])),
|
1023 |
"assets": (req.assets_repo_id or req.repo_id or cur["assets"]),
|
1024 |
}
|
1025 |
|
1026 |
+
# ---------- CASE 1: run with NO FINETUNE (stock base/large) ----------
|
1027 |
+
if no_ckpt:
|
1028 |
+
preview = {
|
1029 |
+
"target_size": tgt["size"],
|
1030 |
+
"target_repo": None,
|
1031 |
+
"target_revision": None,
|
1032 |
+
"target_step": None,
|
1033 |
+
"assets_repo": None,
|
1034 |
+
"assets_probe": {"ok": True, "message": "skipped"},
|
1035 |
+
"active_jam": _any_jam_running(),
|
1036 |
+
}
|
1037 |
+
if req.dry_run:
|
1038 |
+
return {"ok": True, "dry_run": True, **preview}
|
1039 |
+
|
1040 |
+
# Jam policy
|
1041 |
+
if _any_jam_running():
|
1042 |
+
if req.stop_active:
|
1043 |
+
_stop_all_jams()
|
1044 |
+
else:
|
1045 |
+
raise HTTPException(status_code=409, detail="A jam is running; retry with stop_active=true")
|
1046 |
+
|
1047 |
+
# Clear checkpoint + asset env so get_mrt() uses stock weights
|
1048 |
+
for k in ("MRT_CKPT_REPO", "MRT_CKPT_REV", "MRT_CKPT_STEP", "MRT_ASSETS_REPO"):
|
1049 |
+
os.environ.pop(k, None)
|
1050 |
+
os.environ["MRT_SIZE"] = str(tgt["size"])
|
1051 |
+
|
1052 |
+
# Rebuild model and optionally prewarm
|
1053 |
+
global _MRT
|
1054 |
+
with _MRT_LOCK:
|
1055 |
+
_MRT = None
|
1056 |
+
if req.prewarm:
|
1057 |
+
get_mrt()
|
1058 |
+
|
1059 |
+
return {"ok": True, **preview}
|
1060 |
+
|
1061 |
+
# ---------- CASE 2: select a repo + step (supports "latest") ----------
|
1062 |
if not tgt["repo"]:
|
1063 |
+
raise HTTPException(status_code=400, detail="repo_id is required for model selection.")
|
1064 |
|
1065 |
+
# 1) enumerate available steps
|
|
|
1066 |
steps = _list_ckpt_steps(tgt["repo"], tgt["rev"])
|
1067 |
if not steps:
|
1068 |
return {"ok": False, "error": f"No checkpoint files found in {tgt['repo']}@{tgt['rev']}", "discovered_steps": steps}
|
1069 |
|
1070 |
+
# 2) choose step (explicit or latest)
|
1071 |
chosen_step = int(tgt["step"]) if tgt["step"] is not None else steps[-1]
|
1072 |
if chosen_step not in steps:
|
1073 |
return {"ok": False, "error": f"checkpoint_{chosen_step} not present in {tgt['repo']}@{tgt['rev']}", "discovered_steps": steps}
|
1074 |
|
1075 |
+
# 3) optional finetune assets probe (no downloads, just listing)
|
1076 |
+
assets_ok, assets_msg = True, "skipped"
|
|
|
1077 |
if req.sync_assets:
|
1078 |
try:
|
|
|
1079 |
api = HfApi()
|
1080 |
files = set(api.list_repo_files(repo_id=tgt["assets"], repo_type="model"))
|
1081 |
if ("mean_style_embed.npy" not in files) and ("cluster_centroids.npy" not in files):
|
|
|
1090 |
"target_repo": tgt["repo"],
|
1091 |
"target_revision": tgt["rev"],
|
1092 |
"target_step": chosen_step,
|
1093 |
+
"assets_repo": (tgt["assets"] if req.sync_assets else None),
|
1094 |
"assets_probe": {"ok": assets_ok, "message": assets_msg},
|
1095 |
"active_jam": _any_jam_running(),
|
1096 |
}
|
1097 |
+
|
1098 |
if req.dry_run:
|
1099 |
return {"ok": True, "dry_run": True, **preview}
|
1100 |
|
1101 |
+
# Jam policy
|
1102 |
if _any_jam_running():
|
1103 |
if req.stop_active:
|
1104 |
_stop_all_jams()
|
1105 |
else:
|
1106 |
raise HTTPException(status_code=409, detail="A jam is running; retry with stop_active=true")
|
1107 |
|
1108 |
+
# 4) atomic swap with rollback
|
1109 |
old_env = {
|
1110 |
+
"MRT_SIZE": os.getenv("MRT_SIZE"),
|
1111 |
+
"MRT_CKPT_REPO": os.getenv("MRT_CKPT_REPO"),
|
1112 |
+
"MRT_CKPT_REV": os.getenv("MRT_CKPT_REV"),
|
1113 |
+
"MRT_CKPT_STEP": os.getenv("MRT_CKPT_STEP"),
|
1114 |
+
"MRT_ASSETS_REPO": os.getenv("MRT_ASSETS_REPO"),
|
1115 |
}
|
1116 |
try:
|
1117 |
+
os.environ["MRT_SIZE"] = str(tgt["size"])
|
1118 |
os.environ["MRT_CKPT_REPO"] = str(tgt["repo"])
|
1119 |
os.environ["MRT_CKPT_REV"] = str(tgt["rev"])
|
1120 |
os.environ["MRT_CKPT_STEP"] = str(chosen_step)
|
|
|
1121 |
if req.sync_assets:
|
1122 |
os.environ["MRT_ASSETS_REPO"] = str(tgt["assets"])
|
1123 |
|
|
|
1126 |
with _MRT_LOCK:
|
1127 |
_MRT = None
|
1128 |
|
1129 |
+
# optionally load finetune assets now
|
1130 |
if req.sync_assets:
|
1131 |
_load_finetune_assets_from_hf(os.getenv("MRT_ASSETS_REPO"))
|
1132 |
|
1133 |
+
# optional prewarm to amortize JIT
|
1134 |
if req.prewarm:
|
1135 |
+
get_mrt()
|
1136 |
|
1137 |
return {"ok": True, **preview}
|
1138 |
except Exception as e:
|
|
|
1151 |
raise HTTPException(status_code=500, detail=f"Swap failed: {e}")
|
1152 |
|
1153 |
|
1154 |
+
|
1155 |
@app.post("/generate")
|
1156 |
def generate(
|
1157 |
loop_audio: UploadFile = File(...),
|