import warnings import spaces warnings.filterwarnings("ignore") import logging from argparse import ArgumentParser from pathlib import Path import torch import torchaudio import gradio as gr from transformers import AutoModel import laion_clap from meanaudio.eval_utils import ( ModelConfig, all_model_cfg, generate_mf, generate_fm, setup_eval_logging, ) from meanaudio.model.flow_matching import FlowMatching from meanaudio.model.mean_flow import MeanFlow from meanaudio.model.networks import MeanAudio, get_mean_audio from meanaudio.model.utils.features_utils import FeaturesUtils torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True import gc import json from datetime import datetime from huggingface_hub import snapshot_download import numpy as np log = logging.getLogger() device = "cpu" if torch.cuda.is_available(): device = "cuda" setup_eval_logging() OUTPUT_DIR = Path("./output/gradio") OUTPUT_DIR.mkdir(parents=True, exist_ok=True) NUM_SAMPLE = 1 # 创建RLHF反馈数据目录 FEEDBACK_DIR = Path("./rlhf") FEEDBACK_DIR.mkdir(exist_ok=True) FEEDBACK_FILE = FEEDBACK_DIR / "user_preferences.jsonl" # Global model cache to avoid reloading MODEL_CACHE = {} FEATURE_UTILS_CACHE = {} def fade_out(x, sr, fade_ms=50): n = len(x) k = int(sr * fade_ms / 1000) if k <= 0 or k >= n: return x w = np.linspace(1.0, 0.0, k) x[-k:] = x[-k:] * w return x def ensure_models_downloaded(): for variant, model_cfg in all_model_cfg.items(): if not model_cfg.model_path.exists(): log.info(f'Model {variant} not found, downloading...') snapshot_download(repo_id="AndreasXi/MeanAudio", local_dir="./weights") break def load_model_cache(): for variant in all_model_cfg.keys(): if variant in MODEL_CACHE: return MODEL_CACHE[variant], FEATURE_UTILS_CACHE['default'] else: log.info(f"Loading model {variant} for the first time...") model_cfg = all_model_cfg[variant] net = get_mean_audio(model_cfg.model_name, use_rope=True, text_c_dim=512) net = net.to(device, torch.bfloat16).eval() net.load_weights(torch.load(model_cfg.model_path, map_location=device, weights_only=True)) MODEL_CACHE[variant] = net feature_utils = FeaturesUtils( tod_vae_ckpt=model_cfg.vae_path, enable_conditions=True, encoder_name="t5_clap", mode=model_cfg.mode, bigvgan_vocoder_ckpt=model_cfg.bigvgan_16k_path, need_vae_encoder=False ).to(device, torch.bfloat16).eval() FEATURE_UTILS_CACHE['default'] = feature_utils def save_preference_feedback(prompt, audio1_path, audio2_path, preference, additional_comment=""): feedback_data = { "timestamp": datetime.now().isoformat(), "prompt": prompt, "audio1_path": audio1_path, "audio2_path": audio2_path, "preference": preference, # "audio1", "audio2", "equal", "both_bad" "additional_comment": additional_comment } with open(FEEDBACK_FILE, "a", encoding="utf-8") as f: f.write(json.dumps(feedback_data, ensure_ascii=False) + "\n") log.info(f"Preference feedback saved: {preference} for prompt: '{prompt[:50]}...'") return f"✅ Thanks for your feedback, preference recorded: {preference}" @spaces.GPU(duration=60) @torch.inference_mode() def generate_audio_gradio( prompt, duration, cfg_strength, num_steps, variant, seed ): # update if duration <= 0 or num_steps <= 0: raise ValueError("Duration and number of steps must be positive.") if variant not in all_model_cfg: raise ValueError(f"Unknown model variant: {variant}. Available: {list(all_model_cfg.keys())}") net, feature_utils = MODEL_CACHE[variant], FEATURE_UTILS_CACHE['default'] model = all_model_cfg[variant] seq_cfg = model.seq_cfg seq_cfg.duration = duration net.update_seq_lengths(seq_cfg.latent_seq_len) if variant == 'meanaudio_s_ac' or variant == 'meanaudio_s_full' or variant == 'meanaudio_l_full': use_meanflow=True elif variant == 'fluxaudio_s_full': use_meanflow=False if use_meanflow: sampler = MeanFlow(steps=num_steps) log.info("Using MeanFlow for generation.") generation_func = generate_mf sampler_arg_name = "mf" cfg_strength = 0 else: sampler = FlowMatching( min_sigma=0, inference_mode="euler", num_steps=num_steps ) log.info("Using FlowMatching for generation.") generation_func = generate_fm sampler_arg_name = "fm" rng = torch.Generator(device=device) rng.manual_seed(seed) audios = generation_func( [prompt]*NUM_SAMPLE, negative_text=None, feature_utils=feature_utils, net=net, rng=rng, cfg_strength=cfg_strength, **{sampler_arg_name: sampler}, ) save_paths = [] safe_prompt = ( "".join(c for c in prompt if c.isalnum() or c in (" ", "_")) .rstrip() .replace(" ", "_")[:50] ) for i, audio in enumerate(audios): audio = audio.float().cpu() audio = fade_out(audio, seq_cfg.sampling_rate, fade_ms=100) current_time_string = datetime.now().strftime("%Y%m%d_%H%M%S_%f") filename = f"{safe_prompt}_{current_time_string}_{i}.flac" save_path = OUTPUT_DIR / filename torchaudio.save(str(save_path), audio, seq_cfg.sampling_rate) log.info(f"Audio saved to {save_path}") save_paths.append(str(save_path)) if device == "cuda": torch.cuda.empty_cache() return save_paths[0], prompt # Gradio input and output components input_text = gr.Textbox(lines=2, label="Prompt") variant = gr.Dropdown(label="Model Variant", choices=list(all_model_cfg.keys()), value='meanaudio_s_full', interactive=True) output_audio = gr.Audio(label="Generated Audio", type="filepath") denoising_steps = gr.Slider(minimum=1, maximum=25, value=1, step=1, label="Sampling Steps", interactive=True) cfg_strength = gr.Slider(minimum=1, maximum=10, value=4.5, step=0.5, label="Guidance Scale", interactive=True) duration = gr.Slider(minimum=1, maximum=30, value=10, step=1, label="Duration", interactive=True) seed = gr.Slider(minimum=1, maximum=100, value=42, step=1, label="Seed", interactive=True) description_text = """ ### **MeanAudio** is a novel text-to-audio generator that uses **MeanFlow** to synthesize realistic and faithful audio in few sampling steps. It achieves state-of-the-art performance in single-step audio generation and delivers strong performance in multi-step audio generation. ### [📖 **Arxiv**](https://arxiv.org/abs/2508.06098) | [💻 **GitHub**](https://github.com/xiquan-li/MeanAudio) | [🤗 **Model**](https://huggingface.co/AndreasXi/MeanAudio) | [🚀 **Space**](https://huggingface.co/spaces/chenxie95/MeanAudio) | [🌐 **Project Page**](https://meanaudio.github.io/) """ gr_interface = gr.Interface( fn=generate_audio_gradio, inputs=[input_text, duration, cfg_strength, denoising_steps, variant, seed], outputs=[ gr.Audio(label="🎵 Audio Sample", type="filepath"), gr.Textbox(label="Prompt Used", interactive=False) ], title="MeanAudio: Fast and Faithful Text-to-Audio Generation with Mean Flows", description=description_text, flagging_mode="never", examples=[ ["Guitar and piano playing a warm music, with a soft and gentle melody, perfect for a romantic evening.", 10, 3, 1, "meanaudio_s_full", 42], ["Melodic human whistling harmonizing with natural birdsong", 10, 3, 1, "meanaudio_s_full", 42], ["A parade marches through a town square, with drumbeats pounding, children clapping, and a horse neighing amidst the commotion", 10, 3, 1, "meanaudio_s_full", 42], ["Quiet speech and then and airplane flying away", 10, 3, 1, "meanaudio_s_full", 42], ["A basketball bounces rhythmically on a court, shoes squeak against the floor, and a referee’s whistle cuts through the air", 10, 3, 1, "meanaudio_s_full", 42], ["Chopping meat on a wooden table.", 10, 3, 1, "meanaudio_s_full", 42], ["A vehicle engine revving then accelerating at a high rate as a metal surface is whipped followed by tires skidding.", 10, 3, 1, "meanaudio_s_full", 42], ["Battlefield scene, continuous roar of artillery and gunfire, high fidelity, the sharp crack of bullets, the thundering explosions of bombs, and the screams of wounded soldiers.", 10, 3, 1, "meanaudio_s_full", 42], ["Pop music that upbeat, catchy, and easy to listen, high fidelity, with simple melodies, electronic instruments and polished production.", 10, 3, 1, "meanaudio_s_full", 42], ["A fast-paced instrumental piece with a classical vibe featuring stringed instruments, evoking an energetic and uplifting mood.", 10, 3, 1, "meanaudio_s_full", 42] ], cache_examples="lazy", ) if __name__ == "__main__": ensure_models_downloaded() load_model_cache() gr_interface.queue(15).launch()