thecollabagepatch commited on
Commit
30fdbbc
Β·
1 Parent(s): 384e4ac

updating docs a bit

Browse files
Files changed (2) hide show
  1. app.py +30 -179
  2. documentation.html +302 -43
app.py CHANGED
@@ -77,45 +77,22 @@ from pydantic import BaseModel
77
  from model_management import CheckpointManager, AssetManager, ModelSelector, ModelSelect
78
 
79
  # ---- Finetune assets (mean & centroids) --------------------------------------
80
- _FINETUNE_REPO_DEFAULT = os.getenv("MRT_ASSETS_REPO", "thepatch/magenta-ft")
81
  _ASSETS_REPO_ID: str | None = None
82
  _MEAN_EMBED: np.ndarray | None = None # shape (D,) dtype float32
83
  _CENTROIDS: np.ndarray | None = None # shape (K, D) dtype float32
84
 
85
- _STEP_RE = re.compile(r"(?:^|/)checkpoint_(\d+)(?:/|\.tar\.gz|\.tgz)?$")
86
 
87
  # Create instances (these don't modify globals)
88
  asset_manager = AssetManager()
89
  model_selector = ModelSelector(CheckpointManager(), asset_manager)
90
 
91
  # Sync asset manager with existing globals
92
- def _sync_asset_manager():
93
- asset_manager.mean_embed = _MEAN_EMBED
94
- asset_manager.centroids = _CENTROIDS
95
- asset_manager.assets_repo_id = _ASSETS_REPO_ID
96
-
97
- # def _list_ckpt_steps(repo_id: str, revision: str = "main") -> list[int]:
98
- # """
99
- # List available checkpoint steps in a HF model repo without downloading all weights.
100
- # Looks for:
101
- # checkpoint_<step>/
102
- # checkpoint_<step>.tgz | .tar.gz
103
- # archives/checkpoint_<step>.tgz | .tar.gz
104
- # """
105
- # api = HfApi()
106
- # files = api.list_repo_files(repo_id=repo_id, repo_type="model", revision=revision)
107
- # steps = set()
108
- # for f in files:
109
- # m = _STEP_RE.search(f)
110
- # if m:
111
- # try:
112
- # steps.add(int(m.group(1)))
113
- # except:
114
- # pass
115
- # return sorted(steps)
116
-
117
- # def _step_exists(repo_id: str, revision: str, step: int) -> bool:
118
- # return step in _list_ckpt_steps(repo_id, revision)
119
 
120
  def _any_jam_running() -> bool:
121
  with jam_lock:
@@ -129,132 +106,6 @@ def _stop_all_jams(timeout: float = 5.0):
129
  w.join(timeout=timeout)
130
  jam_registry.pop(sid, None)
131
 
