thecollabagepatch commited on
Commit
147fd8b
·
1 Parent(s): 406bd0f

attempting to use finetunes

Browse files
Files changed (1) hide show
  1. app.py +78 -1
app.py CHANGED
@@ -66,6 +66,51 @@ except Exception:
66
  class ClientDisconnected(Exception): # fallback
67
  pass
68
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
  async def send_json_safe(ws: WebSocket, obj) -> bool:
70
  """Try to send. Returns False if the socket is (or becomes) closed."""
71
  if ws.client_state == WebSocketState.DISCONNECTED or ws.application_state == WebSocketState.DISCONNECTED:
@@ -569,7 +614,14 @@ def get_mrt():
569
  if _MRT is None:
570
  with _MRT_LOCK:
571
  if _MRT is None:
572
- _MRT = system.MagentaRT(tag="large", guidance_weight=5.0, device="gpu", lazy=False)
 
 
 
 
 
 
 
573
  return _MRT
574
 
575
  _WARMED = False
@@ -648,6 +700,31 @@ def _kickoff_warmup():
648
  if os.getenv("MRT_WARMUP", "1") != "0":
649
  threading.Thread(target=_mrt_warmup, name="mrt-warmup", daemon=True).start()
650
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
651
  @app.post("/generate")
652
  def generate(
653
  loop_audio: UploadFile = File(...),
 
66
  class ClientDisconnected(Exception): # fallback
67
  pass
68
 
69
+ import re
70
+ from pathlib import Path
71
+
72
+ def _resolve_checkpoint_dir() -> str | None:
73
+ """
74
+ Returns a local directory path for MagentaRT(checkpoint_dir=...),
75
+ using a Hugging Face model repo that contains subfolders like:
76
+ checkpoint_1861001/, checkpoint_1862001/, ...
77
+ """
78
+ repo_id = os.getenv("MRT_CKPT_REPO")
79
+ if not repo_id:
80
+ return None # fall back to builtin 'base'/'large' assets
81
+
82
+ step = os.getenv("MRT_CKPT_STEP") # e.g., "1863001"
83
+ allow = None
84
+ if step:
85
+ # only pull that step + optional centroid files
86
+ allow = [f"checkpoint_{step}/**", "cluster_centroids.npy", "mean_style_embed.npy"]
87
+
88
+ from huggingface_hub import snapshot_download
89
+ local = snapshot_download(
90
+ repo_id=repo_id,
91
+ repo_type="model",
92
+ local_dir="/home/appuser/.cache/mrt_ckpt/repo",
93
+ local_dir_use_symlinks=False,
94
+ allow_patterns=allow or ["*"], # whole repo if no step provided
95
+ )
96
+ root = Path(local)
97
+
98
+ # If a step is specified, return that subfolder
99
+ if step:
100
+ cand = root / f"checkpoint_{step}"
101
+ if cand.is_dir():
102
+ return str(cand)
103
+
104
+ # Otherwise pick the numerically latest checkpoint_* folder
105
+ step_dirs = [d for d in root.iterdir() if d.is_dir() and re.match(r"checkpoint_\\d+$", d.name)]
106
+ if step_dirs:
107
+ pick = max(step_dirs, key=lambda d: int(d.name.split("_")[-1]))
108
+ return str(pick)
109
+
110
+ # Fallback: repo itself might already be a single checkpoint directory
111
+ return str(root)
112
+
113
+
114
  async def send_json_safe(ws: WebSocket, obj) -> bool:
115
  """Try to send. Returns False if the socket is (or becomes) closed."""
116
  if ws.client_state == WebSocketState.DISCONNECTED or ws.application_state == WebSocketState.DISCONNECTED:
 
614
  if _MRT is None:
615
  with _MRT_LOCK:
616
  if _MRT is None:
617
+ ckpt_dir = _resolve_checkpoint_dir() # points to checkpoint_1863001
618
+ _MRT = system.MagentaRT(
619
+ tag=os.getenv("MRT_SIZE", "large"), # keep 'large' if finetuned from large
620
+ guidance_weight=5.0,
621
+ device="gpu",
622
+ checkpoint_dir=ckpt_dir, # ← uses your finetune
623
+ lazy=False,
624
+ )
625
  return _MRT
626
 
627
  _WARMED = False
 
700
  if os.getenv("MRT_WARMUP", "1") != "0":
701
  threading.Thread(target=_mrt_warmup, name="mrt-warmup", daemon=True).start()
702
 
703
+ @app.get("/model/status")
704
+ def model_status():
705
+ mrt = get_mrt()
706
+ return {
707
+ "tag": getattr(mrt, "_tag", "unknown"),
708
+ "using_checkpoint_dir": True,
709
+ "codec_frame_rate": float(mrt.codec.frame_rate),
710
+ "decoder_rvq_depth": int(mrt.config.decoder_codec_rvq_depth),
711
+ "context_seconds": float(mrt.config.context_length),
712
+ "chunk_seconds": float(mrt.config.chunk_length),
713
+ "crossfade_seconds": float(mrt.config.crossfade_length),
714
+ "selected_step": os.getenv("MRT_CKPT_STEP"),
715
+ "repo": os.getenv("MRT_CKPT_REPO"),
716
+ }
717
+
718
+ @app.post("/model/swap")
719
+ def model_swap(step: int = Form(...)):
720
+ # stop any active jam if you want to be strict (not shown)
721
+ os.environ["MRT_CKPT_STEP"] = str(step)
722
+ global _MRT
723
+ with _MRT_LOCK:
724
+ _MRT = None # force re-create on next get_mrt()
725
+ # optionally pre-warm here by calling get_mrt()
726
+ return {"reloaded": True, "step": step}
727
+
728
  @app.post("/generate")
729
  def generate(
730
  loop_audio: UploadFile = File(...),