MeanAudio / app.py
AndreasXi's picture
update meanaudio_l_full
98c6962
import warnings
import spaces
warnings.filterwarnings("ignore")
import logging
from argparse import ArgumentParser
from pathlib import Path
import torch
import torchaudio
import gradio as gr
from transformers import AutoModel
import laion_clap
from meanaudio.eval_utils import (
ModelConfig,
all_model_cfg,
generate_mf,
generate_fm,
setup_eval_logging,
)
from meanaudio.model.flow_matching import FlowMatching
from meanaudio.model.mean_flow import MeanFlow
from meanaudio.model.networks import MeanAudio, get_mean_audio
from meanaudio.model.utils.features_utils import FeaturesUtils
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
import gc
import json
from datetime import datetime
from huggingface_hub import snapshot_download
import numpy as np
log = logging.getLogger()
device = "cpu"
if torch.cuda.is_available():
device = "cuda"
setup_eval_logging()
OUTPUT_DIR = Path("./output/gradio")
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
NUM_SAMPLE = 1
# 创建RLHF反馈数据目录
FEEDBACK_DIR = Path("./rlhf")
FEEDBACK_DIR.mkdir(exist_ok=True)
FEEDBACK_FILE = FEEDBACK_DIR / "user_preferences.jsonl"
# Global model cache to avoid reloading
MODEL_CACHE = {}
FEATURE_UTILS_CACHE = {}
def fade_out(x, sr, fade_ms=50):
n = len(x)
k = int(sr * fade_ms / 1000)
if k <= 0 or k >= n:
return x
w = np.linspace(1.0, 0.0, k)
x[-k:] = x[-k:] * w
return x
def ensure_models_downloaded():
for variant, model_cfg in all_model_cfg.items():
if not model_cfg.model_path.exists():
log.info(f'Model {variant} not found, downloading...')
snapshot_download(repo_id="AndreasXi/MeanAudio", local_dir="./weights")
break
def load_model_cache():
for variant in all_model_cfg.keys():
if variant in MODEL_CACHE:
return MODEL_CACHE[variant], FEATURE_UTILS_CACHE['default']
else:
log.info(f"Loading model {variant} for the first time...")
model_cfg = all_model_cfg[variant]
net = get_mean_audio(model_cfg.model_name, use_rope=True, text_c_dim=512)
net = net.to(device, torch.bfloat16).eval()
net.load_weights(torch.load(model_cfg.model_path, map_location=device, weights_only=True))
MODEL_CACHE[variant] = net
feature_utils = FeaturesUtils(
tod_vae_ckpt=model_cfg.vae_path,
enable_conditions=True,
encoder_name="t5_clap",
mode=model_cfg.mode,
bigvgan_vocoder_ckpt=model_cfg.bigvgan_16k_path,
need_vae_encoder=False
).to(device, torch.bfloat16).eval()
FEATURE_UTILS_CACHE['default'] = feature_utils
def save_preference_feedback(prompt, audio1_path, audio2_path, preference, additional_comment=""):
feedback_data = {
"timestamp": datetime.now().isoformat(),
"prompt": prompt,
"audio1_path": audio1_path,
"audio2_path": audio2_path,
"preference": preference, # "audio1", "audio2", "equal", "both_bad"
"additional_comment": additional_comment
}
with open(FEEDBACK_FILE, "a", encoding="utf-8") as f:
f.write(json.dumps(feedback_data, ensure_ascii=False) + "\n")
log.info(f"Preference feedback saved: {preference} for prompt: '{prompt[:50]}...'")
return f"✅ Thanks for your feedback, preference recorded: {preference}"
@spaces.GPU(duration=60)
@torch.inference_mode()
def generate_audio_gradio(
prompt,
duration,
cfg_strength,
num_steps,
variant,
seed
):
# update
if duration <= 0 or num_steps <= 0:
raise ValueError("Duration and number of steps must be positive.")
if variant not in all_model_cfg:
raise ValueError(f"Unknown model variant: {variant}. Available: {list(all_model_cfg.keys())}")
net, feature_utils = MODEL_CACHE[variant], FEATURE_UTILS_CACHE['default']
model = all_model_cfg[variant]
seq_cfg = model.seq_cfg
seq_cfg.duration = duration
net.update_seq_lengths(seq_cfg.latent_seq_len)
if variant == 'meanaudio_s_ac' or variant == 'meanaudio_s_full' or variant == 'meanaudio_l_full':
use_meanflow=True
elif variant == 'fluxaudio_s_full':
use_meanflow=False
if use_meanflow:
sampler = MeanFlow(steps=num_steps)
log.info("Using MeanFlow for generation.")
generation_func = generate_mf
sampler_arg_name = "mf"
cfg_strength = 0
else:
sampler = FlowMatching(
min_sigma=0, inference_mode="euler", num_steps=num_steps
)
log.info("Using FlowMatching for generation.")
generation_func = generate_fm
sampler_arg_name = "fm"
rng = torch.Generator(device=device)
rng.manual_seed(seed)
audios = generation_func(
[prompt]*NUM_SAMPLE,
negative_text=None,
feature_utils=feature_utils,
net=net,
rng=rng,
cfg_strength=cfg_strength,
**{sampler_arg_name: sampler},
)
save_paths = []
safe_prompt = (
"".join(c for c in prompt if c.isalnum() or c in (" ", "_"))
.rstrip()
.replace(" ", "_")[:50]
)
for i, audio in enumerate(audios):
audio = audio.float().cpu()
audio = fade_out(audio, seq_cfg.sampling_rate, fade_ms=100)
current_time_string = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
filename = f"{safe_prompt}_{current_time_string}_{i}.flac"
save_path = OUTPUT_DIR / filename
torchaudio.save(str(save_path), audio, seq_cfg.sampling_rate)
log.info(f"Audio saved to {save_path}")
save_paths.append(str(save_path))
if device == "cuda":
torch.cuda.empty_cache()
return save_paths[0], prompt
# Gradio input and output components
input_text = gr.Textbox(lines=2, label="Prompt")
variant = gr.Dropdown(label="Model Variant", choices=list(all_model_cfg.keys()), value='meanaudio_s_full', interactive=True)
output_audio = gr.Audio(label="Generated Audio", type="filepath")
denoising_steps = gr.Slider(minimum=1, maximum=25, value=1, step=1, label="Sampling Steps", interactive=True)
cfg_strength = gr.Slider(minimum=1, maximum=10, value=4.5, step=0.5, label="Guidance Scale", interactive=True)
duration = gr.Slider(minimum=1, maximum=30, value=10, step=1, label="Duration", interactive=True)
seed = gr.Slider(minimum=1, maximum=100, value=42, step=1, label="Seed", interactive=True)
description_text = """
### **MeanAudio** is a novel text-to-audio generator that uses **MeanFlow** to synthesize realistic and faithful audio in few sampling steps. It achieves state-of-the-art performance in single-step audio generation and delivers strong performance in multi-step audio generation.
### [📖 **Arxiv**](https://arxiv.org/abs/2508.06098) | [💻 **GitHub**](https://github.com/xiquan-li/MeanAudio) | [🤗 **Model**](https://huggingface.co/AndreasXi/MeanAudio) | [🚀 **Space**](https://huggingface.co/spaces/chenxie95/MeanAudio) | [🌐 **Project Page**](https://meanaudio.github.io/)
"""
gr_interface = gr.Interface(
fn=generate_audio_gradio,
inputs=[input_text, duration, cfg_strength, denoising_steps, variant, seed],
outputs=[
gr.Audio(label="🎵 Audio Sample", type="filepath"),
gr.Textbox(label="Prompt Used", interactive=False)
],
title="MeanAudio: Fast and Faithful Text-to-Audio Generation with Mean Flows",
description=description_text,
flagging_mode="never",
examples=[
["Guitar and piano playing a warm music, with a soft and gentle melody, perfect for a romantic evening.", 10, 3, 1, "meanaudio_s_full", 42],
["Melodic human whistling harmonizing with natural birdsong", 10, 3, 1, "meanaudio_s_full", 42],
["A parade marches through a town square, with drumbeats pounding, children clapping, and a horse neighing amidst the commotion", 10, 3, 1, "meanaudio_s_full", 42],
["Quiet speech and then and airplane flying away", 10, 3, 1, "meanaudio_s_full", 42],
["A basketball bounces rhythmically on a court, shoes squeak against the floor, and a referee’s whistle cuts through the air", 10, 3, 1, "meanaudio_s_full", 42],
["Chopping meat on a wooden table.", 10, 3, 1, "meanaudio_s_full", 42],
["A vehicle engine revving then accelerating at a high rate as a metal surface is whipped followed by tires skidding.", 10, 3, 1, "meanaudio_s_full", 42],
["Battlefield scene, continuous roar of artillery and gunfire, high fidelity, the sharp crack of bullets, the thundering explosions of bombs, and the screams of wounded soldiers.", 10, 3, 1, "meanaudio_s_full", 42],
["Pop music that upbeat, catchy, and easy to listen, high fidelity, with simple melodies, electronic instruments and polished production.", 10, 3, 1, "meanaudio_s_full", 42],
["A fast-paced instrumental piece with a classical vibe featuring stringed instruments, evoking an energetic and uplifting mood.", 10, 3, 1, "meanaudio_s_full", 42]
],
cache_examples="lazy",
)
if __name__ == "__main__":
ensure_models_downloaded()
load_model_cache()
gr_interface.queue(15).launch()