thecollabagepatch commited on
Commit
2ed5fff
·
1 Parent(s): 8f1aba9

fixing docs link for gradio interface

Browse files
Files changed (2) hide show
  1. app.py +204 -18
  2. documentation.html +2 -1
app.py CHANGED
@@ -292,17 +292,6 @@ def _patch_t5x_for_gpu_coords():
292
  # Call the patch immediately at import time (before MagentaRT init)
293
  _patch_t5x_for_gpu_coords()
294
 
295
- def load_doc_content(filename: str) -> str:
296
- """Load markdown content from docs directory, with fallback."""
297
- try:
298
- doc_path = Path(__file__).parent / "docs" / filename
299
- return doc_path.read_text(encoding='utf-8')
300
- except FileNotFoundError:
301
- return f"⚠️ Documentation file `{filename}` not found. Please check the docs directory."
302
- except Exception as e:
303
- return f"⚠️ Error loading `{filename}`: {e}"
304
-
305
-
306
  def create_documentation_interface():
307
  """Create a Gradio interface for documentation and transparency"""
308
  with gr.Blocks(title="MagentaRT Research API", theme=gr.themes.Soft()) as interface:
@@ -322,31 +311,223 @@ continuous music either as **bar-aligned chunks over HTTP** or as **low-latency
322
  # About & current status
323
  # ------------------------------------------------------------------
324
  with gr.Tab("📖 About & Status"):
325
- gr.Markdown(load_doc_content("about_status.md"))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
326
 
327
  # ------------------------------------------------------------------
328
  # HTTP API
329
  # ------------------------------------------------------------------
330
  with gr.Tab("🔧 API (HTTP)"):
331
- gr.Markdown(load_doc_content("api_http.md"))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
332
 
333
  # ------------------------------------------------------------------
334
- # WebSocket API: realtime ('rt' mode)
335
  # ------------------------------------------------------------------
336
  with gr.Tab("🧩 API (WebSocket • rt mode)"):
337
- gr.Markdown(load_doc_content("api_websocket.md"))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
338
 
339
  # ------------------------------------------------------------------
340
  # Performance & hardware guidance
341
  # ------------------------------------------------------------------
342
  with gr.Tab("📊 Performance & Hardware"):
343
- gr.Markdown(load_doc_content("performance.md"))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
344
 
345
  # ------------------------------------------------------------------
346
  # Changelog & legal
347
  # ------------------------------------------------------------------
348
  with gr.Tab("🗒️ Changelog & Legal"):
349
- gr.Markdown(load_doc_content("changelog.md"))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
350
 
