Spaces:
Running
Running
""" | |
One-shot music generation functions for MagentaRT. | |
This module contains the core generation functions extracted from the main app | |
that can be used independently for single-shot music generation tasks. | |
""" | |
import math | |
import numpy as np | |
from magenta_rt import audio as au | |
from utils import ( | |
match_loudness_to_reference, | |
stitch_generated, | |
hard_trim_seconds, | |
apply_micro_fades, | |
make_bar_aligned_context, | |
take_bar_aligned_tail | |
) | |
def generate_loop_continuation_with_mrt( | |
mrt, | |
input_wav_path: str, | |
bpm: float, | |
extra_styles=None, | |
style_weights=None, | |
bars: int = 8, | |
beats_per_bar: int = 4, | |
loop_weight: float = 1.0, | |
loudness_mode: str = "auto", | |
loudness_headroom_db: float = 1.0, | |
intro_bars_to_drop: int = 0, | |
): | |
""" | |
Generate a continuation of an input loop using MagentaRT. | |
Args: | |
mrt: MagentaRT instance | |
input_wav_path: Path to input audio file | |
bpm: Beats per minute | |
extra_styles: List of additional text style prompts (optional) | |
style_weights: List of weights for style prompts (optional) | |
bars: Number of bars to generate | |
beats_per_bar: Beats per bar (typically 4) | |
loop_weight: Weight for the input loop's style embedding | |
loudness_mode: Loudness matching method ("auto", "lufs", "rms", "none") | |
loudness_headroom_db: Headroom in dB for peak limiting | |
intro_bars_to_drop: Number of intro bars to generate then drop | |
Returns: | |
Tuple of (au.Waveform output, dict loudness_stats) | |
""" | |
# Load & prep (unchanged) | |
loop = au.Waveform.from_file(input_wav_path).resample(mrt.sample_rate).as_stereo() | |
# Use tail for context (your recent change) | |
codec_fps = float(mrt.codec.frame_rate) | |
ctx_seconds = float(mrt.config.context_length_frames) / codec_fps | |
loop_for_context = take_bar_aligned_tail(loop, bpm, beats_per_bar, ctx_seconds) | |
tokens_full = mrt.codec.encode(loop_for_context).astype(np.int32) | |
tokens = tokens_full[:, :mrt.config.decoder_codec_rvq_depth] | |
# Bar-aligned token window (unchanged) | |
context_tokens = make_bar_aligned_context( | |
tokens, bpm=bpm, fps=float(mrt.codec.frame_rate), | |
ctx_frames=mrt.config.context_length_frames, beats_per_bar=beats_per_bar | |
) | |
state = mrt.init_state() | |
state.context_tokens = context_tokens | |
# STYLE embed (optional: switch to loop_for_context if you want stronger "recent" bias) | |
loop_embed = mrt.embed_style(loop_for_context) | |
embeds, weights = [loop_embed], [float(loop_weight)] | |
if extra_styles: | |
for i, s in enumerate(extra_styles): | |
if s.strip(): | |
embeds.append(mrt.embed_style(s.strip())) | |
w = style_weights[i] if (style_weights and i < len(style_weights)) else 1.0 | |
weights.append(float(w)) | |
wsum = float(sum(weights)) or 1.0 | |
weights = [w / wsum for w in weights] | |
combined_style = np.sum([w * e for w, e in zip(weights, embeds)], axis=0).astype(loop_embed.dtype) | |
# --- Length math --- | |
seconds_per_bar = beats_per_bar * (60.0 / bpm) | |
total_secs = bars * seconds_per_bar | |
drop_bars = max(0, int(intro_bars_to_drop)) | |
drop_secs = min(drop_bars, bars) * seconds_per_bar # clamp to <= bars | |
gen_total_secs = total_secs + drop_secs # generate extra | |
# Chunk scheduling to cover gen_total_secs | |
chunk_secs = mrt.config.chunk_length_frames * mrt.config.frame_length_samples / mrt.sample_rate # ~2.0 | |
steps = int(math.ceil(gen_total_secs / chunk_secs)) + 1 # pad then trim | |
# Generate | |
chunks = [] | |
for _ in range(steps): | |
wav, state = mrt.generate_chunk(state=state, style=combined_style) | |
chunks.append(wav) | |
# Stitch continuous audio | |
stitched = stitch_generated(chunks, mrt.sample_rate, mrt.config.crossfade_length).as_stereo() | |
# Trim to generated length (bars + dropped bars) | |
stitched = hard_trim_seconds(stitched, gen_total_secs) | |
# π Drop the intro bars | |
if drop_secs > 0: | |
n_drop = int(round(drop_secs * stitched.sample_rate)) | |
stitched = au.Waveform(stitched.samples[n_drop:], stitched.sample_rate) | |
# Final exact-length trim to requested bars | |
out = hard_trim_seconds(stitched, total_secs) | |
# Final polish AFTER drop | |
out = out.peak_normalize(0.95) | |
apply_micro_fades(out, 5) | |
# Loudness match to input (after drop) so bar 1 sits right | |
out, loud_stats = match_loudness_to_reference( | |
ref=loop, target=out, | |
method=loudness_mode, headroom_db=loudness_headroom_db | |
) | |
return out, loud_stats | |
def generate_style_only_with_mrt( | |
mrt, | |
bpm: float, | |
bars: int = 8, | |
beats_per_bar: int = 4, | |
styles: str = "warmup", | |
style_weights: str = "", | |
intro_bars_to_drop: int = 0, | |
): | |
""" | |
Style-only, bar-aligned generation using a silent context (no input audio). | |
Returns: (au.Waveform out, dict loud_stats_or_None) | |
""" | |
# ---- Build a 10s silent context, tokenized for the model ---- | |
codec_fps = float(mrt.codec.frame_rate) | |
ctx_seconds = float(mrt.config.context_length_frames) / codec_fps | |
sr = int(mrt.sample_rate) | |
silent = au.Waveform(np.zeros((int(round(ctx_seconds * sr)), 2), np.float32), sr) | |
tokens_full = mrt.codec.encode(silent).astype(np.int32) | |
tokens = tokens_full[:, :mrt.config.decoder_codec_rvq_depth] | |
state = mrt.init_state() | |
state.context_tokens = tokens | |
# ---- Style vector (text prompts only, normalized weights) ---- | |
prompts = [s.strip() for s in (styles.split(",") if styles else []) if s.strip()] | |
if not prompts: | |
prompts = ["warmup"] | |
sw = [float(x) for x in style_weights.split(",")] if style_weights else [] | |
embeds, weights = [], [] | |
for i, p in enumerate(prompts): | |
embeds.append(mrt.embed_style(p)) | |
weights.append(sw[i] if i < len(sw) else 1.0) | |
wsum = float(sum(weights)) or 1.0 | |
weights = [w / wsum for w in weights] | |
style_vec = np.sum([w * e for w, e in zip(weights, embeds)], axis=0).astype(np.float32) | |
# ---- Target length math ---- | |
seconds_per_bar = beats_per_bar * (60.0 / bpm) | |
total_secs = bars * seconds_per_bar | |
drop_bars = max(0, int(intro_bars_to_drop)) | |
drop_secs = min(drop_bars, bars) * seconds_per_bar | |
gen_total_secs = total_secs + drop_secs | |
# ~2.0s chunk length from model config | |
chunk_secs = (mrt.config.chunk_length_frames * mrt.config.frame_length_samples) / float(mrt.sample_rate) | |
# Generate enough chunks to cover total, plus a pad chunk for crossfade headroom | |
steps = int(math.ceil(gen_total_secs / chunk_secs)) + 1 | |
chunks = [] | |
for _ in range(steps): | |
wav, state = mrt.generate_chunk(state=state, style=style_vec) | |
chunks.append(wav) | |
# Stitch & trim to exact musical length | |
stitched = stitch_generated(chunks, mrt.sample_rate, mrt.config.crossfade_length).as_stereo() | |
stitched = hard_trim_seconds(stitched, gen_total_secs) | |
if drop_secs > 0: | |
n_drop = int(round(drop_secs * stitched.sample_rate)) | |
stitched = au.Waveform(stitched.samples[n_drop:], stitched.sample_rate) | |
out = hard_trim_seconds(stitched, total_secs) | |
out = out.peak_normalize(0.95) | |
apply_micro_fades(out, 5) | |
return out, None # loudness stats not applicable (no reference) |