132
- # def _load_finetune_assets_from_hf(repo_id: str | None) -> tuple[bool, str]:
133
- # """
134
- # Download & load mean_style_embed.npy and cluster_centroids.npy from a HF model repo.
135
- # Safe to call multiple times; will overwrite globals if successful.
136
- # """
137
- # global _ASSETS_REPO_ID, _MEAN_EMBED, _CENTROIDS
138
- # repo_id = repo_id or _FINETUNE_REPO_DEFAULT
139
- # try:
140
- # from huggingface_hub import hf_hub_download
141
- # mean_path = None
142
- # cent_path = None
143
- # try:
144
- # mean_path = hf_hub_download(repo_id, filename="mean_style_embed.npy", repo_type="model")
145
- # except Exception:
146
- # pass
147
- # try:
148
- # cent_path = hf_hub_download(repo_id, filename="cluster_centroids.npy", repo_type="model")
149
- # except Exception:
150
- # pass
151
-
152
- # if mean_path is None and cent_path is None:
153
- # return False, f"No finetune asset files found in repo {repo_id}"
154
-
155
- # if mean_path is not None:
156
- # m = np.load(mean_path)
157
- # if m.ndim != 1:
158
- # return False, f"mean_style_embed.npy must be 1-D (got {m.shape})"
159
- # else:
160
- # m = None
161
-
162
- # if cent_path is not None:
163
- # c = np.load(cent_path)
164
- # if c.ndim != 2:
165
- # return False, f"cluster_centroids.npy must be 2-D (got {c.shape})"
166
- # else:
167
- # c = None
168
-
169
- # # Optional: shape check vs model embedding dim once model is alive
170
- # try:
171
- # d = int(get_mrt().style_model.config.embedding_dim)
172
- # if m is not None and m.shape[0] != d:
173
- # return False, f"mean_style_embed dim {m.shape[0]} != model dim {d}"
174
- # if c is not None and c.shape[1] != d:
175
- # return False, f"cluster_centroids dim {c.shape[1]} != model dim {d}"
176
- # except Exception:
177
- # # Model not built yet; we’ll trust the files and rely on runtime checks later
178
- # pass
179
-
180
- # _MEAN_EMBED = m.astype(np.float32, copy=False) if m is not None else None
181
- # _CENTROIDS = c.astype(np.float32, copy=False) if c is not None else None
182
- # _ASSETS_REPO_ID = repo_id
183
- # logging.info("Loaded finetune assets from %s (mean=%s, centroids=%s)",
184
- # repo_id,
185
- # "yes" if _MEAN_EMBED is not None else "no",
186
- # f"{_CENTROIDS.shape[0]}x{_CENTROIDS.shape[1]}" if _CENTROIDS is not None else "no")
187
- # return True, "ok"
188
- # except Exception as e:
189
- # logging.exception("Failed to load finetune assets: %s", e)
190
- # return False, str(e)
191
-
192
- # def _ensure_assets_loaded():
193
- # # Best-effort lazy load if nothing is loaded yet
194
- # if _MEAN_EMBED is None and _CENTROIDS is None:
195
- # _load_finetune_assets_from_hf(_ASSETS_REPO_ID or _FINETUNE_REPO_DEFAULT)
196
- # ------------------------------------------------------------------------------
197
-
198
- # def _resolve_checkpoint_dir() -> str | None:
199
- # repo_id = os.getenv("MRT_CKPT_REPO")
200
- # if not repo_id:
201
- # return None
202
- # step = os.getenv("MRT_CKPT_STEP") # e.g. "1863001"
203
-
204
- # root = Path(snapshot_download(
205
- # repo_id=repo_id,
206
- # repo_type="model",
207
- # revision=os.getenv("MRT_CKPT_REV", "main"),
208
- # local_dir="/home/appuser/.cache/mrt_ckpt/repo",
209
- # local_dir_use_symlinks=False,
210
- # ))
211
-
212
- # # Prefer an archive if present (more reliable for Zarr/T5X)
213
- # arch_names = [
214
- # f"checkpoint_{step}.tgz",
215
- # f"checkpoint_{step}.tar.gz",
216
- # f"archives/checkpoint_{step}.tgz",
217
- # f"archives/checkpoint_{step}.tar.gz",
218
- # ] if step else []
219
-
220
- # cache_root = Path("/home/appuser/.cache/mrt_ckpt/extracted")
221
- # cache_root.mkdir(parents=True, exist_ok=True)
222
- # for name in arch_names:
223
- # arch = root / name
224
- # if arch.is_file():
225
- # out_dir = cache_root / f"checkpoint_{step}"
226
- # marker = out_dir.with_suffix(".ok")
227
- # if not marker.exists():
228
- # out_dir.mkdir(parents=True, exist_ok=True)
229
- # with tarfile.open(arch, "r:*") as tf:
230
- # tf.extractall(out_dir)
231
- # marker.write_text("ok")
232
- # # sanity: require .zarray to exist inside the extracted tree
233
- # if not any(out_dir.rglob(".zarray")):
234
- # raise RuntimeError(f"Extracted archive missing .zarray files: {out_dir}")
235
- # return str(out_dir / f"checkpoint_{step}") if (out_dir / f"checkpoint_{step}").exists() else str(out_dir)
236
-
237
- # # No archive; try raw folder from repo and sanity check.
238
- # if step:
239
- # raw = root / f"checkpoint_{step}"
240
- # if raw.is_dir():
241
- # if not any(raw.rglob(".zarray")):
242
- # raise RuntimeError(
243
- # f"Downloaded checkpoint_{step} appears incomplete (no .zarray). "
244
- # "Upload as a .tgz or push via git from a Unix shell."
245
- # )
246
- # return str(raw)
247
-
248
- # # Pick latest if no step
249
- # step_dirs = [d for d in root.iterdir() if d.is_dir() and re.match(r"checkpoint_\\d+$", d.name)]
250
- # if step_dirs:
251
- # pick = max(step_dirs, key=lambda d: int(d.name.split('_')[-1]))
252
- # if not any(pick.rglob(".zarray")):
253
- # raise RuntimeError(f"Downloaded {pick} appears incomplete (no .zarray).")
254
- # return str(pick)
255
-
256
- # return None
257
-
258
 