351
  gr.Markdown(
352
  r"""
@@ -1935,4 +2116,9 @@ def read_root():
1935
  <p>Documentation file not found. Please check documentation.html</p>
1936
  </body></html>
1937
  """
1938
- return Response(content=html_content, media_type="text/html")
 
 
 
 
 
 
292
  # Call the patch immediately at import time (before MagentaRT init)
293
  _patch_t5x_for_gpu_coords()
294
 
 
 
 
 
 
 
 
 
 
 
 
295
  def create_documentation_interface():
296
  """Create a Gradio interface for documentation and transparency"""
297
  with gr.Blocks(title="MagentaRT Research API", theme=gr.themes.Soft()) as interface:
 
311
  # About & current status
312
  # ------------------------------------------------------------------
313
  with gr.Tab("📖 About & Status"):
314
+ gr.Markdown(
315
+ r"""
316
+ ## What this is
317
+ We're exploring AI‑assisted loop‑based music creation that can run on GPUs (not just TPUs) and stream to apps in realtime.
318
+
319
+ ### Implemented backends
320
+ - **HTTP (bar‑aligned):** `/generate`, `/jam/start`, `/jam/next`, `/jam/stop`, `/jam/update`, etc.
321
+ - **WebSocket (realtime):** `ws://…/ws/jam` with `mode="rt"` (Colab‑style continuous chunks). New in this build.
322
+
323
+ ## What we learned (GPU notes)
324
+ - **L40S 48GB:** comfortably **faster than realtime** → we added a `pace: "realtime"` switch so the server doesn’t outrun playback.
325
+ - **L4 24GB:** **consistently just under realtime**; even with pre‑roll buffering, TF32/JAX tunings, reduced chunk size, and the **base** checkpoint, we still see eventual under‑runs.
326
+ - **Implication:** For production‑quality realtime, aim for ~**40GB VRAM** per user/session (e.g., **A100 40GB**, or MIG slices ≈ **35–40GB** on newer parts). Smaller GPUs can demo, but sustained realtime is not reliable.
327
+
328
+ ## Model / audio specs
329
+ - **Model:** MagentaRT (T5X; decoder RVQ depth = 16)
330
+ - **Audio:** 48 kHz stereo, 2.0 s chunks by default, 40 ms crossfade
331
+ - **Context:** 10 s rolling context window
332
+ """
333
+ )
334
 
335
  # ------------------------------------------------------------------
336
  # HTTP API
337
  # ------------------------------------------------------------------
338
  with gr.Tab("🔧 API (HTTP)"):
339
+ gr.Markdown(
340
+ r"""
341
+ ### Single Generation
342
+ ```bash
343
+ curl -X POST \
344
+ "$HOST/generate" \
345
+ -F "loop_audio=@drum_loop.wav" \
346
+ -F "bpm=120" \
347
+ -F "bars=8" \
348
+ -F "styles=acid house,techno" \
349
+ -F "guidance_weight=5.0" \
350
+ -F "temperature=1.1"
351
+ ```
352
+
353
+ ### Continuous Jamming (bar‑aligned, HTTP)
354
+ ```bash
355
+ # 1) Start a session
356
+ echo $(curl -s -X POST "$HOST/jam/start" \
357
+ -F "loop_audio=@loop.wav" \
358
+ -F "bpm=120" \
359
+ -F "bars_per_chunk=8") | jq .
360
+ # → {"session_id":"…"}
361
+
362
+ # 2) Pull next chunk (repeat)
363
+ curl "$HOST/jam/next?session_id=$SESSION"
364
+
365
+ # 3) Stop
366
+ curl -X POST "$HOST/jam/stop" \
367
+ -H "Content-Type: application/json" \
368
+ -d '{"session_id":"'$SESSION'"}'
369
+ ```
370
+
371
+ ### Common parameters
372
+ - **bpm** *(int)* – beats per minute
373
+ - **bars / bars_per_chunk** *(int)* – musical length
374
+ - **styles** *(str)* – comma‑separated text prompts (mixed internally)
375
+ - **guidance_weight** *(float)* – style adherence (CFG weight)
376
+ - **temperature / topk** – sampling controls
377
+ - **intro_bars_to_drop** *(int, /generate)* – generate-and-trim intro
378
+ """
379
+ )
380
 
381
  # ------------------------------------------------------------------
382
+ # WebSocket API: realtime (rt mode)
383
  # ------------------------------------------------------------------
384
  with gr.Tab("🧩 API (WebSocket • rt mode)"):
385
+ gr.Markdown(
386
+ r"""
387
+ Connect to `wss://…/ws/jam` and send a **JSON control stream**. In `rt` mode the server emits ~2 s WAV chunks (or binary frames) continuously.
388
+
389
+ ### Start (client → server)
390
+ ```jsonc
391
+ {
392
+ "type": "start",
393
+ "mode": "rt",
394
+ "binary_audio": false, // true → raw WAV bytes + separate chunk_meta
395
+ "params": {
396
+ "styles": "heavy metal", // or "jazz, hiphop"
397
+ "style_weights": "1.0,1.0", // optional, auto‑normalized
398
+ "temperature": 1.1,
399
+ "topk": 40,
400
+ "guidance_weight": 1.1,
401
+ "pace": "realtime", // "realtime" | "asap" (default)
402
+ "max_decode_frames": 50 // 50≈2.0s; try 36–45 on smaller GPUs
403
+ }
404
+ }
405
+ ```
406
+
407
+ ### Server events (server → client)
408
+ - `{"type":"started","mode":"rt"}` – handshake
409
+ - `{"type":"chunk","audio_base64":"…","metadata":{…}}` – base64 WAV
410
+ - `metadata.sample_rate` *(int)* – usually 48000
411
+ - `metadata.chunk_frames` *(int)* – e.g., 50
412
+ - `metadata.chunk_seconds` *(float)* – frames / 25.0
413
+ - `metadata.crossfade_seconds` *(float)* – typically 0.04
414
+ - `{"type":"chunk_meta","metadata":{…}}` – sent **after** a binary frame when `binary_audio=true`
415
+ - `{"type":"status",…}`, `{"type":"error",…}`, `{"type":"stopped"}`
416
+
417
+ ### Update (client → server)
418
+ ```jsonc
419
+ {
420
+ "type": "update",
421
+ "styles": "jazz, hiphop",
422
+ "style_weights": "1.0,0.8",
423
+ "temperature": 1.2,
424
+ "topk": 64,
425
+ "guidance_weight": 1.0,
426
+ "pace": "realtime", // optional live flip
427
+ "max_decode_frames": 40 // optional; <= 50
428
+ }
429
+ ```
430
+
431
+ ### Stop / ping
432
+ ```json
433
+ {"type":"stop"}
434
+ {"type":"ping"}
435
+ ```
436
+
437
+ ### Browser quick‑start (schedules seamlessly with 25–40 ms crossfade)
438
+ ```html
439
+ <script>
440
+ const XFADE = 0.025; // 25 ms
441
+ let ctx, gain, ws, nextTime = 0;
442
+ async function start(){
443
+ ctx = new (window.AudioContext||window.webkitAudioContext)();
444
+ gain = ctx.createGain(); gain.connect(ctx.destination);
445
+ ws = new WebSocket("wss://YOUR_SPACE/ws/jam");
446
+ ws.onopen = ()=> ws.send(JSON.stringify({
447
+ type:"start", mode:"rt", binary_audio:false,
448
+ params:{ styles:"warmup", temperature:1.1, topk:40, guidance_weight:1.1, pace:"realtime" }
449
+ }));
450
+ ws.onmessage = async ev => {
451
+ const msg = JSON.parse(ev.data);
452
+ if (msg.type === "chunk" && msg.audio_base64){
453
+ const bin = atob(msg.audio_base64); const buf = new Uint8Array(bin.length);
454
+ for (let i=0;i<bin.length;i++) buf[i] = bin.charCodeAt(i);
455
+ const ab = buf.buffer; const audio = await ctx.decodeAudioData(ab);
456
+ const src = ctx.createBufferSource(); const g = ctx.createGain();
457
+ src.buffer = audio; src.connect(g); g.connect(gain);
458
+ if (nextTime < ctx.currentTime + 0.05) nextTime = ctx.currentTime + 0.12;
459
+ const startAt = nextTime, dur = audio.duration;
460
+ nextTime = startAt + Math.max(0, dur - XFADE);
461
+ g.gain.setValueAtTime(0, startAt);
462
+ g.gain.linearRampToValueAtTime(1, startAt + XFADE);
463
+ g.gain.setValueAtTime(1, startAt + Math.max(0, dur - XFADE));
464
+ g.gain.linearRampToValueAtTime(0, startAt + dur);
465
+ src.start(startAt);
466
+ }
467
+ };
468
+ }
469
+ </script>
470
+ ```
471
+
472
+ ### Python client (async)
473
+ ```python
474
+ import asyncio, json, websockets, base64, soundfile as sf, io
475
+ async def run(url):
476
+ async with websockets.connect(url) as ws:
477
+ await ws.send(json.dumps({"type":"start","mode":"rt","binary_audio":False,
478
+ "params": {"styles":"warmup","temperature":1.1,"topk":40,"guidance_weight":1.1,"pace":"realtime"}}))
479
+ while True:
480
+ msg = json.loads(await ws.recv())
481
+ if msg.get("type") == "chunk":
482
+ wav = base64.b64decode(msg["audio_base64"]) # bytes of a WAV
483
+ x, sr = sf.read(io.BytesIO(wav), dtype="float32")
484
+ print("chunk", x.shape, sr)
485
+ elif msg.get("type") in ("stopped","error"): break
486
+ asyncio.run(run("wss://YOUR_SPACE/ws/jam"))
487
+ ```
488
+ """
489
+ )
490
 
491
  # ------------------------------------------------------------------
492
  # Performance & hardware guidance
493
  # ------------------------------------------------------------------
494
  with gr.Tab("📊 Performance & Hardware"):
495
+ gr.Markdown(
496
+ r"""
497
+ ### Current observations
498
+ - **L40S 48GB** → faster than realtime. Use `pace:"realtime"` to avoid client over‑buffering.
499
+ - **L4 24GB** → slightly **below** realtime even with pre‑roll buffering, TF32/Autotune, smaller chunks (`max_decode_frames`), and the **base** checkpoint.
500
+
501
+ ### Practical guidance
502
+ - For consistent realtime, target **~40GB VRAM per active stream** (e.g., **A100 40GB**, or MIG slices ≈ **35–40GB** on newer GPUs).
503
+ - Keep client‑side **overlap‑add** (25–40 ms) for seamless chunk joins.
504
+ - Prefer **`pace:"realtime"`** once playback begins; use **ASAP** only to build a short pre‑roll if needed.
505
+ - Optional knob: **`max_decode_frames`** (default **50** ≈ 2.0 s). Reducing to **36–45** can lower per‑chunk latency/VRAM, but doesn’t increase frames/sec throughput.
506
+
507
+ ### Concurrency
508
+ This research build is designed for **one active jam per GPU**. Concurrency would require GPU partitioning (MIG) or horizontal scaling with a session scheduler.
509
+ """
510
+ )
511
 
512
  # ------------------------------------------------------------------
513
  # Changelog & legal
514
  # ------------------------------------------------------------------
515
  with gr.Tab("🗒️ Changelog & Legal"):
516
+ gr.Markdown(
517
+ r"""
518
+ ### Recent changes
519
+ - New **WebSocket realtime** route: `/ws/jam` (`mode:"rt"`)
520
+ - Added server pacing flag: `pace: "realtime" | "asap"`
521
+ - Exposed `max_decode_frames` for shorter chunks on smaller GPUs
522
+ - Client test page now does proper **overlap‑add** crossfade between chunks
523
+
524
+ ### Licensing
525
+ This project uses MagentaRT under:
526
+ - **Code:** Apache 2.0
527
+ - **Model weights:** CC‑BY 4.0
528
+ Please review the MagentaRT repo for full terms.
529
+ """
530
+ )
531
 
532
  gr.Markdown(
533
  r"""
 
2116
  <p>Documentation file not found. Please check documentation.html</p>
2117
  </body></html>
2118
  """
2119
+ return Response(content=html_content, media_type="text/html")
2120
+
2121
+ @app.get("/documentation")
2122
+ def documentation():
2123
+ interface = create_documentation_interface()
2124
+ return gr.mount_gradio_app(app, interface, path="/documentation")
documentation.html CHANGED
@@ -64,6 +64,7 @@
64
  </ul>
65
 
66
  <p class="muted"><strong>Licensing:</strong> Uses MagentaRT (Apache 2.0 + CC-BY 4.0). Users are responsible for outputs.</p>
67
- <p>See <a href="/docs">/docs</a> for full API details and client examples.</p>
 
68
  </body>
69
  </html>
 
64
  </ul>
65
 
66
  <p class="muted"><strong>Licensing:</strong> Uses MagentaRT (Apache 2.0 + CC-BY 4.0). Users are responsible for outputs.</p>
67
+ <p>See <a href="/documentation">/documentation</a> for full API details and client examples.</p>
68
+ <p>Or <a href="/docs">/docs</a> for auto-generated API reference.</p>
69
  </body>
70
  </html>