""" One-shot music generation functions for MagentaRT. This module contains the core generation functions extracted from the main app that can be used independently for single-shot music generation tasks. """ import math import numpy as np from magenta_rt import audio as au from utils import ( match_loudness_to_reference, stitch_generated, hard_trim_seconds, apply_micro_fades, make_bar_aligned_context, take_bar_aligned_tail ) def generate_loop_continuation_with_mrt( mrt, input_wav_path: str, bpm: float, extra_styles=None, style_weights=None, bars: int = 8, beats_per_bar: int = 4, loop_weight: float = 1.0, loudness_mode: str = "auto", loudness_headroom_db: float = 1.0, intro_bars_to_drop: int = 0, ): """ Generate a continuation of an input loop using MagentaRT. Args: mrt: MagentaRT instance input_wav_path: Path to input audio file bpm: Beats per minute extra_styles: List of additional text style prompts (optional) style_weights: List of weights for style prompts (optional) bars: Number of bars to generate beats_per_bar: Beats per bar (typically 4) loop_weight: Weight for the input loop's style embedding loudness_mode: Loudness matching method ("auto", "lufs", "rms", "none") loudness_headroom_db: Headroom in dB for peak limiting intro_bars_to_drop: Number of intro bars to generate then drop Returns: Tuple of (au.Waveform output, dict loudness_stats) """ # Load & prep (unchanged) loop = au.Waveform.from_file(input_wav_path).resample(mrt.sample_rate).as_stereo() # Use tail for context (your recent change) codec_fps = float(mrt.codec.frame_rate) ctx_seconds = float(mrt.config.context_length_frames) / codec_fps loop_for_context = take_bar_aligned_tail(loop, bpm, beats_per_bar, ctx_seconds) tokens_full = mrt.codec.encode(loop_for_context).astype(np.int32) tokens = tokens_full[:, :mrt.config.decoder_codec_rvq_depth] # Bar-aligned token window (unchanged) context_tokens = make_bar_aligned_context( tokens, bpm=bpm, fps=float(mrt.codec.frame_rate), ctx_frames=mrt.config.context_length_frames, beats_per_bar=beats_per_bar ) state = mrt.init_state() state.context_tokens = context_tokens # STYLE embed (optional: switch to loop_for_context if you want stronger "recent" bias) loop_embed = mrt.embed_style(loop_for_context) embeds, weights = [loop_embed], [float(loop_weight)] if extra_styles: for i, s in enumerate(extra_styles): if s.strip(): embeds.append(mrt.embed_style(s.strip())) w = style_weights[i] if (style_weights and i < len(style_weights)) else 1.0 weights.append(float(w)) wsum = float(sum(weights)) or 1.0 weights = [w / wsum for w in weights] combined_style = np.sum([w * e for w, e in zip(weights, embeds)], axis=0).astype(loop_embed.dtype) # --- Length math --- seconds_per_bar = beats_per_bar * (60.0 / bpm) total_secs = bars * seconds_per_bar drop_bars = max(0, int(intro_bars_to_drop)) drop_secs = min(drop_bars, bars) * seconds_per_bar # clamp to <= bars gen_total_secs = total_secs + drop_secs # generate extra # Chunk scheduling to cover gen_total_secs chunk_secs = mrt.config.chunk_length_frames * mrt.config.frame_length_samples / mrt.sample_rate # ~2.0 steps = int(math.ceil(gen_total_secs / chunk_secs)) + 1 # pad then trim # Generate chunks = [] for _ in range(steps): wav, state = mrt.generate_chunk(state=state, style=combined_style) chunks.append(wav) # Stitch continuous audio stitched = stitch_generated(chunks, mrt.sample_rate, mrt.config.crossfade_length).as_stereo() # Trim to generated length (bars + dropped bars) stitched = hard_trim_seconds(stitched, gen_total_secs) # 👉 Drop the intro bars if drop_secs > 0: n_drop = int(round(drop_secs * stitched.sample_rate)) stitched = au.Waveform(stitched.samples[n_drop:], stitched.sample_rate) # Final exact-length trim to requested bars out = hard_trim_seconds(stitched, total_secs) # Final polish AFTER drop out = out.peak_normalize(0.95) apply_micro_fades(out, 5) # Loudness match to input (after drop) so bar 1 sits right out, loud_stats = match_loudness_to_reference( ref=loop, target=out, method=loudness_mode, headroom_db=loudness_headroom_db ) return out, loud_stats def generate_style_only_with_mrt( mrt, bpm: float, bars: int = 8, beats_per_bar: int = 4, styles: str = "warmup", style_weights: str = "", intro_bars_to_drop: int = 0, ): """ Style-only, bar-aligned generation using a silent context (no input audio). Returns: (au.Waveform out, dict loud_stats_or_None) """ # ---- Build a 10s silent context, tokenized for the model ---- codec_fps = float(mrt.codec.frame_rate) ctx_seconds = float(mrt.config.context_length_frames) / codec_fps sr = int(mrt.sample_rate) silent = au.Waveform(np.zeros((int(round(ctx_seconds * sr)), 2), np.float32), sr) tokens_full = mrt.codec.encode(silent).astype(np.int32) tokens = tokens_full[:, :mrt.config.decoder_codec_rvq_depth] state = mrt.init_state() state.context_tokens = tokens # ---- Style vector (text prompts only, normalized weights) ---- prompts = [s.strip() for s in (styles.split(",") if styles else []) if s.strip()] if not prompts: prompts = ["warmup"] sw = [float(x) for x in style_weights.split(",")] if style_weights else [] embeds, weights = [], [] for i, p in enumerate(prompts): embeds.append(mrt.embed_style(p)) weights.append(sw[i] if i < len(sw) else 1.0) wsum = float(sum(weights)) or 1.0 weights = [w / wsum for w in weights] style_vec = np.sum([w * e for w, e in zip(weights, embeds)], axis=0).astype(np.float32) # ---- Target length math ---- seconds_per_bar = beats_per_bar * (60.0 / bpm) total_secs = bars * seconds_per_bar drop_bars = max(0, int(intro_bars_to_drop)) drop_secs = min(drop_bars, bars) * seconds_per_bar gen_total_secs = total_secs + drop_secs # ~2.0s chunk length from model config chunk_secs = (mrt.config.chunk_length_frames * mrt.config.frame_length_samples) / float(mrt.sample_rate) # Generate enough chunks to cover total, plus a pad chunk for crossfade headroom steps = int(math.ceil(gen_total_secs / chunk_secs)) + 1 chunks = [] for _ in range(steps): wav, state = mrt.generate_chunk(state=state, style=style_vec) chunks.append(wav) # Stitch & trim to exact musical length stitched = stitch_generated(chunks, mrt.sample_rate, mrt.config.crossfade_length).as_stereo() stitched = hard_trim_seconds(stitched, gen_total_secs) if drop_secs > 0: n_drop = int(round(drop_secs * stitched.sample_rate)) stitched = au.Waveform(stitched.samples[n_drop:], stitched.sample_rate) out = hard_trim_seconds(stitched, total_secs) out = out.peak_normalize(0.95) apply_micro_fades(out, 5) return out, None # loudness stats not applicable (no reference)