259
  async def send_json_safe(ws: WebSocket, obj) -> bool:
260
  """Try to send. Returns False if the socket is (or becomes) closed."""
@@ -328,19 +179,19 @@ try:
328
  except Exception:
329
  _HAS_LOUDNORM = False
330
 
331
- def _combine_styles(mrt, styles_str: str = "", weights_str: str = ""):
332
- extra = [s.strip() for s in (styles_str or "").split(",") if s.strip()]
333
- if not extra:
334
- return mrt.embed_style("warmup")
335
- sw = [float(x) for x in (weights_str or "").split(",") if x.strip()]
336
- embeds, weights = [], []
337
- for i, s in enumerate(extra):
338
- embeds.append(mrt.embed_style(s))
339
- weights.append(sw[i] if i < len(sw) else 1.0)
340
- wsum = sum(weights) or 1.0
341
- weights = [w/wsum for w in weights]
342
- import numpy as np
343
- return np.sum([w*e for w, e in zip(weights, embeds)], axis=0).astype(np.float32)
344
 
345
  def build_style_vector(
346
  mrt,
@@ -518,6 +369,11 @@ def _mrt_warmup():
518
  # Never crash on warmup errors; log and continue serving
519
  logging.exception("MagentaRT warmup failed (continuing without warmup): %s", e)
520
 
 
 
 
 
 
521
  # Kick it off in the background on server start
522
  @app.on_event("startup")
523
  def _kickoff_warmup():
@@ -640,17 +496,6 @@ def model_checkpoints(repo_id: str, revision: str = "main"):
640
  steps = CheckpointManager.list_ckpt_steps(repo_id, revision)
641
  return {"repo": repo_id, "revision": revision, "steps": steps, "latest": (steps[-1] if steps else None)}
642
 
643
- # class ModelSelect(BaseModel):
644
- # size: Optional[Literal["base","large"]] = None
645
- # repo_id: Optional[str] = None
646
- # revision: Optional[str] = "main"
647
- # step: Optional[Union[int, str]] = None # allow "latest"
648
- # assets_repo_id: Optional[str] = None # default: follow repo_id
649
- # sync_assets: bool = True # load mean/centroids from repo
650
- # prewarm: bool = False # call get_mrt() to build right away
651
- # stop_active: bool = True # auto-stop jams; else 409
652
- # dry_run: bool = False # validate only, don't swap
653
-
654
  @app.post("/model/select")
655
  def model_select(req: ModelSelect):
656
  global _MRT, _MEAN_EMBED, _CENTROIDS, _ASSETS_REPO_ID
@@ -733,6 +578,12 @@ def model_select(req: ModelSelect):
733
  except Exception:
734
  pass
735
  raise HTTPException(status_code=500, detail=f"Swap failed: {e}")
 
 
 
 
 
 
736
 
737
 
738
 
 
77
  from model_management import CheckpointManager, AssetManager, ModelSelector, ModelSelect
78
 
79
  # ---- Finetune assets (mean & centroids) --------------------------------------
80
+ # _FINETUNE_REPO_DEFAULT = os.getenv("MRT_ASSETS_REPO", "thepatch/magenta-ft")
81
  _ASSETS_REPO_ID: str | None = None
82
  _MEAN_EMBED: np.ndarray | None = None # shape (D,) dtype float32
83
  _CENTROIDS: np.ndarray | None = None # shape (K, D) dtype float32
84
 
85
+ # _STEP_RE = re.compile(r"(?:^|/)checkpoint_(\d+)(?:/|\.tar\.gz|\.tgz)?$")
86
 
87
  # Create instances (these don't modify globals)
88
  asset_manager = AssetManager()
89
  model_selector = ModelSelector(CheckpointManager(), asset_manager)
90
 
91
  # Sync asset manager with existing globals
92
+ # def _sync_asset_manager():
93
+ # asset_manager.mean_embed = _MEAN_EMBED
94
+ # asset_manager.centroids = _CENTROIDS
95
+ # asset_manager.assets_repo_id = _ASSETS_REPO_ID
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
 
97
  def _any_jam_running() -> bool:
98
  with jam_lock:
 
106
  w.join(timeout=timeout)
107
  jam_registry.pop(sid, None)
108
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
 
110
  async def send_json_safe(ws: WebSocket, obj) -> bool:
111
  """Try to send. Returns False if the socket is (or becomes) closed."""
 
179
  except Exception:
180
  _HAS_LOUDNORM = False
181
 
182
+ # def _combine_styles(mrt, styles_str: str = "", weights_str: str = ""):
183
+ # extra = [s.strip() for s in (styles_str or "").split(",") if s.strip()]
184
+ # if not extra:
185
+ # return mrt.embed_style("warmup")
186
+ # sw = [float(x) for x in (weights_str or "").split(",") if x.strip()]
187
+ # embeds, weights = [], []
188
+ # for i, s in enumerate(extra):
189
+ # embeds.append(mrt.embed_style(s))
190
+ # weights.append(sw[i] if i < len(sw) else 1.0)
191
+ # wsum = sum(weights) or 1.0
192
+ # weights = [w/wsum for w in weights]
193
+ # import numpy as np
194
+ # return np.sum([w*e for w, e in zip(weights, embeds)], axis=0).astype(np.float32)
195
 
196
  def build_style_vector(
197
  mrt,
 
369
  # Never crash on warmup errors; log and continue serving
370
  logging.exception("MagentaRT warmup failed (continuing without warmup): %s", e)
371
 
372
+
373
+ # ----------------------------
374
+ # startup and model selection
375
+ # ----------------------------
376
+
377
  # Kick it off in the background on server start
378
  @app.on_event("startup")
379
  def _kickoff_warmup():
 
496
  steps = CheckpointManager.list_ckpt_steps(repo_id, revision)
497
  return {"repo": repo_id, "revision": revision, "steps": steps, "latest": (steps[-1] if steps else None)}
498
 
 
 
 
 
 
 
 
 
 
 
 
499
  @app.post("/model/select")
500
  def model_select(req: ModelSelect):
501
  global _MRT, _MEAN_EMBED, _CENTROIDS, _ASSETS_REPO_ID
 
578
  except Exception:
579
  pass
580
  raise HTTPException(status_code=500, detail=f"Swap failed: {e}")
581
+
582
+
583
+
584
+ # ----------------------------
585
+ # one-shot generation
586
+ # ----------------------------
587
 
588
 
589
 
documentation.html CHANGED
@@ -4,67 +4,326 @@
4
  <meta charset="utf-8">
5
  <title>MagentaRT Research API</title>
6
  <style>
7
- body { font-family: Arial, sans-serif; max-width: 860px; margin: 48px auto; padding: 0 20px; color:#111; }
8
- code, pre { background:#f6f8fa; border:1px solid #eaecef; border-radius:6px; padding:2px 6px; }
9
- pre { padding:12px; overflow:auto; }
10
- .muted { color:#555; }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  ul { line-height: 1.8; }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  </style>
13
  </head>
14
  <body>
15
- <h1>🎡 MagentaRT Research API</h1>
16
- <p class="muted"><strong>Purpose:</strong> AI music generation for iOS/web app research using Google's MagentaRT.</p>
17
-
18
- <h2>Available Endpoints</h2>
19
- <ul>
20
- <li><code>POST /generate</code> – Generate 4–8 bars of music (HTTP, bar-aligned)</li>
21
- <li><code>POST /jam/start</code> – Start continuous jamming (HTTP)</li>
22
- <li><code>GET /jam/next</code> – Get next chunk (HTTP)</li>
23
- <li><code>POST /jam/consume</code> – Confirm a chunk as consumed (HTTP)</li>
24
- <li><code>POST /jam/stop</code> – End session (HTTP)</li>
25
- <li><code>WEBSOCKET /ws/jam</code> – Realtime streaming (<code>mode="rt"</code>)</li>
26
- <li><code>GET /docs</code> – API documentation (Gradio)</li>
27
- </ul>
28
-
29
- <h2>WebSocket Quick Start (rt mode)</h2>
30
- <p>Connect to <code>wss://&lt;your-space&gt;/ws/jam</code> and send:</p>
31
- <pre>{
 
 
 
 
 
 
 
 
 
 
32
  "type": "start",
33
  "mode": "rt",
34
  "binary_audio": false,
35
  "params": {
36
- "styles": "warmup",
 
37
  "temperature": 1.1,
38
  "topk": 40,
39
  "guidance_weight": 1.1,
40
- "pace": "realtime", // or "asap" to bootstrap quickly
41
- "max_decode_frames": 50 // default ~2.0s; try 36–45 on smaller GPUs
 
 
42
  }
43
  }</pre>
44
- <p>Update parameters live:</p>
45
- <pre>{
 
46
  "type": "update",
47
  "styles": "jazz, hiphop",
48
- "style_weights": "1.0,0.8",
49
  "temperature": 1.2,
50
  "topk": 64,
51
  "guidance_weight": 1.0,
52
- "pace": "realtime",
53
- "max_decode_frames": 40
54
  }</pre>
55
- <p>Stop:</p>
56
- <pre>{"type":"stop"}</pre>
57
-
58
- <h2>Notes</h2>
59
- <ul>
60
- <li>Audio: 48 kHz stereo, ~2.0 s chunks by default with ~40 ms crossfade.</li>
61
- <li>L40S 48GB: faster than realtime β†’ prefer <code>pace: "realtime"</code>.</li>
62
- <li>L4 24GB: slightly under realtime even with pre-roll and tuning.</li>
63
- <li>For sustained realtime, target ~40 GB VRAM per active stream (e.g., A100 40GB or β‰ˆ35–40 GB MIG slice).</li>
64
- </ul>
65
-
66
- <p class="muted"><strong>Licensing:</strong> Uses MagentaRT (Apache 2.0 + CC-BY 4.0). Users are responsible for outputs.</p>
67
- <p>See <a href="../blob/main/docs" target="_blank">documentation files</a> for detailed guides.</p>
68
- <p>Or <a href="/docs">/docs</a> for auto-generated FastAPI reference.</p>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
  </body>
70
  </html>
 
4
  <meta charset="utf-8">
5
  <title>MagentaRT Research API</title>
6
  <style>
7
+ body {
8
+ font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif;
9
+ max-width: 900px;
10
+ margin: 48px auto;
11
+ padding: 0 24px;
12
+ color: #111;
13
+ line-height: 1.6;
14
+ }
15
+ .header { text-align: center; margin-bottom: 48px; }
16
+ .badge {
17
+ display: inline-block;
18
+ background: #ff6b35;
19
+ color: white;
20
+ padding: 4px 12px;
21
+ border-radius: 16px;
22
+ font-size: 0.85em;
23
+ font-weight: 500;
24
+ margin-left: 8px;
25
+ }
26
+ code, pre {
27
+ background: #f6f8fa;
28
+ border: 1px solid #eaecef;
29
+ border-radius: 6px;
30
+ font-family: 'SF Mono', Monaco, 'Cascadia Code', 'Roboto Mono', Consolas, monospace;
31
+ }
32
+ code { padding: 2px 6px; }
33
+ pre {
34
+ padding: 16px;
35
+ overflow-x: auto;
36
+ margin: 16px 0;
37
+ position: relative;
38
+ }
39
+ .copy-btn {
40
+ position: absolute;
41
+ top: 8px;
42
+ right: 8px;
43
+ background: #0969da;
44
+ color: white;
45
+ border: none;
46
+ border-radius: 4px;
47
+ padding: 4px 8px;
48
+ font-size: 12px;
49
+ cursor: pointer;
50
+ }
51
+ .copy-btn:hover { background: #0550ae; }
52
+ .muted { color: #656d76; }
53
+ .warning {
54
+ background: #fff8c5;
55
+ border: 1px solid #e3b341;
56
+ border-radius: 8px;
57
+ padding: 16px;
58
+ margin: 16px 0;
59
+ }
60
+ .info {
61
+ background: #dbeafe;
62
+ border: 1px solid #3b82f6;
63
+ border-radius: 8px;
64
+ padding: 16px;
65
+ margin: 16px 0;
66
+ }
67
  ul { line-height: 1.8; }
68
+ .endpoint {
69
+ background: #f8f9fa;
70
+ border-left: 4px solid #0969da;
71
+ padding: 12px 16px;
72
+ margin: 12px 0;
73
+ }
74
+ .demo-placeholder {
75
+ background: #f6f8fa;
76
+ border: 2px dashed #d1d9e0;
77
+ border-radius: 8px;
78
+ padding: 48px;
79
+ text-align: center;
80
+ margin: 24px 0;
81
+ color: #656d76;
82
+ }
83
+ .grid {
84
+ display: grid;
85
+ grid-template-columns: 1fr 1fr;
86
+ gap: 24px;
87
+ margin: 24px 0;
88
+ }
89
+ .card {
90
+ background: #f8f9fa;
91
+ border: 1px solid #e1e8ed;
92
+ border-radius: 8px;
93
+ padding: 20px;
94
+ }
95
+ a { color: #0969da; text-decoration: none; }
96
+ a:hover { text-decoration: underline; }
97
+ .section { margin: 48px 0; }
98
  </style>
99
  </head>
100
  <body>
101
+ <div class="header">
102
+ <h1>🎡 MagentaRT Research API</h1>
103
+ <p class="muted"><strong>AI Music Generation API</strong> β€’ Real-time streaming β€’ Custom fine-tuning support</p>
104
+ <span class="badge">Research Project</span>
105
+ </div>
106
+
107
+ <div class="demo-placeholder">
108
+ <h3>πŸ“± App Demo Video</h3>
109
+ <p>Demo video will be embedded here<br>
110
+ <small>Showing the iPhone app generating music in real-time</small></p>
111
+ </div>
112
+
113
+ <div class="section">
114
+ <h2>Overview</h2>
115
+ <p>This API powers AI music generation using Google's MagentaRT, designed for real-time audio streaming and custom model fine-tuning. Built for iOS app integration with WebSocket streaming support.</p>
116
+
117
+ <div class="info">
118
+ <strong>Hardware Requirements:</strong> Optimal performance requires an L40S GPU (48GB VRAM) for real-time streaming. L4 24GB works but may not maintain real-time performance.
119
+ </div>
120
+ </div>
121
+
122
+ <div class="section">
123
+ <h2>Quick Start - WebSocket Streaming</h2>
124
+ <p>Connect to <code>wss://&lt;your-space&gt;/ws/jam</code> for real-time audio generation:</p>
125
+
126
+ <h3>Start Real-time Generation</h3>
127
+ <pre><button class="copy-btn" onclick="copyCode(this)">Copy</button>{
128
  "type": "start",
129
  "mode": "rt",
130
  "binary_audio": false,
131
  "params": {
132
+ "styles": "electronic, ambient",
133
+ "style_weights": "1.0, 0.8",
134
  "temperature": 1.1,
135
  "topk": 40,
136
  "guidance_weight": 1.1,
137
+ "pace": "realtime",
138
+ "style_ramp_seconds": 8.0,
139
+ "mean": 0.0,
140
+ "centroid_weights": "0.0, 0.0, 0.0"
141
  }
142
  }</pre>
143
+
144
+ <h3>Update Parameters Live</h3>
145
+ <pre><button class="copy-btn" onclick="copyCode(this)">Copy</button>{
146
  "type": "update",
147
  "styles": "jazz, hiphop",
148
+ "style_weights": "1.0, 0.8",
149
  "temperature": 1.2,
150
  "topk": 64,
151
  "guidance_weight": 1.0,
152
+ "mean": 0.2,
153
+ "centroid_weights": "0.1, 0.3, 0.0"
154
  }</pre>
155
+
156
+ <h3>Stop Generation</h3>
157
+ <pre><button class="copy-btn" onclick="copyCode(this)">Copy</button>{"type": "stop"}</pre>
158
+ </div>
159
+
160
+ <div class="section">
161
+ <h2>API Endpoints</h2>
162
+
163
+ <div class="endpoint">
164
+ <strong>POST /generate</strong> - Generate 4–8 bars of music with input audio
165
+ </div>
166
+
167
+ <div class="endpoint">
168
+ <strong>POST /generate_style</strong> - Generate music from style prompts only (experimental)
169
+ </div>
170
+
171
+ <div class="endpoint">
172
+ <strong>POST /jam/start</strong> - Start continuous jamming session
173
+ </div>
174
+
175
+ <div class="endpoint">
176
+ <strong>GET /jam/next</strong> - Get next audio chunk from session
177
+ </div>
178
+
179
+ <div class="endpoint">
180
+ <strong>POST /jam/consume</strong> - Mark chunk as consumed
181
+ </div>
182
+
183
+ <div class="endpoint">
184
+ <strong>POST /jam/stop</strong> - End jamming session
185
+ </div>
186
+
187
+ <div class="endpoint">
188
+ <strong>WEBSOCKET /ws/jam</strong> - Real-time streaming interface
189
+ </div>
190
+
191
+ <div class="endpoint">
192
+ <strong>POST /model/select</strong> - Switch between base and fine-tuned models
193
+ </div>
194
+ </div>
195
+
196
+ <div class="section">
197
+ <h2>Custom Fine-Tuning</h2>
198
+ <p>Train your own MagentaRT models and use them with this API and the iOS app.</p>
199
+
200
+ <div class="grid">
201
+ <div class="card">
202
+ <h3>1. Train Your Model</h3>
203
+ <p>Use the official MagentaRT fine-tuning notebook:</p>
204
+ <p><a href="https://github.com/magenta-realtime/notebooks/blob/main/Magenta_RT_Finetune.ipynb" target="_blank">πŸ”— MagentaRT Fine-tuning Colab</a></p>
205
+ <p>This will create checkpoint folders like:</p>
206
+ <ul>
207
+ <li><code>checkpoint_1861001/</code></li>
208
+ <li><code>checkpoint_1862001/</code></li>
209
+ <li>And steering assets: <code>cluster_centroids.npy</code>, <code>mean_style_embed.npy</code></li>
210
+ </ul>
211
+ </div>
212
+
213
+ <div class="card">
214
+ <h3>2. Package Checkpoints</h3>
215
+ <p>Checkpoints must be compressed as .tgz files to preserve .zarray files correctly.</p>
216
+ <div class="warning">
217
+ <strong>Important:</strong> Do not download checkpoint folders directly from Google Drive - the .zarray files won't transfer properly.
218
+ </div>
219
+ </div>
220
+ </div>
221
+
222
+ <h3>Checkpoint Packaging Script</h3>
223
+ <p>Use this in a Colab cell to properly package your checkpoints:</p>
224
+ <pre><button class="copy-btn" onclick="copyCode(this)">Copy</button># Mount Drive to access your trained checkpoints
225
+ from google.colab import drive
226
+ drive.mount('/content/drive')
227
+
228
+ # Set the path to your checkpoint folder
229
+ CKPT_SRC = '/content/drive/MyDrive/thepatch/checkpoint_1862001' # Adjust path
230
+
231
+ # Copy folder to local storage (preserves dotfiles)
232
+ !rm -rf /content/checkpoint_1862001
233
+ !cp -a "$CKPT_SRC" /content/
234
+
235
+ # Verify .zarray files are present
236
+ !find /content/checkpoint_1862001 -name .zarray | wc -l
237
+
238
+ # Create properly formatted .tgz archive
239
+ !tar -C /content -czf /content/checkpoint_1862001.tgz checkpoint_1862001
240
+
241
+ # Verify critical files are in the archive
242
+ !tar -tzf /content/checkpoint_1862001.tgz | grep -c '.zarray'
243
+
244
+ # Download the .tgz file
245
+ from google.colab import files
246
+ files.download('/content/checkpoint_1862001.tgz')</pre>
247
+
248
+ <h3>3. Upload to Hugging Face</h3>
249
+ <p>Create a model repository and upload:</p>
250
+ <ul>
251
+ <li>Your <code>.tgz</code> checkpoint files</li>
252
+ <li><code>cluster_centroids.npy</code> (for steering)</li>
253
+ <li><code>mean_style_embed.npy</code> (for steering)</li>
254
+ </ul>
255
+
256
+ <div class="info">
257
+ <strong>Example Repository:</strong> <a href="https://huggingface.co/thepatch/magenta-ft" target="_blank">thepatch/magenta-ft</a><br>
258
+ Shows the correct file structure with .tgz files and .npy steering assets in the root directory.
259
+ </div>
260
+
261
+ <h3>4. Use in the App</h3>
262
+ <p>In the iOS app's model selector, point to your Hugging Face repository URL. The app will automatically discover available checkpoints and allow switching between them.</p>
263
+ </div>
264
+
265
+ <div class="section">
266
+ <h2>Technical Specifications</h2>
267
+ <ul>
268
+ <li><strong>Audio Format:</strong> 48 kHz stereo, ~2.0s chunks with ~40ms crossfade</li>
269
+ <li><strong>Model Sizes:</strong> Base and Large variants available</li>
270
+ <li><strong>Steering:</strong> Support for text prompts, audio embeddings, and centroid-based fine-tune steering</li>
271
+ <li><strong>Real-time Performance:</strong> L40S recommended; L4 may experience slight delays</li>
272
+ <li><strong>Memory Requirements:</strong> ~40GB VRAM for sustained real-time streaming</li>
273
+ </ul>
274
+
275
+ <div class="warning">
276
+ <strong>Note:</strong> The <code>/generate_style</code> endpoint is experimental and may not properly adhere to BPM without additional context (considering metronome-based context instead of silence).
277
+ </div>
278
+ </div>
279
+
280
+ <div class="section">
281
+ <h2>Integration with iOS App</h2>
282
+ <p>This API is designed to work seamlessly with our iOS music generation app:</p>
283
+ <ul>
284
+ <li>Real-time audio streaming via WebSockets</li>
285
+ <li>Dynamic model switching between base and fine-tuned models</li>
286
+ <li>Integration with stable-audio-open-small for combined input audio generation</li>
287
+ <li>Live parameter adjustment during generation</li>
288
+ </ul>
289
+ </div>
290
+
291
+ <div class="section">
292
+ <h2>Deployment</h2>
293
+ <p>To run your own instance:</p>
294
+ <ol>
295
+ <li>Duplicate this Hugging Face Space</li>
296
+ <li>Ensure you have access to an L40S GPU</li>
297
+ <li>Point your iOS app to the new space URL (e.g., <code>https://your-username-magenta-retry.hf.space</code>)</li>
298
+ <li>Upload your fine-tuned models as described above</li>
299
+ </ol>
300
+ </div>
301
+
302
+ <div class="section">
303
+ <h2>Support & Contact</h2>
304
+ <p>This is an active research project. For questions, technical support, or collaboration:</p>
305
+ <p><strong>Email:</strong> <a href="mailto:kev@thecollabagepatch.com">kev@thecollabagepatch.com</a></p>
306
+
307
+ <div class="info">
308
+ <strong>Research Status:</strong> This project is under active development. Features and API may change. We welcome feedback and contributions from the research community.
309
+ </div>
310
+ </div>
311
+
312
+ <div class="section">
313
+ <h2>Licensing</h2>
314
+ <p>Built on Google's MagentaRT (Apache 2.0 + CC-BY 4.0). Users are responsible for their generated outputs and ensuring compliance with applicable laws and platform policies.</p>
315
+ <p><a href="/docs">πŸ“– API Reference Documentation</a></p>
316
+ </div>
317
+
318
+ <script>
319
+ function copyCode(button) {
320
+ const pre = button.parentElement;
321
+ const code = pre.textContent.replace('Copy', '').trim();
322
+ navigator.clipboard.writeText(code).then(() => {
323
+ button.textContent = 'Copied!';
324
+ setTimeout(() => button.textContent = 'Copy', 2000);
325
+ });
326
+ }
327
+ </script>
328
  </body>
329
  </html>