Commit
·
147fd8b
1
Parent(s):
406bd0f
attempting to use finetunes
Browse files
app.py
CHANGED
@@ -66,6 +66,51 @@ except Exception:
|
|
66 |
class ClientDisconnected(Exception): # fallback
|
67 |
pass
|
68 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
69 |
async def send_json_safe(ws: WebSocket, obj) -> bool:
|
70 |
"""Try to send. Returns False if the socket is (or becomes) closed."""
|
71 |
if ws.client_state == WebSocketState.DISCONNECTED or ws.application_state == WebSocketState.DISCONNECTED:
|
@@ -569,7 +614,14 @@ def get_mrt():
|
|
569 |
if _MRT is None:
|
570 |
with _MRT_LOCK:
|
571 |
if _MRT is None:
|
572 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
573 |
return _MRT
|
574 |
|
575 |
_WARMED = False
|
@@ -648,6 +700,31 @@ def _kickoff_warmup():
|
|
648 |
if os.getenv("MRT_WARMUP", "1") != "0":
|
649 |
threading.Thread(target=_mrt_warmup, name="mrt-warmup", daemon=True).start()
|
650 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
651 |
@app.post("/generate")
|
652 |
def generate(
|
653 |
loop_audio: UploadFile = File(...),
|
|
|
66 |
class ClientDisconnected(Exception): # fallback
|
67 |
pass
|
68 |
|
69 |
+
import re
|
70 |
+
from pathlib import Path
|
71 |
+
|
72 |
+
def _resolve_checkpoint_dir() -> str | None:
|
73 |
+
"""
|
74 |
+
Returns a local directory path for MagentaRT(checkpoint_dir=...),
|
75 |
+
using a Hugging Face model repo that contains subfolders like:
|
76 |
+
checkpoint_1861001/, checkpoint_1862001/, ...
|
77 |
+
"""
|
78 |
+
repo_id = os.getenv("MRT_CKPT_REPO")
|
79 |
+
if not repo_id:
|
80 |
+
return None # fall back to builtin 'base'/'large' assets
|
81 |
+
|
82 |
+
step = os.getenv("MRT_CKPT_STEP") # e.g., "1863001"
|
83 |
+
allow = None
|
84 |
+
if step:
|
85 |
+
# only pull that step + optional centroid files
|
86 |
+
allow = [f"checkpoint_{step}/**", "cluster_centroids.npy", "mean_style_embed.npy"]
|
87 |
+
|
88 |
+
from huggingface_hub import snapshot_download
|
89 |
+
local = snapshot_download(
|
90 |
+
repo_id=repo_id,
|
91 |
+
repo_type="model",
|
92 |
+
local_dir="/home/appuser/.cache/mrt_ckpt/repo",
|
93 |
+
local_dir_use_symlinks=False,
|
94 |
+
allow_patterns=allow or ["*"], # whole repo if no step provided
|
95 |
+
)
|
96 |
+
root = Path(local)
|
97 |
+
|
98 |
+
# If a step is specified, return that subfolder
|
99 |
+
if step:
|
100 |
+
cand = root / f"checkpoint_{step}"
|
101 |
+
if cand.is_dir():
|
102 |
+
return str(cand)
|
103 |
+
|
104 |
+
# Otherwise pick the numerically latest checkpoint_* folder
|
105 |
+
step_dirs = [d for d in root.iterdir() if d.is_dir() and re.match(r"checkpoint_\\d+$", d.name)]
|
106 |
+
if step_dirs:
|
107 |
+
pick = max(step_dirs, key=lambda d: int(d.name.split("_")[-1]))
|
108 |
+
return str(pick)
|
109 |
+
|
110 |
+
# Fallback: repo itself might already be a single checkpoint directory
|
111 |
+
return str(root)
|
112 |
+
|
113 |
+
|
114 |
async def send_json_safe(ws: WebSocket, obj) -> bool:
|
115 |
"""Try to send. Returns False if the socket is (or becomes) closed."""
|
116 |
if ws.client_state == WebSocketState.DISCONNECTED or ws.application_state == WebSocketState.DISCONNECTED:
|
|
|
614 |
if _MRT is None:
|
615 |
with _MRT_LOCK:
|
616 |
if _MRT is None:
|
617 |
+
ckpt_dir = _resolve_checkpoint_dir() # ← points to checkpoint_1863001
|
618 |
+
_MRT = system.MagentaRT(
|
619 |
+
tag=os.getenv("MRT_SIZE", "large"), # keep 'large' if finetuned from large
|
620 |
+
guidance_weight=5.0,
|
621 |
+
device="gpu",
|
622 |
+
checkpoint_dir=ckpt_dir, # ← uses your finetune
|
623 |
+
lazy=False,
|
624 |
+
)
|
625 |
return _MRT
|
626 |
|
627 |
_WARMED = False
|
|
|
700 |
if os.getenv("MRT_WARMUP", "1") != "0":
|
701 |
threading.Thread(target=_mrt_warmup, name="mrt-warmup", daemon=True).start()
|
702 |
|
703 |
+
@app.get("/model/status")
|
704 |
+
def model_status():
|
705 |
+
mrt = get_mrt()
|
706 |
+
return {
|
707 |
+
"tag": getattr(mrt, "_tag", "unknown"),
|
708 |
+
"using_checkpoint_dir": True,
|
709 |
+
"codec_frame_rate": float(mrt.codec.frame_rate),
|
710 |
+
"decoder_rvq_depth": int(mrt.config.decoder_codec_rvq_depth),
|
711 |
+
"context_seconds": float(mrt.config.context_length),
|
712 |
+
"chunk_seconds": float(mrt.config.chunk_length),
|
713 |
+
"crossfade_seconds": float(mrt.config.crossfade_length),
|
714 |
+
"selected_step": os.getenv("MRT_CKPT_STEP"),
|
715 |
+
"repo": os.getenv("MRT_CKPT_REPO"),
|
716 |
+
}
|
717 |
+
|
718 |
+
@app.post("/model/swap")
|
719 |
+
def model_swap(step: int = Form(...)):
|
720 |
+
# stop any active jam if you want to be strict (not shown)
|
721 |
+
os.environ["MRT_CKPT_STEP"] = str(step)
|
722 |
+
global _MRT
|
723 |
+
with _MRT_LOCK:
|
724 |
+
_MRT = None # force re-create on next get_mrt()
|
725 |
+
# optionally pre-warm here by calling get_mrt()
|
726 |
+
return {"reloaded": True, "step": step}
|
727 |
+
|
728 |
@app.post("/generate")
|
729 |
def generate(
|
730 |
loop_audio: UploadFile = File(...),
|