thecollabagepatch commited on
Commit
842a99f
·
1 Parent(s): dcfd5bb

extract one-shot generation

Browse files
Files changed (3) hide show
  1. Dockerfile +1 -0
  2. app.py +165 -163
  3. one_shot_generation.py +196 -0
Dockerfile CHANGED
@@ -142,6 +142,7 @@ COPY --chown=appuser:appuser app.py /home/appuser/app/app.py
142
  COPY --chown=appuser:appuser utils.py /home/appuser/app/utils.py
143
  COPY --chown=appuser:appuser jam_worker.py /home/appuser/app/jam_worker.py
144
 
 
145
  COPY --chown=appuser:appuser documentation.html /home/appuser/app/documentation.html
146
 
147
  USER appuser
 
142
  COPY --chown=appuser:appuser utils.py /home/appuser/app/utils.py
143
  COPY --chown=appuser:appuser jam_worker.py /home/appuser/app/jam_worker.py
144
 
145
+ COPY --chown=appuser:appuser documentation.html /home/appuser/app/documentation.html
146
  COPY --chown=appuser:appuser documentation.html /home/appuser/app/documentation.html
147
 
148
  USER appuser
app.py CHANGED
@@ -46,6 +46,8 @@ from utils import (
46
  )
47
 
48
  from jam_worker import JamWorker, JamParams, JamChunk
 
 
49
  import uuid, threading
50
 
51
  import logging
@@ -560,169 +562,169 @@ try:
560
  except Exception:
561
  _HAS_LOUDNORM = False
562
 
