|
""" |
|
One-shot music generation functions for MagentaRT. |
|
|
|
This module contains the core generation functions extracted from the main app |
|
that can be used independently for single-shot music generation tasks. |
|
""" |
|
import math |
|
import numpy as np |
|
from magenta_rt import audio as au |
|
from utils import ( |
|
match_loudness_to_reference, |
|
stitch_generated, |
|
hard_trim_seconds, |
|
apply_micro_fades, |
|
make_bar_aligned_context, |
|
take_bar_aligned_tail |
|
) |
|
|
|
|
|
def generate_loop_continuation_with_mrt( |
|
mrt, |
|
input_wav_path: str, |
|
bpm: float, |
|
extra_styles=None, |
|
style_weights=None, |
|
bars: int = 8, |
|
beats_per_bar: int = 4, |
|
loop_weight: float = 1.0, |
|
loudness_mode: str = "auto", |
|
loudness_headroom_db: float = 1.0, |
|
intro_bars_to_drop: int = 0, |
|
): |
|
""" |
|
Generate a continuation of an input loop using MagentaRT. |
|
|
|
Args: |
|
mrt: MagentaRT instance |
|
input_wav_path: Path to input audio file |
|
bpm: Beats per minute |
|
extra_styles: List of additional text style prompts (optional) |
|
style_weights: List of weights for style prompts (optional) |
|
bars: Number of bars to generate |
|
beats_per_bar: Beats per bar (typically 4) |
|
loop_weight: Weight for the input loop's style embedding |
|
loudness_mode: Loudness matching method ("auto", "lufs", "rms", "none") |
|
loudness_headroom_db: Headroom in dB for peak limiting |
|
intro_bars_to_drop: Number of intro bars to generate then drop |
|
|
|
Returns: |
|
Tuple of (au.Waveform output, dict loudness_stats) |
|
""" |
|
|
|
loop = au.Waveform.from_file(input_wav_path).resample(mrt.sample_rate).as_stereo() |
|
|
|
|
|
codec_fps = float(mrt.codec.frame_rate) |
|
ctx_seconds = float(mrt.config.context_length_frames) / codec_fps |
|
loop_for_context = take_bar_aligned_tail(loop, bpm, beats_per_bar, ctx_seconds) |
|
|
|
tokens_full = mrt.codec.encode(loop_for_context).astype(np.int32) |
|
tokens = tokens_full[:, :mrt.config.decoder_codec_rvq_depth] |
|
|
|
|
|
context_tokens = make_bar_aligned_context( |
|
tokens, bpm=bpm, fps=float(mrt.codec.frame_rate), |
|
ctx_frames=mrt.config.context_length_frames, beats_per_bar=beats_per_bar |
|
) |
|
state = mrt.init_state() |
|
state.context_tokens = context_tokens |
|
|
|
|
|
loop_embed = mrt.embed_style(loop_for_context) |
|
embeds, weights = [loop_embed], [float(loop_weight)] |
|
if extra_styles: |
|
for i, s in enumerate(extra_styles): |
|
if s.strip(): |
|
embeds.append(mrt.embed_style(s.strip())) |
|
w = style_weights[i] if (style_weights and i < len(style_weights)) else 1.0 |
|
weights.append(float(w)) |
|
wsum = float(sum(weights)) or 1.0 |
|
weights = [w / wsum for w in weights] |
|
combined_style = np.sum([w * e for w, e in zip(weights, embeds)], axis=0).astype(loop_embed.dtype) |
|
|
|
|
|
seconds_per_bar = beats_per_bar * (60.0 / bpm) |
|
total_secs = bars * seconds_per_bar |
|
drop_bars = max(0, int(intro_bars_to_drop)) |
|
drop_secs = min(drop_bars, bars) * seconds_per_bar |
|
gen_total_secs = total_secs + drop_secs |
|
|
|
|
|
chunk_secs = mrt.config.chunk_length_frames * mrt.config.frame_length_samples / mrt.sample_rate |
|
steps = int(math.ceil(gen_total_secs / chunk_secs)) + 1 |
|
|
|
|
|
chunks = [] |
|
for _ in range(steps): |
|
wav, state = mrt.generate_chunk(state=state, style=combined_style) |
|
chunks.append(wav) |
|
|
|
|
|
stitched = stitch_generated(chunks, mrt.sample_rate, mrt.config.crossfade_length).as_stereo() |
|
|
|
|
|
stitched = hard_trim_seconds(stitched, gen_total_secs) |
|
|
|
|
|
if drop_secs > 0: |
|
n_drop = int(round(drop_secs * stitched.sample_rate)) |
|
stitched = au.Waveform(stitched.samples[n_drop:], stitched.sample_rate) |
|
|
|
|
|
out = hard_trim_seconds(stitched, total_secs) |
|
|
|
|
|
out = out.peak_normalize(0.95) |
|
apply_micro_fades(out, 5) |
|
|
|
|
|
out, loud_stats = match_loudness_to_reference( |
|
ref=loop, target=out, |
|
method=loudness_mode, headroom_db=loudness_headroom_db |
|
) |
|
|
|
return out, loud_stats |
|
|
|
|
|
def generate_style_only_with_mrt( |
|
mrt, |
|
bpm: float, |
|
bars: int = 8, |
|
beats_per_bar: int = 4, |
|
styles: str = "warmup", |
|
style_weights: str = "", |
|
intro_bars_to_drop: int = 0, |
|
): |
|
""" |
|
Style-only, bar-aligned generation using a silent context (no input audio). |
|
Returns: (au.Waveform out, dict loud_stats_or_None) |
|
""" |
|
|
|
codec_fps = float(mrt.codec.frame_rate) |
|
ctx_seconds = float(mrt.config.context_length_frames) / codec_fps |
|
sr = int(mrt.sample_rate) |
|
|
|
silent = au.Waveform(np.zeros((int(round(ctx_seconds * sr)), 2), np.float32), sr) |
|
tokens_full = mrt.codec.encode(silent).astype(np.int32) |
|
tokens = tokens_full[:, :mrt.config.decoder_codec_rvq_depth] |
|
|
|
state = mrt.init_state() |
|
state.context_tokens = tokens |
|
|
|
|
|
prompts = [s.strip() for s in (styles.split(",") if styles else []) if s.strip()] |
|
if not prompts: |
|
prompts = ["warmup"] |
|
sw = [float(x) for x in style_weights.split(",")] if style_weights else [] |
|
embeds, weights = [], [] |
|
for i, p in enumerate(prompts): |
|
embeds.append(mrt.embed_style(p)) |
|
weights.append(sw[i] if i < len(sw) else 1.0) |
|
wsum = float(sum(weights)) or 1.0 |
|
weights = [w / wsum for w in weights] |
|
style_vec = np.sum([w * e for w, e in zip(weights, embeds)], axis=0).astype(np.float32) |
|
|
|
|
|
seconds_per_bar = beats_per_bar * (60.0 / bpm) |
|
total_secs = bars * seconds_per_bar |
|
drop_bars = max(0, int(intro_bars_to_drop)) |
|
drop_secs = min(drop_bars, bars) * seconds_per_bar |
|
gen_total_secs = total_secs + drop_secs |
|
|
|
|
|
chunk_secs = (mrt.config.chunk_length_frames * mrt.config.frame_length_samples) / float(mrt.sample_rate) |
|
|
|
|
|
steps = int(math.ceil(gen_total_secs / chunk_secs)) + 1 |
|
|
|
chunks = [] |
|
for _ in range(steps): |
|
wav, state = mrt.generate_chunk(state=state, style=style_vec) |
|
chunks.append(wav) |
|
|
|
|
|
stitched = stitch_generated(chunks, mrt.sample_rate, mrt.config.crossfade_length).as_stereo() |
|
stitched = hard_trim_seconds(stitched, gen_total_secs) |
|
|
|
if drop_secs > 0: |
|
n_drop = int(round(drop_secs * stitched.sample_rate)) |
|
stitched = au.Waveform(stitched.samples[n_drop:], stitched.sample_rate) |
|
|
|
out = hard_trim_seconds(stitched, total_secs) |
|
out = out.peak_normalize(0.95) |
|
apply_micro_fades(out, 5) |
|
|
|
return out, None |