Commit
Β·
30fdbbc
1
Parent(s):
384e4ac
updating docs a bit
Browse files- app.py +30 -179
- documentation.html +302 -43
app.py
CHANGED
@@ -77,45 +77,22 @@ from pydantic import BaseModel
|
|
77 |
from model_management import CheckpointManager, AssetManager, ModelSelector, ModelSelect
|
78 |
|
79 |
# ---- Finetune assets (mean & centroids) --------------------------------------
|
80 |
-
_FINETUNE_REPO_DEFAULT = os.getenv("MRT_ASSETS_REPO", "thepatch/magenta-ft")
|
81 |
_ASSETS_REPO_ID: str | None = None
|
82 |
_MEAN_EMBED: np.ndarray | None = None # shape (D,) dtype float32
|
83 |
_CENTROIDS: np.ndarray | None = None # shape (K, D) dtype float32
|
84 |
|
85 |
-
_STEP_RE = re.compile(r"(?:^|/)checkpoint_(\d+)(?:/|\.tar\.gz|\.tgz)?$")
|
86 |
|
87 |
# Create instances (these don't modify globals)
|
88 |
asset_manager = AssetManager()
|
89 |
model_selector = ModelSelector(CheckpointManager(), asset_manager)
|
90 |
|
91 |
# Sync asset manager with existing globals
|
92 |
-
def _sync_asset_manager():
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
# def _list_ckpt_steps(repo_id: str, revision: str = "main") -> list[int]:
|
98 |
-
# """
|
99 |
-
# List available checkpoint steps in a HF model repo without downloading all weights.
|
100 |
-
# Looks for:
|
101 |
-
# checkpoint_<step>/
|
102 |
-
# checkpoint_<step>.tgz | .tar.gz
|
103 |
-
# archives/checkpoint_<step>.tgz | .tar.gz
|
104 |
-
# """
|
105 |
-
# api = HfApi()
|
106 |
-
# files = api.list_repo_files(repo_id=repo_id, repo_type="model", revision=revision)
|
107 |
-
# steps = set()
|
108 |
-
# for f in files:
|
109 |
-
# m = _STEP_RE.search(f)
|
110 |
-
# if m:
|
111 |
-
# try:
|
112 |
-
# steps.add(int(m.group(1)))
|
113 |
-
# except:
|
114 |
-
# pass
|
115 |
-
# return sorted(steps)
|
116 |
-
|
117 |
-
# def _step_exists(repo_id: str, revision: str, step: int) -> bool:
|
118 |
-
# return step in _list_ckpt_steps(repo_id, revision)
|
119 |
|
120 |
def _any_jam_running() -> bool:
|
121 |
with jam_lock:
|
@@ -129,132 +106,6 @@ def _stop_all_jams(timeout: float = 5.0):
|
|
129 |
w.join(timeout=timeout)
|
130 |
jam_registry.pop(sid, None)
|
131 |
|
132 |
-
# def _load_finetune_assets_from_hf(repo_id: str | None) -> tuple[bool, str]:
|
133 |
-
# """
|
134 |
-
# Download & load mean_style_embed.npy and cluster_centroids.npy from a HF model repo.
|
135 |
-
# Safe to call multiple times; will overwrite globals if successful.
|
136 |
-
# """
|
137 |
-
# global _ASSETS_REPO_ID, _MEAN_EMBED, _CENTROIDS
|
138 |
-
# repo_id = repo_id or _FINETUNE_REPO_DEFAULT
|
139 |
-
# try:
|
140 |
-
# from huggingface_hub import hf_hub_download
|
141 |
-
# mean_path = None
|
142 |
-
# cent_path = None
|
143 |
-
# try:
|
144 |
-
# mean_path = hf_hub_download(repo_id, filename="mean_style_embed.npy", repo_type="model")
|
145 |
-
# except Exception:
|
146 |
-
# pass
|
147 |
-
# try:
|
148 |
-
# cent_path = hf_hub_download(repo_id, filename="cluster_centroids.npy", repo_type="model")
|
149 |
-
# except Exception:
|
150 |
-
# pass
|
151 |
-
|
152 |
-
# if mean_path is None and cent_path is None:
|
153 |
-
# return False, f"No finetune asset files found in repo {repo_id}"
|
154 |
-
|
155 |
-
# if mean_path is not None:
|
156 |
-
# m = np.load(mean_path)
|
157 |
-
# if m.ndim != 1:
|
158 |
-
# return False, f"mean_style_embed.npy must be 1-D (got {m.shape})"
|
159 |
-
# else:
|
160 |
-
# m = None
|
161 |
-
|
162 |
-
# if cent_path is not None:
|
163 |
-
# c = np.load(cent_path)
|
164 |
-
# if c.ndim != 2:
|
165 |
-
# return False, f"cluster_centroids.npy must be 2-D (got {c.shape})"
|
166 |
-
# else:
|
167 |
-
# c = None
|
168 |
-
|
169 |
-
# # Optional: shape check vs model embedding dim once model is alive
|
170 |
-
# try:
|
171 |
-
# d = int(get_mrt().style_model.config.embedding_dim)
|
172 |
-
# if m is not None and m.shape[0] != d:
|
173 |
-
# return False, f"mean_style_embed dim {m.shape[0]} != model dim {d}"
|
174 |
-
# if c is not None and c.shape[1] != d:
|
175 |
-
# return False, f"cluster_centroids dim {c.shape[1]} != model dim {d}"
|
176 |
-
# except Exception:
|
177 |
-
# # Model not built yet; weβll trust the files and rely on runtime checks later
|
178 |
-
# pass
|
179 |
-
|
180 |
-
# _MEAN_EMBED = m.astype(np.float32, copy=False) if m is not None else None
|
181 |
-
# _CENTROIDS = c.astype(np.float32, copy=False) if c is not None else None
|
182 |
-
# _ASSETS_REPO_ID = repo_id
|
183 |
-
# logging.info("Loaded finetune assets from %s (mean=%s, centroids=%s)",
|
184 |
-
# repo_id,
|
185 |
-
# "yes" if _MEAN_EMBED is not None else "no",
|
186 |
-
# f"{_CENTROIDS.shape[0]}x{_CENTROIDS.shape[1]}" if _CENTROIDS is not None else "no")
|
187 |
-
# return True, "ok"
|
188 |
-
# except Exception as e:
|
189 |
-
# logging.exception("Failed to load finetune assets: %s", e)
|
190 |
-
# return False, str(e)
|
191 |
-
|
192 |
-
# def _ensure_assets_loaded():
|
193 |
-
# # Best-effort lazy load if nothing is loaded yet
|
194 |
-
# if _MEAN_EMBED is None and _CENTROIDS is None:
|
195 |
-
# _load_finetune_assets_from_hf(_ASSETS_REPO_ID or _FINETUNE_REPO_DEFAULT)
|
196 |
-
# ------------------------------------------------------------------------------
|
197 |
-
|
198 |
-
# def _resolve_checkpoint_dir() -> str | None:
|
199 |
-
# repo_id = os.getenv("MRT_CKPT_REPO")
|
200 |
-
# if not repo_id:
|
201 |
-
# return None
|
202 |
-
# step = os.getenv("MRT_CKPT_STEP") # e.g. "1863001"
|
203 |
-
|
204 |
-
# root = Path(snapshot_download(
|
205 |
-
# repo_id=repo_id,
|
206 |
-
# repo_type="model",
|
207 |
-
# revision=os.getenv("MRT_CKPT_REV", "main"),
|
208 |
-
# local_dir="/home/appuser/.cache/mrt_ckpt/repo",
|
209 |
-
# local_dir_use_symlinks=False,
|
210 |
-
# ))
|
211 |
-
|
212 |
-
# # Prefer an archive if present (more reliable for Zarr/T5X)
|
213 |
-
# arch_names = [
|
214 |
-
# f"checkpoint_{step}.tgz",
|
215 |
-
# f"checkpoint_{step}.tar.gz",
|
216 |
-
# f"archives/checkpoint_{step}.tgz",
|
217 |
-
# f"archives/checkpoint_{step}.tar.gz",
|
218 |
-
# ] if step else []
|
219 |
-
|
220 |
-
# cache_root = Path("/home/appuser/.cache/mrt_ckpt/extracted")
|
221 |
-
# cache_root.mkdir(parents=True, exist_ok=True)
|
222 |
-
# for name in arch_names:
|
223 |
-
# arch = root / name
|
224 |
-
# if arch.is_file():
|
225 |
-
# out_dir = cache_root / f"checkpoint_{step}"
|
226 |
-
# marker = out_dir.with_suffix(".ok")
|
227 |
-
# if not marker.exists():
|
228 |
-
# out_dir.mkdir(parents=True, exist_ok=True)
|
229 |
-
# with tarfile.open(arch, "r:*") as tf:
|
230 |
-
# tf.extractall(out_dir)
|
231 |
-
# marker.write_text("ok")
|
232 |
-
# # sanity: require .zarray to exist inside the extracted tree
|
233 |
-
# if not any(out_dir.rglob(".zarray")):
|
234 |
-
# raise RuntimeError(f"Extracted archive missing .zarray files: {out_dir}")
|
235 |
-
# return str(out_dir / f"checkpoint_{step}") if (out_dir / f"checkpoint_{step}").exists() else str(out_dir)
|
236 |
-
|
237 |
-
# # No archive; try raw folder from repo and sanity check.
|
238 |
-
# if step:
|
239 |
-
# raw = root / f"checkpoint_{step}"
|
240 |
-
# if raw.is_dir():
|
241 |
-
# if not any(raw.rglob(".zarray")):
|
242 |
-
# raise RuntimeError(
|
243 |
-
# f"Downloaded checkpoint_{step} appears incomplete (no .zarray). "
|
244 |
-
# "Upload as a .tgz or push via git from a Unix shell."
|
245 |
-
# )
|
246 |
-
# return str(raw)
|
247 |
-
|
248 |
-
# # Pick latest if no step
|
249 |
-
# step_dirs = [d for d in root.iterdir() if d.is_dir() and re.match(r"checkpoint_\\d+$", d.name)]
|
250 |
-
# if step_dirs:
|
251 |
-
# pick = max(step_dirs, key=lambda d: int(d.name.split('_')[-1]))
|
252 |
-
# if not any(pick.rglob(".zarray")):
|
253 |
-
# raise RuntimeError(f"Downloaded {pick} appears incomplete (no .zarray).")
|
254 |
-
# return str(pick)
|
255 |
-
|
256 |
-
# return None
|
257 |
-
|
258 |
|
259 |
async def send_json_safe(ws: WebSocket, obj) -> bool:
|
260 |
"""Try to send. Returns False if the socket is (or becomes) closed."""
|
@@ -328,19 +179,19 @@ try:
|
|
328 |
except Exception:
|
329 |
_HAS_LOUDNORM = False
|
330 |
|
331 |
-
def _combine_styles(mrt, styles_str: str = "", weights_str: str = ""):
|
332 |
-
|
333 |
-
|
334 |
-
|
335 |
-
|
336 |
-
|
337 |
-
|
338 |
-
|
339 |
-
|
340 |
-
|
341 |
-
|
342 |
-
|
343 |
-
|
344 |
|
345 |
def build_style_vector(
|
346 |
mrt,
|
@@ -518,6 +369,11 @@ def _mrt_warmup():
|
|
518 |
# Never crash on warmup errors; log and continue serving
|
519 |
logging.exception("MagentaRT warmup failed (continuing without warmup): %s", e)
|
520 |
|
|
|
|
|
|
|
|
|
|
|
521 |
# Kick it off in the background on server start
|
522 |
@app.on_event("startup")
|
523 |
def _kickoff_warmup():
|
@@ -640,17 +496,6 @@ def model_checkpoints(repo_id: str, revision: str = "main"):
|
|
640 |
steps = CheckpointManager.list_ckpt_steps(repo_id, revision)
|
641 |
return {"repo": repo_id, "revision": revision, "steps": steps, "latest": (steps[-1] if steps else None)}
|
642 |
|
643 |
-
# class ModelSelect(BaseModel):
|
644 |
-
# size: Optional[Literal["base","large"]] = None
|
645 |
-
# repo_id: Optional[str] = None
|
646 |
-
# revision: Optional[str] = "main"
|
647 |
-
# step: Optional[Union[int, str]] = None # allow "latest"
|
648 |
-
# assets_repo_id: Optional[str] = None # default: follow repo_id
|
649 |
-
# sync_assets: bool = True # load mean/centroids from repo
|
650 |
-
# prewarm: bool = False # call get_mrt() to build right away
|
651 |
-
# stop_active: bool = True # auto-stop jams; else 409
|
652 |
-
# dry_run: bool = False # validate only, don't swap
|
653 |
-
|
654 |
@app.post("/model/select")
|
655 |
def model_select(req: ModelSelect):
|
656 |
global _MRT, _MEAN_EMBED, _CENTROIDS, _ASSETS_REPO_ID
|
@@ -733,6 +578,12 @@ def model_select(req: ModelSelect):
|
|
733 |
except Exception:
|
734 |
pass
|
735 |
raise HTTPException(status_code=500, detail=f"Swap failed: {e}")
|
|
|
|
|
|
|
|
|
|
|
|
|
736 |
|
737 |
|
738 |
|
|
|
77 |
from model_management import CheckpointManager, AssetManager, ModelSelector, ModelSelect
|
78 |
|
79 |
# ---- Finetune assets (mean & centroids) --------------------------------------
|
80 |
+
# _FINETUNE_REPO_DEFAULT = os.getenv("MRT_ASSETS_REPO", "thepatch/magenta-ft")
|
81 |
_ASSETS_REPO_ID: str | None = None
|
82 |
_MEAN_EMBED: np.ndarray | None = None # shape (D,) dtype float32
|
83 |
_CENTROIDS: np.ndarray | None = None # shape (K, D) dtype float32
|
84 |
|
85 |
+
# _STEP_RE = re.compile(r"(?:^|/)checkpoint_(\d+)(?:/|\.tar\.gz|\.tgz)?$")
|
86 |
|
87 |
# Create instances (these don't modify globals)
|
88 |
asset_manager = AssetManager()
|
89 |
model_selector = ModelSelector(CheckpointManager(), asset_manager)
|
90 |
|
91 |
# Sync asset manager with existing globals
|
92 |
+
# def _sync_asset_manager():
|
93 |
+
# asset_manager.mean_embed = _MEAN_EMBED
|
94 |
+
# asset_manager.centroids = _CENTROIDS
|
95 |
+
# asset_manager.assets_repo_id = _ASSETS_REPO_ID
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
96 |
|
97 |
def _any_jam_running() -> bool:
|
98 |
with jam_lock:
|
|
|
106 |
w.join(timeout=timeout)
|
107 |
jam_registry.pop(sid, None)
|
108 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
109 |
|
110 |
async def send_json_safe(ws: WebSocket, obj) -> bool:
|
111 |
"""Try to send. Returns False if the socket is (or becomes) closed."""
|
|
|
179 |
except Exception:
|
180 |
_HAS_LOUDNORM = False
|
181 |
|
182 |
+
# def _combine_styles(mrt, styles_str: str = "", weights_str: str = ""):
|
183 |
+
# extra = [s.strip() for s in (styles_str or "").split(",") if s.strip()]
|
184 |
+
# if not extra:
|
185 |
+
# return mrt.embed_style("warmup")
|
186 |
+
# sw = [float(x) for x in (weights_str or "").split(",") if x.strip()]
|
187 |
+
# embeds, weights = [], []
|
188 |
+
# for i, s in enumerate(extra):
|
189 |
+
# embeds.append(mrt.embed_style(s))
|
190 |
+
# weights.append(sw[i] if i < len(sw) else 1.0)
|
191 |
+
# wsum = sum(weights) or 1.0
|
192 |
+
# weights = [w/wsum for w in weights]
|
193 |
+
# import numpy as np
|
194 |
+
# return np.sum([w*e for w, e in zip(weights, embeds)], axis=0).astype(np.float32)
|
195 |
|
196 |
def build_style_vector(
|
197 |
mrt,
|
|
|
369 |
# Never crash on warmup errors; log and continue serving
|
370 |
logging.exception("MagentaRT warmup failed (continuing without warmup): %s", e)
|
371 |
|
372 |
+
|
373 |
+
# ----------------------------
|
374 |
+
# startup and model selection
|
375 |
+
# ----------------------------
|
376 |
+
|
377 |
# Kick it off in the background on server start
|
378 |
@app.on_event("startup")
|
379 |
def _kickoff_warmup():
|
|
|
496 |
steps = CheckpointManager.list_ckpt_steps(repo_id, revision)
|
497 |
return {"repo": repo_id, "revision": revision, "steps": steps, "latest": (steps[-1] if steps else None)}
|
498 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
499 |
@app.post("/model/select")
|
500 |
def model_select(req: ModelSelect):
|
501 |
global _MRT, _MEAN_EMBED, _CENTROIDS, _ASSETS_REPO_ID
|
|
|
578 |
except Exception:
|
579 |
pass
|
580 |
raise HTTPException(status_code=500, detail=f"Swap failed: {e}")
|
581 |
+
|
582 |
+
|
583 |
+
|
584 |
+
# ----------------------------
|
585 |
+
# one-shot generation
|
586 |
+
# ----------------------------
|
587 |
|
588 |
|
589 |
|
documentation.html
CHANGED
@@ -4,67 +4,326 @@
|
|
4 |
<meta charset="utf-8">
|
5 |
<title>MagentaRT Research API</title>
|
6 |
<style>
|
7 |
-
body {
|
8 |
-
|
9 |
-
|
10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
ul { line-height: 1.8; }
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
</style>
|
13 |
</head>
|
14 |
<body>
|
15 |
-
<
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
<
|
23 |
-
<
|
24 |
-
<
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
32 |
"type": "start",
|
33 |
"mode": "rt",
|
34 |
"binary_audio": false,
|
35 |
"params": {
|
36 |
-
"styles": "
|
|
|
37 |
"temperature": 1.1,
|
38 |
"topk": 40,
|
39 |
"guidance_weight": 1.1,
|
40 |
-
"pace": "realtime",
|
41 |
-
"
|
|
|
|
|
42 |
}
|
43 |
}</pre>
|
44 |
-
|
45 |
-
|
|
|
46 |
"type": "update",
|
47 |
"styles": "jazz, hiphop",
|
48 |
-
"style_weights": "1.0,0.8",
|
49 |
"temperature": 1.2,
|
50 |
"topk": 64,
|
51 |
"guidance_weight": 1.0,
|
52 |
-
"
|
53 |
-
"
|
54 |
}</pre>
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
<
|
62 |
-
|
63 |
-
<
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
69 |
</body>
|
70 |
</html>
|
|
|
4 |
<meta charset="utf-8">
|
5 |
<title>MagentaRT Research API</title>
|
6 |
<style>
|
7 |
+
body {
|
8 |
+
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif;
|
9 |
+
max-width: 900px;
|
10 |
+
margin: 48px auto;
|
11 |
+
padding: 0 24px;
|
12 |
+
color: #111;
|
13 |
+
line-height: 1.6;
|
14 |
+
}
|
15 |
+
.header { text-align: center; margin-bottom: 48px; }
|
16 |
+
.badge {
|
17 |
+
display: inline-block;
|
18 |
+
background: #ff6b35;
|
19 |
+
color: white;
|
20 |
+
padding: 4px 12px;
|
21 |
+
border-radius: 16px;
|
22 |
+
font-size: 0.85em;
|
23 |
+
font-weight: 500;
|
24 |
+
margin-left: 8px;
|
25 |
+
}
|
26 |
+
code, pre {
|
27 |
+
background: #f6f8fa;
|
28 |
+
border: 1px solid #eaecef;
|
29 |
+
border-radius: 6px;
|
30 |
+
font-family: 'SF Mono', Monaco, 'Cascadia Code', 'Roboto Mono', Consolas, monospace;
|
31 |
+
}
|
32 |
+
code { padding: 2px 6px; }
|
33 |
+
pre {
|
34 |
+
padding: 16px;
|
35 |
+
overflow-x: auto;
|
36 |
+
margin: 16px 0;
|
37 |
+
position: relative;
|
38 |
+
}
|
39 |
+
.copy-btn {
|
40 |
+
position: absolute;
|
41 |
+
top: 8px;
|
42 |
+
right: 8px;
|
43 |
+
background: #0969da;
|
44 |
+
color: white;
|
45 |
+
border: none;
|
46 |
+
border-radius: 4px;
|
47 |
+
padding: 4px 8px;
|
48 |
+
font-size: 12px;
|
49 |
+
cursor: pointer;
|
50 |
+
}
|
51 |
+
.copy-btn:hover { background: #0550ae; }
|
52 |
+
.muted { color: #656d76; }
|
53 |
+
.warning {
|
54 |
+
background: #fff8c5;
|
55 |
+
border: 1px solid #e3b341;
|
56 |
+
border-radius: 8px;
|
57 |
+
padding: 16px;
|
58 |
+
margin: 16px 0;
|
59 |
+
}
|
60 |
+
.info {
|
61 |
+
background: #dbeafe;
|
62 |
+
border: 1px solid #3b82f6;
|
63 |
+
border-radius: 8px;
|
64 |
+
padding: 16px;
|
65 |
+
margin: 16px 0;
|
66 |
+
}
|
67 |
ul { line-height: 1.8; }
|
68 |
+
.endpoint {
|
69 |
+
background: #f8f9fa;
|
70 |
+
border-left: 4px solid #0969da;
|
71 |
+
padding: 12px 16px;
|
72 |
+
margin: 12px 0;
|
73 |
+
}
|
74 |
+
.demo-placeholder {
|
75 |
+
background: #f6f8fa;
|
76 |
+
border: 2px dashed #d1d9e0;
|
77 |
+
border-radius: 8px;
|
78 |
+
padding: 48px;
|
79 |
+
text-align: center;
|
80 |
+
margin: 24px 0;
|
81 |
+
color: #656d76;
|
82 |
+
}
|
83 |
+
.grid {
|
84 |
+
display: grid;
|
85 |
+
grid-template-columns: 1fr 1fr;
|
86 |
+
gap: 24px;
|
87 |
+
margin: 24px 0;
|
88 |
+
}
|
89 |
+
.card {
|
90 |
+
background: #f8f9fa;
|
91 |
+
border: 1px solid #e1e8ed;
|
92 |
+
border-radius: 8px;
|
93 |
+
padding: 20px;
|
94 |
+
}
|
95 |
+
a { color: #0969da; text-decoration: none; }
|
96 |
+
a:hover { text-decoration: underline; }
|
97 |
+
.section { margin: 48px 0; }
|
98 |
</style>
|
99 |
</head>
|
100 |
<body>
|
101 |
+
<div class="header">
|
102 |
+
<h1>π΅ MagentaRT Research API</h1>
|
103 |
+
<p class="muted"><strong>AI Music Generation API</strong> β’ Real-time streaming β’ Custom fine-tuning support</p>
|
104 |
+
<span class="badge">Research Project</span>
|
105 |
+
</div>
|
106 |
+
|
107 |
+
<div class="demo-placeholder">
|
108 |
+
<h3>π± App Demo Video</h3>
|
109 |
+
<p>Demo video will be embedded here<br>
|
110 |
+
<small>Showing the iPhone app generating music in real-time</small></p>
|
111 |
+
</div>
|
112 |
+
|
113 |
+
<div class="section">
|
114 |
+
<h2>Overview</h2>
|
115 |
+
<p>This API powers AI music generation using Google's MagentaRT, designed for real-time audio streaming and custom model fine-tuning. Built for iOS app integration with WebSocket streaming support.</p>
|
116 |
+
|
117 |
+
<div class="info">
|
118 |
+
<strong>Hardware Requirements:</strong> Optimal performance requires an L40S GPU (48GB VRAM) for real-time streaming. L4 24GB works but may not maintain real-time performance.
|
119 |
+
</div>
|
120 |
+
</div>
|
121 |
+
|
122 |
+
<div class="section">
|
123 |
+
<h2>Quick Start - WebSocket Streaming</h2>
|
124 |
+
<p>Connect to <code>wss://<your-space>/ws/jam</code> for real-time audio generation:</p>
|
125 |
+
|
126 |
+
<h3>Start Real-time Generation</h3>
|
127 |
+
<pre><button class="copy-btn" onclick="copyCode(this)">Copy</button>{
|
128 |
"type": "start",
|
129 |
"mode": "rt",
|
130 |
"binary_audio": false,
|
131 |
"params": {
|
132 |
+
"styles": "electronic, ambient",
|
133 |
+
"style_weights": "1.0, 0.8",
|
134 |
"temperature": 1.1,
|
135 |
"topk": 40,
|
136 |
"guidance_weight": 1.1,
|
137 |
+
"pace": "realtime",
|
138 |
+
"style_ramp_seconds": 8.0,
|
139 |
+
"mean": 0.0,
|
140 |
+
"centroid_weights": "0.0, 0.0, 0.0"
|
141 |
}
|
142 |
}</pre>
|
143 |
+
|
144 |
+
<h3>Update Parameters Live</h3>
|
145 |
+
<pre><button class="copy-btn" onclick="copyCode(this)">Copy</button>{
|
146 |
"type": "update",
|
147 |
"styles": "jazz, hiphop",
|
148 |
+
"style_weights": "1.0, 0.8",
|
149 |
"temperature": 1.2,
|
150 |
"topk": 64,
|
151 |
"guidance_weight": 1.0,
|
152 |
+
"mean": 0.2,
|
153 |
+
"centroid_weights": "0.1, 0.3, 0.0"
|
154 |
}</pre>
|
155 |
+
|
156 |
+
<h3>Stop Generation</h3>
|
157 |
+
<pre><button class="copy-btn" onclick="copyCode(this)">Copy</button>{"type": "stop"}</pre>
|
158 |
+
</div>
|
159 |
+
|
160 |
+
<div class="section">
|
161 |
+
<h2>API Endpoints</h2>
|
162 |
+
|
163 |
+
<div class="endpoint">
|
164 |
+
<strong>POST /generate</strong> - Generate 4β8 bars of music with input audio
|
165 |
+
</div>
|
166 |
+
|
167 |
+
<div class="endpoint">
|
168 |
+
<strong>POST /generate_style</strong> - Generate music from style prompts only (experimental)
|
169 |
+
</div>
|
170 |
+
|
171 |
+
<div class="endpoint">
|
172 |
+
<strong>POST /jam/start</strong> - Start continuous jamming session
|
173 |
+
</div>
|
174 |
+
|
175 |
+
<div class="endpoint">
|
176 |
+
<strong>GET /jam/next</strong> - Get next audio chunk from session
|
177 |
+
</div>
|
178 |
+
|
179 |
+
<div class="endpoint">
|
180 |
+
<strong>POST /jam/consume</strong> - Mark chunk as consumed
|
181 |
+
</div>
|
182 |
+
|
183 |
+
<div class="endpoint">
|
184 |
+
<strong>POST /jam/stop</strong> - End jamming session
|
185 |
+
</div>
|
186 |
+
|
187 |
+
<div class="endpoint">
|
188 |
+
<strong>WEBSOCKET /ws/jam</strong> - Real-time streaming interface
|
189 |
+
</div>
|
190 |
+
|
191 |
+
<div class="endpoint">
|
192 |
+
<strong>POST /model/select</strong> - Switch between base and fine-tuned models
|
193 |
+
</div>
|
194 |
+
</div>
|
195 |
+
|
196 |
+
<div class="section">
|
197 |
+
<h2>Custom Fine-Tuning</h2>
|
198 |
+
<p>Train your own MagentaRT models and use them with this API and the iOS app.</p>
|
199 |
+
|
200 |
+
<div class="grid">
|
201 |
+
<div class="card">
|
202 |
+
<h3>1. Train Your Model</h3>
|
203 |
+
<p>Use the official MagentaRT fine-tuning notebook:</p>
|
204 |
+
<p><a href="https://github.com/magenta-realtime/notebooks/blob/main/Magenta_RT_Finetune.ipynb" target="_blank">π MagentaRT Fine-tuning Colab</a></p>
|
205 |
+
<p>This will create checkpoint folders like:</p>
|
206 |
+
<ul>
|
207 |
+
<li><code>checkpoint_1861001/</code></li>
|
208 |
+
<li><code>checkpoint_1862001/</code></li>
|
209 |
+
<li>And steering assets: <code>cluster_centroids.npy</code>, <code>mean_style_embed.npy</code></li>
|
210 |
+
</ul>
|
211 |
+
</div>
|
212 |
+
|
213 |
+
<div class="card">
|
214 |
+
<h3>2. Package Checkpoints</h3>
|
215 |
+
<p>Checkpoints must be compressed as .tgz files to preserve .zarray files correctly.</p>
|
216 |
+
<div class="warning">
|
217 |
+
<strong>Important:</strong> Do not download checkpoint folders directly from Google Drive - the .zarray files won't transfer properly.
|
218 |
+
</div>
|
219 |
+
</div>
|
220 |
+
</div>
|
221 |
+
|
222 |
+
<h3>Checkpoint Packaging Script</h3>
|
223 |
+
<p>Use this in a Colab cell to properly package your checkpoints:</p>
|
224 |
+
<pre><button class="copy-btn" onclick="copyCode(this)">Copy</button># Mount Drive to access your trained checkpoints
|
225 |
+
from google.colab import drive
|
226 |
+
drive.mount('/content/drive')
|
227 |
+
|
228 |
+
# Set the path to your checkpoint folder
|
229 |
+
CKPT_SRC = '/content/drive/MyDrive/thepatch/checkpoint_1862001' # Adjust path
|
230 |
+
|
231 |
+
# Copy folder to local storage (preserves dotfiles)
|
232 |
+
!rm -rf /content/checkpoint_1862001
|
233 |
+
!cp -a "$CKPT_SRC" /content/
|
234 |
+
|
235 |
+
# Verify .zarray files are present
|
236 |
+
!find /content/checkpoint_1862001 -name .zarray | wc -l
|
237 |
+
|
238 |
+
# Create properly formatted .tgz archive
|
239 |
+
!tar -C /content -czf /content/checkpoint_1862001.tgz checkpoint_1862001
|
240 |
+
|
241 |
+
# Verify critical files are in the archive
|
242 |
+
!tar -tzf /content/checkpoint_1862001.tgz | grep -c '.zarray'
|
243 |
+
|
244 |
+
# Download the .tgz file
|
245 |
+
from google.colab import files
|
246 |
+
files.download('/content/checkpoint_1862001.tgz')</pre>
|
247 |
+
|
248 |
+
<h3>3. Upload to Hugging Face</h3>
|
249 |
+
<p>Create a model repository and upload:</p>
|
250 |
+
<ul>
|
251 |
+
<li>Your <code>.tgz</code> checkpoint files</li>
|
252 |
+
<li><code>cluster_centroids.npy</code> (for steering)</li>
|
253 |
+
<li><code>mean_style_embed.npy</code> (for steering)</li>
|
254 |
+
</ul>
|
255 |
+
|
256 |
+
<div class="info">
|
257 |
+
<strong>Example Repository:</strong> <a href="https://huggingface.co/thepatch/magenta-ft" target="_blank">thepatch/magenta-ft</a><br>
|
258 |
+
Shows the correct file structure with .tgz files and .npy steering assets in the root directory.
|
259 |
+
</div>
|
260 |
+
|
261 |
+
<h3>4. Use in the App</h3>
|
262 |
+
<p>In the iOS app's model selector, point to your Hugging Face repository URL. The app will automatically discover available checkpoints and allow switching between them.</p>
|
263 |
+
</div>
|
264 |
+
|
265 |
+
<div class="section">
|
266 |
+
<h2>Technical Specifications</h2>
|
267 |
+
<ul>
|
268 |
+
<li><strong>Audio Format:</strong> 48 kHz stereo, ~2.0s chunks with ~40ms crossfade</li>
|
269 |
+
<li><strong>Model Sizes:</strong> Base and Large variants available</li>
|
270 |
+
<li><strong>Steering:</strong> Support for text prompts, audio embeddings, and centroid-based fine-tune steering</li>
|
271 |
+
<li><strong>Real-time Performance:</strong> L40S recommended; L4 may experience slight delays</li>
|
272 |
+
<li><strong>Memory Requirements:</strong> ~40GB VRAM for sustained real-time streaming</li>
|
273 |
+
</ul>
|
274 |
+
|
275 |
+
<div class="warning">
|
276 |
+
<strong>Note:</strong> The <code>/generate_style</code> endpoint is experimental and may not properly adhere to BPM without additional context (considering metronome-based context instead of silence).
|
277 |
+
</div>
|
278 |
+
</div>
|
279 |
+
|
280 |
+
<div class="section">
|
281 |
+
<h2>Integration with iOS App</h2>
|
282 |
+
<p>This API is designed to work seamlessly with our iOS music generation app:</p>
|
283 |
+
<ul>
|
284 |
+
<li>Real-time audio streaming via WebSockets</li>
|
285 |
+
<li>Dynamic model switching between base and fine-tuned models</li>
|
286 |
+
<li>Integration with stable-audio-open-small for combined input audio generation</li>
|
287 |
+
<li>Live parameter adjustment during generation</li>
|
288 |
+
</ul>
|
289 |
+
</div>
|
290 |
+
|
291 |
+
<div class="section">
|
292 |
+
<h2>Deployment</h2>
|
293 |
+
<p>To run your own instance:</p>
|
294 |
+
<ol>
|
295 |
+
<li>Duplicate this Hugging Face Space</li>
|
296 |
+
<li>Ensure you have access to an L40S GPU</li>
|
297 |
+
<li>Point your iOS app to the new space URL (e.g., <code>https://your-username-magenta-retry.hf.space</code>)</li>
|
298 |
+
<li>Upload your fine-tuned models as described above</li>
|
299 |
+
</ol>
|
300 |
+
</div>
|
301 |
+
|
302 |
+
<div class="section">
|
303 |
+
<h2>Support & Contact</h2>
|
304 |
+
<p>This is an active research project. For questions, technical support, or collaboration:</p>
|
305 |
+
<p><strong>Email:</strong> <a href="mailto:kev@thecollabagepatch.com">kev@thecollabagepatch.com</a></p>
|
306 |
+
|
307 |
+
<div class="info">
|
308 |
+
<strong>Research Status:</strong> This project is under active development. Features and API may change. We welcome feedback and contributions from the research community.
|
309 |
+
</div>
|
310 |
+
</div>
|
311 |
+
|
312 |
+
<div class="section">
|
313 |
+
<h2>Licensing</h2>
|
314 |
+
<p>Built on Google's MagentaRT (Apache 2.0 + CC-BY 4.0). Users are responsible for their generated outputs and ensuring compliance with applicable laws and platform policies.</p>
|
315 |
+
<p><a href="/docs">π API Reference Documentation</a></p>
|
316 |
+
</div>
|
317 |
+
|
318 |
+
<script>
|
319 |
+
function copyCode(button) {
|
320 |
+
const pre = button.parentElement;
|
321 |
+
const code = pre.textContent.replace('Copy', '').trim();
|
322 |
+
navigator.clipboard.writeText(code).then(() => {
|
323 |
+
button.textContent = 'Copied!';
|
324 |
+
setTimeout(() => button.textContent = 'Copy', 2000);
|
325 |
+
});
|
326 |
+
}
|
327 |
+
</script>
|
328 |
</body>
|
329 |
</html>
|