563
- # ----------------------------
564
- # Main generation (single combined style vector)
565
- # ----------------------------
566
- def generate_loop_continuation_with_mrt(
567
- mrt,
568
- input_wav_path: str,
569
- bpm: float,
570
- extra_styles=None,
571
- style_weights=None,
572
- bars: int = 8,
573
- beats_per_bar: int = 4,
574
- loop_weight: float = 1.0,
575
- loudness_mode: str = "auto",
576
- loudness_headroom_db: float = 1.0,
577
- intro_bars_to_drop: int = 0, # <— NEW
578
- ):
579
- # Load & prep (unchanged)
580
- loop = au.Waveform.from_file(input_wav_path).resample(mrt.sample_rate).as_stereo()
581
-
582
- # Use tail for context (your recent change)
583
- codec_fps = float(mrt.codec.frame_rate)
584
- ctx_seconds = float(mrt.config.context_length_frames) / codec_fps
585
- loop_for_context = take_bar_aligned_tail(loop, bpm, beats_per_bar, ctx_seconds)
586
-
587
- tokens_full = mrt.codec.encode(loop_for_context).astype(np.int32)
588
- tokens = tokens_full[:, :mrt.config.decoder_codec_rvq_depth]
589
-
590
- # Bar-aligned token window (unchanged)
591
- context_tokens = make_bar_aligned_context(
592
- tokens, bpm=bpm, fps=float(mrt.codec.frame_rate),
593
- ctx_frames=mrt.config.context_length_frames, beats_per_bar=beats_per_bar
594
- )
595
- state = mrt.init_state()
596
- state.context_tokens = context_tokens
597
-
598
- # STYLE embed (optional: switch to loop_for_context if you want stronger “recent” bias)
599
- loop_embed = mrt.embed_style(loop_for_context)
600
- embeds, weights = [loop_embed], [float(loop_weight)]
601
- if extra_styles:
602
- for i, s in enumerate(extra_styles):
603
- if s.strip():
604
- embeds.append(mrt.embed_style(s.strip()))
605
- w = style_weights[i] if (style_weights and i < len(style_weights)) else 1.0
606
- weights.append(float(w))
607
- wsum = float(sum(weights)) or 1.0
608
- weights = [w / wsum for w in weights]
609
- combined_style = np.sum([w * e for w, e in zip(weights, embeds)], axis=0).astype(loop_embed.dtype)
610
-
611
- # --- Length math ---
612
- seconds_per_bar = beats_per_bar * (60.0 / bpm)
613
- total_secs = bars * seconds_per_bar
614
- drop_bars = max(0, int(intro_bars_to_drop))
615
- drop_secs = min(drop_bars, bars) * seconds_per_bar # clamp to <= bars
616
- gen_total_secs = total_secs + drop_secs # generate extra
617
-
618
- # Chunk scheduling to cover gen_total_secs
619
- chunk_secs = mrt.config.chunk_length_frames * mrt.config.frame_length_samples / mrt.sample_rate # ~2.0
620
- steps = int(math.ceil(gen_total_secs / chunk_secs)) + 1 # pad then trim
621
-
622
- # Generate
623
- chunks = []
624
- for _ in range(steps):
625
- wav, state = mrt.generate_chunk(state=state, style=combined_style)
626
- chunks.append(wav)
627
-
628
- # Stitch continuous audio
629
- stitched = stitch_generated(chunks, mrt.sample_rate, mrt.config.crossfade_length).as_stereo()
630
-
631
- # Trim to generated length (bars + dropped bars)
632
- stitched = hard_trim_seconds(stitched, gen_total_secs)
633
-
634
- # 👉 Drop the intro bars
635
- if drop_secs > 0:
636
- n_drop = int(round(drop_secs * stitched.sample_rate))
637
- stitched = au.Waveform(stitched.samples[n_drop:], stitched.sample_rate)
638
-
639
- # Final exact-length trim to requested bars
640
- out = hard_trim_seconds(stitched, total_secs)
641
-
642
- # Final polish AFTER drop
643
- out = out.peak_normalize(0.95)
644
- apply_micro_fades(out, 5)
645
-
646
- # Loudness match to input (after drop) so bar 1 sits right
647
- out, loud_stats = match_loudness_to_reference(
648
- ref=loop, target=out,
649
- method=loudness_mode, headroom_db=loudness_headroom_db
650
- )
651
-
652
- return out, loud_stats
653
-
654
- # untested.
655
- # not sure how it will retain the input bpm. we may want to use a metronome instead of silence. i think google might do that.
656
- # does a generation with silent context rather than a combined loop
657
- def generate_style_only_with_mrt(
658
- mrt,
659
- bpm: float,
660
- bars: int = 8,
661
- beats_per_bar: int = 4,
662
- styles: str = "warmup",
663
- style_weights: str = "",
664
- intro_bars_to_drop: int = 0,
665
- ):
666
- """
667
- Style-only, bar-aligned generation using a silent context (no input audio).
668
- Returns: (au.Waveform out, dict loud_stats_or_None)
669
- """
670
- # ---- Build a 10s silent context, tokenized for the model ----
671
- codec_fps = float(mrt.codec.frame_rate)
672
- ctx_seconds = float(mrt.config.context_length_frames) / codec_fps
673
- sr = int(mrt.sample_rate)
674
-
675
- silent = au.Waveform(np.zeros((int(round(ctx_seconds * sr)), 2), np.float32), sr)
676
- tokens_full = mrt.codec.encode(silent).astype(np.int32)
677
- tokens = tokens_full[:, :mrt.config.decoder_codec_rvq_depth]
678
-
679
- state = mrt.init_state()
680
- state.context_tokens = tokens
681
-
682
- # ---- Style vector (text prompts only, normalized weights) ----
683
- prompts = [s.strip() for s in (styles.split(",") if styles else []) if s.strip()]
684
- if not prompts:
685
- prompts = ["warmup"]
686
- sw = [float(x) for x in style_weights.split(",")] if style_weights else []
687
- embeds, weights = [], []
688
- for i, p in enumerate(prompts):
689
- embeds.append(mrt.embed_style(p))
690
- weights.append(sw[i] if i < len(sw) else 1.0)
691
- wsum = float(sum(weights)) or 1.0
692
- weights = [w / wsum for w in weights]
693
- style_vec = np.sum([w * e for w, e in zip(weights, embeds)], axis=0).astype(np.float32)
694
-
695
- # ---- Target length math ----
696
- seconds_per_bar = beats_per_bar * (60.0 / bpm)
697
- total_secs = bars * seconds_per_bar
698
- drop_bars = max(0, int(intro_bars_to_drop))
699
- drop_secs = min(drop_bars, bars) * seconds_per_bar
700
- gen_total_secs = total_secs + drop_secs
701
-
702
- # ~2.0s chunk length from model config
703
- chunk_secs = (mrt.config.chunk_length_frames * mrt.config.frame_length_samples) / float(mrt.sample_rate)
704
-
705
- # Generate enough chunks to cover total, plus a pad chunk for crossfade headroom
706
- steps = int(math.ceil(gen_total_secs / chunk_secs)) + 1
707
-
708
- chunks = []
709
- for _ in range(steps):
710
- wav, state = mrt.generate_chunk(state=state, style=style_vec)
711
- chunks.append(wav)
712
-
713
- # Stitch & trim to exact musical length
714
- stitched = stitch_generated(chunks, mrt.sample_rate, mrt.config.crossfade_length).as_stereo()
715
- stitched = hard_trim_seconds(stitched, gen_total_secs)
716
-
717
- if drop_secs > 0:
718
- n_drop = int(round(drop_secs * stitched.sample_rate))
719
- stitched = au.Waveform(stitched.samples[n_drop:], stitched.sample_rate)
720
-
721
- out = hard_trim_seconds(stitched, total_secs)
722
- out = out.peak_normalize(0.95)
723
- apply_micro_fades(out, 5)
724
-
725
- return out, None # loudness stats not applicable (no reference)
726
 
