import warnings import spaces warnings.filterwarnings("ignore") import logging from argparse import ArgumentParser from pathlib import Path import torch import torchaudio import gradio as gr from transformers import AutoModel import laion_clap from meanaudio.eval_utils import ( ModelConfig, all_model_cfg, generate_mf, generate_fm, setup_eval_logging, ) from meanaudio.model.flow_matching import FlowMatching from meanaudio.model.mean_flow import MeanFlow from meanaudio.model.networks import MeanAudio, get_mean_audio from meanaudio.model.utils.features_utils import FeaturesUtils torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True import gc import json from datetime import datetime from huggingface_hub import snapshot_download import numpy as np log = logging.getLogger() device = "cpu" if torch.cuda.is_available(): device = "cuda" setup_eval_logging() OUTPUT_DIR = Path("./output/gradio") OUTPUT_DIR.mkdir(parents=True, exist_ok=True) NUM_SAMPLE = 1 # 创建RLHF反馈数据目录 FEEDBACK_DIR = Path("./rlhf") FEEDBACK_DIR.mkdir(exist_ok=True) FEEDBACK_FILE = FEEDBACK_DIR / "user_preferences.jsonl" # Global model cache to avoid reloading MODEL_CACHE = {} FEATURE_UTILS_CACHE = {} def fade_out(x, sr, fade_ms=50): n = len(x) k = int(sr * fade_ms / 1000) if k <= 0 or k >= n: return x w = np.linspace(1.0, 0.0, k) x[-k:] = x[-k:] * w return x def ensure_models_downloaded(): for variant, model_cfg in all_model_cfg.items(): if not model_cfg.model_path.exists(): log.info(f'Model {variant} not found, downloading...') snapshot_download(repo_id="AndreasXi/MeanAudio", local_dir="./weights") break def load_model_cache(): for variant in all_model_cfg.keys(): if variant in MODEL_CACHE: return MODEL_CACHE[variant], FEATURE_UTILS_CACHE['default'] else: log.info(f"Loading model {variant} for the first time...") model_cfg = all_model_cfg[variant] net = get_mean_audio(model_cfg.model_name, use_rope=True, text_c_dim=512) net = net.to(device, torch.bfloat16).eval() net.load_weights(torch.load(model_cfg.model_path, map_location=device, weights_only=True)) MODEL_CACHE[variant] = net feature_utils = FeaturesUtils( tod_vae_ckpt=model_cfg.vae_path, enable_conditions=True, encoder_name="t5_clap", mode=model_cfg.mode, bigvgan_vocoder_ckpt=model_cfg.bigvgan_16k_path, need_vae_encoder=False ).to(device, torch.bfloat16).eval() FEATURE_UTILS_CACHE['default'] = feature_utils def save_preference_feedback(prompt, audio1_path, audio2_path, preference, additional_comment=""): feedback_data = { "timestamp": datetime.now().isoformat(), "prompt": prompt, "audio1_path": audio1_path, "audio2_path": audio2_path, "preference": preference, # "audio1", "audio2", "equal", "both_bad" "additional_comment": additional_comment } with open(FEEDBACK_FILE, "a", encoding="utf-8") as f: f.write(json.dumps(feedback_data, ensure_ascii=False) + "\n") log.info(f"Preference feedback saved: {preference} for prompt: '{prompt[:50]}...'") return f"✅ Thanks for your feedback, preference recorded: {preference}" @spaces.GPU(duration=60) @torch.inference_mode() def generate_audio_gradio( prompt, duration, cfg_strength, num_steps, variant, ): if duration <= 0 or num_steps <= 0: raise ValueError("Duration and number of steps must be positive.") if variant not in all_model_cfg: raise ValueError(f"Unknown model variant: {variant}. Available: {list(all_model_cfg.keys())}") net, feature_utils = MODEL_CACHE[variant], FEATURE_UTILS_CACHE['default'] model = all_model_cfg[variant] seq_cfg = model.seq_cfg seq_cfg.duration = duration net.update_seq_lengths(seq_cfg.latent_seq_len) if variant == 'meanaudio_s_ac' or variant == 'meanaudio_s_full': use_meanflow=True elif variant == 'fluxaudio_s_full': use_meanflow=False if use_meanflow: sampler = MeanFlow(steps=num_steps) log.info("Using MeanFlow for generation.") generation_func = generate_mf sampler_arg_name = "mf" cfg_strength = 0 else: sampler = FlowMatching( min_sigma=0, inference_mode="euler", num_steps=num_steps ) log.info("Using FlowMatching for generation.") generation_func = generate_fm sampler_arg_name = "fm" rng = torch.Generator(device=device) # force to 42 rng.manual_seed(42) audios = generation_func( [prompt]*NUM_SAMPLE, negative_text=None, feature_utils=feature_utils, net=net, rng=rng, cfg_strength=cfg_strength, **{sampler_arg_name: sampler}, ) save_paths = [] safe_prompt = ( "".join(c for c in prompt if c.isalnum() or c in (" ", "_")) .rstrip() .replace(" ", "_")[:50] ) for i, audio in enumerate(audios): audio = audio.float().cpu() audio = fade_out(audio, seq_cfg.sampling_rate) current_time_string = datetime.now().strftime("%Y%m%d_%H%M%S_%f") filename = f"{safe_prompt}_{current_time_string}_{i}.flac" save_path = OUTPUT_DIR / filename torchaudio.save(str(save_path), audio, seq_cfg.sampling_rate) log.info(f"Audio saved to {save_path}") save_paths.append(str(save_path)) if device == "cuda": torch.cuda.empty_cache() return save_paths[0], prompt # Gradio input and output components input_text = gr.Textbox(lines=2, label="Prompt") output_audio = gr.Audio(label="Generated Audio", type="filepath") denoising_steps = gr.Slider(minimum=1, maximum=25, value=1, step=1, label="Sampling Steps", interactive=True) cfg_strength = gr.Slider(minimum=1, maximum=10, value=4.5, step=0.5, label="Guidance Scale", interactive=True) duration = gr.Slider(minimum=1, maximum=30, value=10, step=1, label="Duration", interactive=True) # seed = gr.Slider(minimum=1, maximum=1000000, value=42, step=1, label="Seed", interactive=True) variant = gr.Dropdown(label="Model Variant", choices=list(all_model_cfg.keys()), value='meanaudio_s_full', interactive=True) description_text = """ **MeanAudio** is a novel text-to-audio generator that uses **MeanFlow** to synthesize realistic and faithful audio in few sampling steps. It achieves state-of-the-art performance in single-step audio generation and delivers strong performance in multi-step audio generation.
""" gr_interface = gr.Interface( fn=generate_audio_gradio, inputs=[input_text, duration, cfg_strength, denoising_steps, variant], outputs=[ gr.Audio(label="🎵 Audio Sample", type="filepath"), gr.Textbox(label="Prompt Used", interactive=False) ], title="MeanAudio: Fast and Faithful Text-to-Audio Generation with Mean Flows", description="", flagging_mode="never", examples=[ ["Generate the festive sounds of a fireworks show: explosions lighting up the sky, crowd cheering, and the faint music playing in the background!! Celebration of the new year!", 10, 3, 1, "meanaudio_s_full"], ["Melodic human whistling harmonizing with natural birdsong", 10, 3, 1, "meanaudio_s_full"], ["A parade marches through a town square, with drumbeats pounding, children clapping, and a horse neighing amidst the commotion", 10, 3, 1, "meanaudio_s_full"], ["Quiet speech and then and airplane flying away", 10, 3, 1, "meanaudio_s_full"], ["A soccer ball hits a goalpost with a metallic clang, followed by cheers, clapping, and the distant hum of a commentator’s voice", 10, 3, 1, "meanaudio_s_full"], ["A basketball bounces rhythmically on a court, shoes squeak against the floor, and a referee’s whistle cuts through the air", 10, 3, 1, "meanaudio_s_full"], ["Dripping water echoes sharply, a distant growl reverberates through the cavern, and soft scraping metal suggests something lurking unseen", 10, 3, 1, "meanaudio_s_full"], ["A cow is mooing whilst a lion is roaring in the background as a hunter shoots. A flock of birds subsequently fly away from the trees.", 10, 3, 1, "meanaudio_s_full"], ["The deep growl of an alligator ripples through the swamp as reeds sway with a soft rustle and a turtle splashes into the murky water", 10, 3, 1, "meanaudio_s_full"], ["Gentle female voice cooing and baby responding with happy gurgles and giggles", 10, 3, 1, "meanaudio_s_full"], ['doorbell ding once followed by footsteps gradually getting louder and a door is opened ', 10, 3, 1, "meanaudio_s_full"], ["A fork scrapes a plate, water drips slowly into a sink, and the faint hum of a refrigerator lingers in the background", 10, 3, 1, "meanaudio_s_full"] ], cache_examples="lazy", ) if __name__ == "__main__": ensure_models_downloaded() load_model_cache() gr_interface.queue(15).launch() # theme = gr.themes.Soft( # primary_hue="blue", # secondary_hue="slate", # neutral_hue="slate", # text_size="sm", # spacing_size="sm", # ).set( # background_fill_primary="*neutral_50", # background_fill_secondary="*background_fill_primary", # block_background_fill="*background_fill_primary", # block_border_width="0px", # panel_background_fill="*neutral_50", # panel_border_width="0px", # input_background_fill="*neutral_100", # input_border_color="*neutral_200", # button_primary_background_fill="*primary_300", # button_primary_background_fill_hover="*primary_400", # button_secondary_background_fill="*neutral_200", # button_secondary_background_fill_hover="*neutral_300", # ) # custom_css = """ # #main-headertitle { # text-align: center; # margin-top: 15px; # margin-bottom: 10px; # color: var(--neutral-600); # font-weight: 600; # } # #main-header { # text-align: center; # margin-top: 5px; # margin-bottom: 10px; # color: var(--neutral-600); # font-weight: 600; # } # #model-settings-header, #generation-settings-header { # color: var(--neutral-600); # margin-top: 8px; # margin-bottom: 8px; # font-weight: 500; # font-size: 1.1em; # } # .setting-section { # padding: 10px 12px; # border-radius: 6px; # background-color: var(--neutral-50); # margin-bottom: 10px; # border: 1px solid var(--neutral-100); # } # hr { # border: none; # height: 1px; # background-color: var(--neutral-200); # margin: 8px 0; # } # #generate-btn { # width: 100%; # max-width: 250px; # margin: 10px auto; # display: block; # padding: 10px 15px; # font-size: 16px; # border-radius: 5px; # } # #status-box { # min-height: 50px; # display: flex; # align-items: center; # justify-content: center; # padding: 8px; # border-radius: 5px; # border: 1px solid var(--neutral-200); # color: var(--neutral-700); # } # #project-badges { # text-align: center; # margin-top: 30px; # margin-bottom: 20px; # } # #project-badges #badge-container { # display: flex; # gap: 10px; # align-items: center; # justify-content: center; # flex-wrap: wrap; # } # #project-badges img { # border-radius: 5px; # box-shadow: 0 1px 3px rgba(0, 0, 0, 0.1); # height: 20px; # transition: transform 0.1s ease, box-shadow 0.1s ease; # } # #project-badges a:hover img { # transform: translateY(-2px); # box-shadow: 0 4px 8px rgba(0, 0, 0, 0.15); # } # #audio-output { # height: 200px; # border-radius: 5px; # border: 1px solid var(--neutral-200); # } # .gradio-dropdown label, .gradio-checkbox label, .gradio-number label, .gradio-textbox label { # font-weight: 500; # color: var(--neutral-700); # font-size: 0.9em; # } # .gradio-row { # gap: 8px; # } # .gradio-block { # margin-bottom: 8px; # } # .setting-section .gradio-block { # margin-bottom: 6px; # } # ::-webkit-scrollbar { # width: 8px; # height: 8px; # } # ::-webkit-scrollbar-track { # background: var(--neutral-100); # border-radius: 4px; # } # ::-webkit-scrollbar-thumb { # background: var(--neutral-300); # border-radius: 4px; # } # ::-webkit-scrollbar-thumb:hover { # background: var(--neutral-400); # } # * { # scrollbar-width: thin; # scrollbar-color: var(--neutral-300) var(--neutral-100); # } # """ # with gr.Blocks(title="MeanAudio Generator", theme=theme, css=custom_css) as demo: # gr.Markdown("# MeanAudio: Fast and Faithful Text-to-Audio Generation with Mean Flows", elem_id="main-header") # badge_html = ''' # # ''' # gr.HTML(badge_html) # with gr.Column(elem_classes="setting-section"): # with gr.Row(): # available_variants = ( # list(all_model_cfg.keys()) if all_model_cfg else [] # ) # default_variant = ( # 'meanaudio_mf' # ) # variant = gr.Dropdown( # label="Model Variant", # choices=available_variants, # value=default_variant, # interactive=True, # scale=3, # ) # with gr.Column(elem_classes="setting-section"): # with gr.Row(): # prompt = gr.Textbox( # label="Prompt", # placeholder="Describe the sound you want to generate...", # scale=1, # ) # negative_prompt = gr.Textbox( # label="Negative Prompt", # placeholder="Describe sounds you want to avoid...", # value="", # scale=1, # ) # with gr.Row(): # duration = gr.Number( # label="Duration (sec)", value=10.0, minimum=0.1, scale=1 # ) # cfg_strength = gr.Number( # label="CFG (Meanflow forced to 3)", value=3, minimum=0.0, scale=1 # ) # with gr.Row(): # seed = gr.Number( # label="Seed (-1 for random)", value=42, precision=0, scale=1 # ) # num_steps = gr.Number( # label="Number of Steps", # value=1, # precision=0, # minimum=1, # scale=1, # ) # generate_button = gr.Button("Generate", variant="primary", elem_id="generate-btn") # generate_output_text = gr.Textbox( # label="Result Status", interactive=False, elem_id="status-box" # ) # audio_output = gr.Audio( # label="Generated Audio", type="filepath", elem_id="audio-output" # ) # generate_button.click( # fn=generate_audio_gradio, # inputs=[ # prompt, # negative_prompt, # duration, # cfg_strength, # num_steps, # seed, # variant, # ], # outputs=[generate_output_text, audio_output], # ) # audio_examples = [ # ["Typing on a keyboard", "", 10.0, 3, 1, 42, "meanaudio_mf"], # ["A man speaks followed by a popping noise and laughter", "", 10.0, 3, 1, 42, "meanaudio_mf"], # ["Some humming followed by a toilet flushing", "", 10.0, 3, 2, 42, "meanaudio_mf"], # ["Rain falling on a hard surface as thunder roars in the distance", "", 10.0, 3, 5, 42, "meanaudio_mf"], # ["Food sizzling and oil popping", "", 10.0, 3, 25, 42, "meanaudio_mf"], # ["Pots and dishes clanking as a man talks followed by liquid pouring into a container", "", 8.0, 3, 2, 42, "meanaudio_mf"], # ["A few seconds of silence then a rasping sound against wood", "", 12.0, 3, 2, 42, "meanaudio_mf"], # ["A man speaks as he gives a speech and then the crowd cheers", "", 10.0, 3, 25, 42, "fluxaudio_fm"], # ["A goat bleating repeatedly", "", 10.0, 3, 50, 123, "fluxaudio_fm"], # ["A speech and gunfire followed by a gun being loaded", "", 10.0, 3, 1, 42, "meanaudio_mf"], # ["Tires squealing followed by an engine revving", "", 12.0, 4, 25, 456, "fluxaudio_fm"], # ["Hammer slowly hitting the wooden table", "", 10.0, 3.5, 25, 42, "fluxaudio_fm"], # ["Dog barking excitedly and man shouting as race car engine roars past", "", 10.0, 3, 1, 42, "meanaudio_mf"], # ["A dog barking and a cat mewing and a racing car passes by", "", 12.0, 3, 5, -1, "meanaudio_mf"], # ["Whistling with birds chirping", "", 10.0, 4, 50, 42, "fluxaudio_fm"], # ] # gr.Examples( # examples=audio_examples, # inputs=[prompt, negative_prompt, duration, cfg_strength, num_steps, seed, variant], # #outputs=[generate_output_text, audio_output], # #fn=generate_audio_gradio, # examples_per_page=5, # label="Example Prompts", # ) # if __name__ == "__main__": # demo.launch()