thecollabagepatch commited on
Commit
406bd0f
·
1 Parent(s): 4a4198e

one final http endpoint without input audio

Browse files
Files changed (1) hide show
  1. app.py +139 -1
app.py CHANGED
@@ -475,6 +475,77 @@ def generate_loop_continuation_with_mrt(
475
 
476
  return out, loud_stats
477
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
478
 
479
 
480
  # ----------------------------
@@ -498,7 +569,7 @@ def get_mrt():
498
  if _MRT is None:
499
  with _MRT_LOCK:
500
  if _MRT is None:
501
- _MRT = system.MagentaRT(tag="base", guidance_weight=5.0, device="gpu", lazy=False)
502
  return _MRT
503
 
504
  _WARMED = False
@@ -663,6 +734,73 @@ def generate(
663
  }
664
  return {"audio_base64": audio_b64, "metadata": metadata}
665
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
666
  # ----------------------------
667
  # the 'keep jamming' button
668
  # ----------------------------
 
475
 
476
  return out, loud_stats
477
 
478
+ def generate_style_only_with_mrt(
479
+ mrt,
480
+ bpm: float,
481
+ bars: int = 8,
482
+ beats_per_bar: int = 4,
483
+ styles: str = "warmup",
484
+ style_weights: str = "",
485
+ intro_bars_to_drop: int = 0,
486
+ ):
487
+ """
488
+ Style-only, bar-aligned generation using a silent context (no input audio).
489
+ Returns: (au.Waveform out, dict loud_stats_or_None)
490
+ """
491
+ # ---- Build a 10s silent context, tokenized for the model ----
492
+ codec_fps = float(mrt.codec.frame_rate)
493
+ ctx_seconds = float(mrt.config.context_length_frames) / codec_fps
494
+ sr = int(mrt.sample_rate)
495
+
496
+ silent = au.Waveform(np.zeros((int(round(ctx_seconds * sr)), 2), np.float32), sr)
497
+ tokens_full = mrt.codec.encode(silent).astype(np.int32)
498
+ tokens = tokens_full[:, :mrt.config.decoder_codec_rvq_depth]
499
+
500
+ state = mrt.init_state()
501
+ state.context_tokens = tokens
502
+
503
+ # ---- Style vector (text prompts only, normalized weights) ----
504
+ prompts = [s.strip() for s in (styles.split(",") if styles else []) if s.strip()]
505
+ if not prompts:
506
+ prompts = ["warmup"]
507
+ sw = [float(x) for x in style_weights.split(",")] if style_weights else []
508
+ embeds, weights = [], []
509
+ for i, p in enumerate(prompts):
510
+ embeds.append(mrt.embed_style(p))
511
+ weights.append(sw[i] if i < len(sw) else 1.0)
512
+ wsum = float(sum(weights)) or 1.0
513
+ weights = [w / wsum for w in weights]
514
+ style_vec = np.sum([w * e for w, e in zip(weights, embeds)], axis=0).astype(np.float32)
515
+
516
+ # ---- Target length math ----
517
+ seconds_per_bar = beats_per_bar * (60.0 / bpm)
518
+ total_secs = bars * seconds_per_bar
519
+ drop_bars = max(0, int(intro_bars_to_drop))
520
+ drop_secs = min(drop_bars, bars) * seconds_per_bar
521
+ gen_total_secs = total_secs + drop_secs
522
+
523
+ # ~2.0s chunk length from model config
524
+ chunk_secs = (mrt.config.chunk_length_frames * mrt.config.frame_length_samples) / float(mrt.sample_rate)
525
+
526
+ # Generate enough chunks to cover total, plus a pad chunk for crossfade headroom
527
+ steps = int(math.ceil(gen_total_secs / chunk_secs)) + 1
528
+
529
+ chunks = []
530
+ for _ in range(steps):
531
+ wav, state = mrt.generate_chunk(state=state, style=style_vec)
532
+ chunks.append(wav)
533
+
534
+ # Stitch & trim to exact musical length
535
+ stitched = stitch_generated(chunks, mrt.sample_rate, mrt.config.crossfade_length).as_stereo()
536
+ stitched = hard_trim_seconds(stitched, gen_total_secs)
537
+
538
+ if drop_secs > 0:
539
+ n_drop = int(round(drop_secs * stitched.sample_rate))
540
+ stitched = au.Waveform(stitched.samples[n_drop:], stitched.sample_rate)
541
+
542
+ out = hard_trim_seconds(stitched, total_secs)
543
+ out = out.peak_normalize(0.95)
544
+ apply_micro_fades(out, 5)
545
+
546
+ return out, None # loudness stats not applicable (no reference)
547
+
548
+
549
 
550
 
551
  # ----------------------------
 
569
  if _MRT is None:
570
  with _MRT_LOCK:
571
  if _MRT is None:
572
+ _MRT = system.MagentaRT(tag="large", guidance_weight=5.0, device="gpu", lazy=False)
573
  return _MRT
574
 
575
  _WARMED = False
 
734
  }