727
  def _combine_styles(mrt, styles_str: str = "", weights_str: str = ""):
728
  extra = [s.strip() for s in (styles_str or "").split(",") if s.strip()]
 
46
  )
47
 
48
  from jam_worker import JamWorker, JamParams, JamChunk
49
+ from one_shot_generation import generate_loop_continuation_with_mrt, generate_style_only_with_mrt
50
+
51
  import uuid, threading
52
 
53
  import logging
 
562
  except Exception:
563
  _HAS_LOUDNORM = False
564
 
565
+ # # ----------------------------
566
+ # # Main generation (single combined style vector)
567
+ # # ----------------------------
568
+ # def generate_loop_continuation_with_mrt(
569
+ # mrt,
570
+ # input_wav_path: str,
571
+ # bpm: float,
572
+ # extra_styles=None,
573
+ # style_weights=None,
574
+ # bars: int = 8,
575
+ # beats_per_bar: int = 4,
576
+ # loop_weight: float = 1.0,
577
+ # loudness_mode: str = "auto",
578
+ # loudness_headroom_db: float = 1.0,
579
+ # intro_bars_to_drop: int = 0, # <— NEW
580
+ # ):
581
+ # # Load & prep (unchanged)
582
+ # loop = au.Waveform.from_file(input_wav_path).resample(mrt.sample_rate).as_stereo()
583
+
584
+ # # Use tail for context (your recent change)
585
+ # codec_fps = float(mrt.codec.frame_rate)
586
+ # ctx_seconds = float(mrt.config.context_length_frames) / codec_fps
587
+ # loop_for_context = take_bar_aligned_tail(loop, bpm, beats_per_bar, ctx_seconds)
588
+
589
+ # tokens_full = mrt.codec.encode(loop_for_context).astype(np.int32)
590
+ # tokens = tokens_full[:, :mrt.config.decoder_codec_rvq_depth]
591
+
592
+ # # Bar-aligned token window (unchanged)
593
+ # context_tokens = make_bar_aligned_context(
594
+ # tokens, bpm=bpm, fps=float(mrt.codec.frame_rate),
595
+ # ctx_frames=mrt.config.context_length_frames, beats_per_bar=beats_per_bar
596
+ # )
597
+ # state = mrt.init_state()
598
+ # state.context_tokens = context_tokens
599
+
600
+ # # STYLE embed (optional: switch to loop_for_context if you want stronger “recent” bias)
601
+ # loop_embed = mrt.embed_style(loop_for_context)
602
+ # embeds, weights = [loop_embed], [float(loop_weight)]
603
+ # if extra_styles:
604
+ # for i, s in enumerate(extra_styles):
605
+ # if s.strip():
606
+ # embeds.append(mrt.embed_style(s.strip()))
607
+ # w = style_weights[i] if (style_weights and i < len(style_weights)) else 1.0
608
+ # weights.append(float(w))
609
+ # wsum = float(sum(weights)) or 1.0
610
+ # weights = [w / wsum for w in weights]
611
+ # combined_style = np.sum([w * e for w, e in zip(weights, embeds)], axis=0).astype(loop_embed.dtype)
612
+
613
+ # # --- Length math ---
614
+ # seconds_per_bar = beats_per_bar * (60.0 / bpm)
615
+ # total_secs = bars * seconds_per_bar
616
+ # drop_bars = max(0, int(intro_bars_to_drop))
617
+ # drop_secs = min(drop_bars, bars) * seconds_per_bar # clamp to <= bars
618
+ # gen_total_secs = total_secs + drop_secs # generate extra
619
+
620
+ # # Chunk scheduling to cover gen_total_secs
621
+ # chunk_secs = mrt.config.chunk_length_frames * mrt.config.frame_length_samples / mrt.sample_rate # ~2.0
622
+ # steps = int(math.ceil(gen_total_secs / chunk_secs)) + 1 # pad then trim
623
+
624
+ # # Generate
625
+ # chunks = []
626
+ # for _ in range(steps):
627
+ # wav, state = mrt.generate_chunk(state=state, style=combined_style)
628
+ # chunks.append(wav)
629
+
630
+ # # Stitch continuous audio
631
+ # stitched = stitch_generated(chunks, mrt.sample_rate, mrt.config.crossfade_length).as_stereo()
632
+
633
+ # # Trim to generated length (bars + dropped bars)
634
+ # stitched = hard_trim_seconds(stitched, gen_total_secs)
635
+
636
+ # # 👉 Drop the intro bars
637
+ # if drop_secs > 0:
638
+ # n_drop = int(round(drop_secs * stitched.sample_rate))
639
+ # stitched = au.Waveform(stitched.samples[n_drop:], stitched.sample_rate)
640
+
641
+ # # Final exact-length trim to requested bars
642
+ # out = hard_trim_seconds(stitched, total_secs)
643
+
644
+ # # Final polish AFTER drop
645
+ # out = out.peak_normalize(0.95)
646
+ # apply_micro_fades(out, 5)
647
+
648
+ # # Loudness match to input (after drop) so bar 1 sits right
649
+ # out, loud_stats = match_loudness_to_reference(
650
+ # ref=loop, target=out,
651
+ # method=loudness_mode, headroom_db=loudness_headroom_db
652
+ # )
653
+
654
+ # return out, loud_stats
655
+
656
+ # # untested.
657
+ # # not sure how it will retain the input bpm. we may want to use a metronome instead of silence. i think google might do that.
658
+ # # does a generation with silent context rather than a combined loop
659
+ # def generate_style_only_with_mrt(
660
+ # mrt,
661
+ # bpm: float,
662
+ # bars: int = 8,
663
+ # beats_per_bar: int = 4,
664
+ # styles: str = "warmup",
665
+ # style_weights: str = "",
666
+ # intro_bars_to_drop: int = 0,
667
+ # ):
668
+ # """
669
+ # Style-only, bar-aligned generation using a silent context (no input audio).
670
+ # Returns: (au.Waveform out, dict loud_stats_or_None)
671
+ # """
672
+ # # ---- Build a 10s silent context, tokenized for the model ----
673
+ # codec_fps = float(mrt.codec.frame_rate)
674
+ # ctx_seconds = float(mrt.config.context_length_frames) / codec_fps
675
+ # sr = int(mrt.sample_rate)
676
+
677
+ # silent = au.Waveform(np.zeros((int(round(ctx_seconds * sr)), 2), np.float32), sr)
678
+ # tokens_full = mrt.codec.encode(silent).astype(np.int32)
679
+ # tokens = tokens_full[:, :mrt.config.decoder_codec_rvq_depth]
680
+
681
+ # state = mrt.init_state()
682
+ # state.context_tokens = tokens
683
+
684
+ # # ---- Style vector (text prompts only, normalized weights) ----
685
+ # prompts = [s.strip() for s in (styles.split(",") if styles else []) if s.strip()]
686
+ # if not prompts:
687
+ # prompts = ["warmup"]
688
+ # sw = [float(x) for x in style_weights.split(",")] if style_weights else []
689
+ # embeds, weights = [], []
690
+ # for i, p in enumerate(prompts):
691
+ # embeds.append(mrt.embed_style(p))
692
+ # weights.append(sw[i] if i < len(sw) else 1.0)
693
+ # wsum = float(sum(weights)) or 1.0
694
+ # weights = [w / wsum for w in weights]
695
+ # style_vec = np.sum([w * e for w, e in zip(weights, embeds)], axis=0).astype(np.float32)
696
+
697
+ # # ---- Target length math ----
698
+ # seconds_per_bar = beats_per_bar * (60.0 / bpm)
699
+ # total_secs = bars * seconds_per_bar
700
+ # drop_bars = max(0, int(intro_bars_to_drop))
701
+ # drop_secs = min(drop_bars, bars) * seconds_per_bar
702
+ # gen_total_secs = total_secs + drop_secs
703
+
704
+ # # ~2.0s chunk length from model config
705
+ # chunk_secs = (mrt.config.chunk_length_frames * mrt.config.frame_length_samples) / float(mrt.sample_rate)
706
+
707
+ # # Generate enough chunks to cover total, plus a pad chunk for crossfade headroom
708
+ # steps = int(math.ceil(gen_total_secs / chunk_secs)) + 1
709
+
710
+ # chunks = []
711
+ # for _ in range(steps):
712
+ # wav, state = mrt.generate_chunk(state=state, style=style_vec)
713
+ # chunks.append(wav)
714
+
715
+ # # Stitch & trim to exact musical length
716
+ # stitched = stitch_generated(chunks, mrt.sample_rate, mrt.config.crossfade_length).as_stereo()
717
+ # stitched = hard_trim_seconds(stitched, gen_total_secs)
718
+
719
+ # if drop_secs > 0:
720
+ # n_drop = int(round(drop_secs * stitched.sample_rate))
721
+ # stitched = au.Waveform(stitched.samples[n_drop:], stitched.sample_rate)
722
+
723
+ # out = hard_trim_seconds(stitched, total_secs)
724
+ # out = out.peak_normalize(0.95)
725
+ # apply_micro_fades(out, 5)
726
+
727
+ # return out, None # loudness stats not applicable (no reference)
728
 