735
  return {"audio_base64": audio_b64, "metadata": metadata}
736
 
737
+ # new endpoint to return a bar-aligned chunk without the need for combined audio
738
+
739
+ @app.post("/generate_style")
740
+ def generate_style(
741
+ bpm: float = Form(...),
742
+ bars: int = Form(8),
743
+ beats_per_bar: int = Form(4),
744
+ styles: str = Form("warmup"),
745
+ style_weights: str = Form(""),
746
+ guidance_weight: float = Form(1.1),
747
+ temperature: float = Form(1.1),
748
+ topk: int = Form(40),
749
+ target_sample_rate: int | None = Form(None),
750
+ intro_bars_to_drop: int = Form(0),
751
+ ):
752
+ """
753
+ Style-only, bar-aligned generation (no input audio).
754
+ Seeds with 10s of silent context; outputs exactly `bars` at the requested BPM.
755
+ """
756
+ mrt = get_mrt()
757
+
758
+ # Override sampling knobs just for this request
759
+ with mrt_overrides(mrt,
760
+ guidance_weight=guidance_weight,
761
+ temperature=temperature,
762
+ topk=topk):
763
+ wav, _ = generate_style_only_with_mrt(
764
+ mrt,
765
+ bpm=bpm,
766
+ bars=bars,
767
+ beats_per_bar=beats_per_bar,
768
+ styles=styles,
769
+ style_weights=style_weights,
770
+ intro_bars_to_drop=intro_bars_to_drop,
771
+ )
772
+
773
+ # Determine target SR (defaults to model SR = 48k)
774
+ cur_sr = int(mrt.sample_rate)
775
+ target_sr = int(target_sample_rate or cur_sr)
776
+ x = wav.samples if wav.samples.ndim == 2 else wav.samples[:, None]
777
+
778
+ seconds_per_bar = (60.0 / float(bpm)) * int(beats_per_bar)
779
+ expected_secs = float(bars) * seconds_per_bar
780
+
781
+ # Snap exactly to musical length at the requested sample rate
782
+ x = resample_and_snap(x, cur_sr=cur_sr, target_sr=target_sr, seconds=expected_secs)
783
+
784
+ audio_b64, total_samples, channels = wav_bytes_base64(x, target_sr)
785
+
786
+ metadata = {
787
+ "bpm": int(round(bpm)),
788
+ "bars": int(bars),
789
+ "beats_per_bar": int(beats_per_bar),
790
+ "styles": [s.strip() for s in (styles.split(",") if styles else []) if s.strip()],
791
+ "style_weights": [float(y) for y in style_weights.split(",")] if style_weights else None,
792
+ "sample_rate": int(target_sr),
793
+ "channels": int(channels),
794
+ "crossfade_seconds": mrt.config.crossfade_length,
795
+ "seconds_per_bar": seconds_per_bar,
796
+ "loop_duration_seconds": total_samples / float(target_sr),
797
+ "guidance_weight": guidance_weight,
798
+ "temperature": temperature,
799
+ "topk": topk,
800
+ }
801
+ return {"audio_base64": audio_b64, "metadata": metadata}
802
+
803
+
804
  # ----------------------------
805
  # the 'keep jamming' button
806
  # ----------------------------