729
  def _combine_styles(mrt, styles_str: str = "", weights_str: str = ""):
730
  extra = [s.strip() for s in (styles_str or "").split(",") if s.strip()]
one_shot_generation.py ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ One-shot music generation functions for MagentaRT.
3
+
4
+ This module contains the core generation functions extracted from the main app
5
+ that can be used independently for single-shot music generation tasks.
6
+ """
7
+ import math
8
+ import numpy as np
9
+ from magenta_rt import audio as au
10
+ from utils import (
11
+ match_loudness_to_reference,
12
+ stitch_generated,
13
+ hard_trim_seconds,
14
+ apply_micro_fades,
15
+ make_bar_aligned_context,
16
+ take_bar_aligned_tail
17
+ )
18
+
19
+
20
+ def generate_loop_continuation_with_mrt(
21
+ mrt,
22
+ input_wav_path: str,
23
+ bpm: float,
24
+ extra_styles=None,
25
+ style_weights=None,
26
+ bars: int = 8,
27
+ beats_per_bar: int = 4,
28
+ loop_weight: float = 1.0,
29
+ loudness_mode: str = "auto",
30
+ loudness_headroom_db: float = 1.0,
31
+ intro_bars_to_drop: int = 0,
32
+ ):
33
+ """
34
+ Generate a continuation of an input loop using MagentaRT.
35
+
36
+ Args:
37
+ mrt: MagentaRT instance
38
+ input_wav_path: Path to input audio file
39
+ bpm: Beats per minute
40
+ extra_styles: List of additional text style prompts (optional)
41
+ style_weights: List of weights for style prompts (optional)
42
+ bars: Number of bars to generate
43
+ beats_per_bar: Beats per bar (typically 4)
44
+ loop_weight: Weight for the input loop's style embedding
45
+ loudness_mode: Loudness matching method ("auto", "lufs", "rms", "none")
46
+ loudness_headroom_db: Headroom in dB for peak limiting
47
+ intro_bars_to_drop: Number of intro bars to generate then drop
48
+
49
+ Returns:
50
+ Tuple of (au.Waveform output, dict loudness_stats)
51
+ """
52
+ # Load & prep (unchanged)
53
+ loop = au.Waveform.from_file(input_wav_path).resample(mrt.sample_rate).as_stereo()
54
+
55
+ # Use tail for context (your recent change)
56
+ codec_fps = float(mrt.codec.frame_rate)
57
+ ctx_seconds = float(mrt.config.context_length_frames) / codec_fps
58
+ loop_for_context = take_bar_aligned_tail(loop, bpm, beats_per_bar, ctx_seconds)
59
+
60
+ tokens_full = mrt.codec.encode(loop_for_context).astype(np.int32)
61
+ tokens = tokens_full[:, :mrt.config.decoder_codec_rvq_depth]
62
+
63
+ # Bar-aligned token window (unchanged)
64
+ context_tokens = make_bar_aligned_context(
65
+ tokens, bpm=bpm, fps=float(mrt.codec.frame_rate),
66
+ ctx_frames=mrt.config.context_length_frames, beats_per_bar=beats_per_bar
67
+ )
68
+ state = mrt.init_state()
69
+ state.context_tokens = context_tokens
70
+
71
+ # STYLE embed (optional: switch to loop_for_context if you want stronger "recent" bias)
72
+ loop_embed = mrt.embed_style(loop_for_context)
73
+ embeds, weights = [loop_embed], [float(loop_weight)]
74
+ if extra_styles:
75
+ for i, s in enumerate(extra_styles):
76
+ if s.strip():
77
+ embeds.append(mrt.embed_style(s.strip()))
78
+ w = style_weights[i] if (style_weights and i < len(style_weights)) else 1.0
79
+ weights.append(float(w))
80
+ wsum = float(sum(weights)) or 1.0
81
+ weights = [w / wsum for w in weights]
82
+ combined_style = np.sum([w * e for w, e in zip(weights, embeds)], axis=0).astype(loop_embed.dtype)
83
+
84
+ # --- Length math ---
85
+ seconds_per_bar = beats_per_bar * (60.0 / bpm)
86
+ total_secs = bars * seconds_per_bar
87
+ drop_bars = max(0, int(intro_bars_to_drop))
88
+ drop_secs = min(drop_bars, bars) * seconds_per_bar # clamp to <= bars
89
+ gen_total_secs = total_secs + drop_secs # generate extra
90
+
91
+ # Chunk scheduling to cover gen_total_secs
92
+ chunk_secs = mrt.config.chunk_length_frames * mrt.config.frame_length_samples / mrt.sample_rate # ~2.0
93
+ steps = int(math.ceil(gen_total_secs / chunk_secs)) + 1 # pad then trim
94
+
95
+ # Generate
96
+ chunks = []
97
+ for _ in range(steps):
98
+ wav, state = mrt.generate_chunk(state=state, style=combined_style)
99
+ chunks.append(wav)
100
+
101
+ # Stitch continuous audio
102
+ stitched = stitch_generated(chunks, mrt.sample_rate, mrt.config.crossfade_length).as_stereo()
103
+
104
+ # Trim to generated length (bars + dropped bars)
105
+ stitched = hard_trim_seconds(stitched, gen_total_secs)
106
+
107
+ # 👉 Drop the intro bars
108
+ if drop_secs > 0:
109
+ n_drop = int(round(drop_secs * stitched.sample_rate))
110
+ stitched = au.Waveform(stitched.samples[n_drop:], stitched.sample_rate)
111
+
112
+ # Final exact-length trim to requested bars
113
+ out = hard_trim_seconds(stitched, total_secs)
114
+
115
+ # Final polish AFTER drop
116
+ out = out.peak_normalize(0.95)
117
+ apply_micro_fades(out, 5)
118
+
119
+ # Loudness match to input (after drop) so bar 1 sits right
120
+ out, loud_stats = match_loudness_to_reference(
121
+ ref=loop, target=out,
122
+ method=loudness_mode, headroom_db=loudness_headroom_db
123
+ )
124
+
125
+ return out, loud_stats
126
+
127
+
128
+ def generate_style_only_with_mrt(
129
+ mrt,
130
+ bpm: float,
131
+ bars: int = 8,
132
+ beats_per_bar: int = 4,
133
+ styles: str = "warmup",
134
+ style_weights: str = "",
135
+ intro_bars_to_drop: int = 0,
136
+ ):
137
+ """
138
+ Style-only, bar-aligned generation using a silent context (no input audio).
139
+ Returns: (au.Waveform out, dict loud_stats_or_None)
140
+ """
141
+ # ---- Build a 10s silent context, tokenized for the model ----
142
+ codec_fps = float(mrt.codec.frame_rate)
143
+ ctx_seconds = float(mrt.config.context_length_frames) / codec_fps
144
+ sr = int(mrt.sample_rate)
145
+
146
+ silent = au.Waveform(np.zeros((int(round(ctx_seconds * sr)), 2), np.float32), sr)
147
+ tokens_full = mrt.codec.encode(silent).astype(np.int32)
148
+ tokens = tokens_full[:, :mrt.config.decoder_codec_rvq_depth]
149
+
150
+ state = mrt.init_state()
151
+ state.context_tokens = tokens
152
+
153
+ # ---- Style vector (text prompts only, normalized weights) ----
154
+ prompts = [s.strip() for s in (styles.split(",") if styles else []) if s.strip()]
155
+ if not prompts:
156
+ prompts = ["warmup"]
157
+ sw = [float(x) for x in style_weights.split(",")] if style_weights else []
158
+ embeds, weights = [], []
159
+ for i, p in enumerate(prompts):
160
+ embeds.append(mrt.embed_style(p))
161
+ weights.append(sw[i] if i < len(sw) else 1.0)
162
+ wsum = float(sum(weights)) or 1.0
163
+ weights = [w / wsum for w in weights]
164
+ style_vec = np.sum([w * e for w, e in zip(weights, embeds)], axis=0).astype(np.float32)
165
+
166
+ # ---- Target length math ----
167
+ seconds_per_bar = beats_per_bar * (60.0 / bpm)
168
+ total_secs = bars * seconds_per_bar
169
+ drop_bars = max(0, int(intro_bars_to_drop))
170
+ drop_secs = min(drop_bars, bars) * seconds_per_bar
171
+ gen_total_secs = total_secs + drop_secs
172
+
173
+ # ~2.0s chunk length from model config
174
+ chunk_secs = (mrt.config.chunk_length_frames * mrt.config.frame_length_samples) / float(mrt.sample_rate)
175
+
176
+ # Generate enough chunks to cover total, plus a pad chunk for crossfade headroom
177
+ steps = int(math.ceil(gen_total_secs / chunk_secs)) + 1
178
+
179
+ chunks = []
180
+ for _ in range(steps):
181
+ wav, state = mrt.generate_chunk(state=state, style=style_vec)
182
+ chunks.append(wav)
183
+
184
+ # Stitch & trim to exact musical length
185
+ stitched = stitch_generated(chunks, mrt.sample_rate, mrt.config.crossfade_length).as_stereo()
186
+ stitched = hard_trim_seconds(stitched, gen_total_secs)
187
+
188
+ if drop_secs > 0:
189
+ n_drop = int(round(drop_secs * stitched.sample_rate))
190
+ stitched = au.Waveform(stitched.samples[n_drop:], stitched.sample_rate)
191
+
192
+ out = hard_trim_seconds(stitched, total_secs)
193
+ out = out.peak_normalize(0.95)
194
+ apply_micro_fades(out, 5)
195
+
196
+ return out, None # loudness stats not applicable (no reference)