import warnings import spaces warnings.filterwarnings("ignore", category=FutureWarning) import logging from argparse import ArgumentParser from pathlib import Path import torch import torchaudio import gradio as gr from transformers import AutoModel import laion_clap from meanaudio.eval_utils import ( ModelConfig, all_model_cfg, generate_mf, generate_fm, setup_eval_logging, ) from meanaudio.model.flow_matching import FlowMatching from meanaudio.model.mean_flow import MeanFlow from meanaudio.model.networks import MeanAudio, get_mean_audio from meanaudio.model.utils.features_utils import FeaturesUtils torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True import gc from datetime import datetime from huggingface_hub import snapshot_download log = logging.getLogger() device = "cpu" if torch.cuda.is_available(): device = "cuda" setup_eval_logging() OUTPUT_DIR = Path("./output/gradio") OUTPUT_DIR.mkdir(parents=True, exist_ok=True) NUM_SAMPLE = 1 # snapshot_download(repo_id="google/flan-t5-large") # a = AutoModel.from_pretrained('bert-base-uncased') # b = AutoModel.from_pretrained('roberta-base') snapshot_download(repo_id="AndreasXi/MeanAudio", local_dir="./weights",allow_patterns=["*.pt", "*.pth"] ) # _clap_ckpt_path='./weights/music_speech_audioset_epoch_15_esc_89.98.pt' # laion_clap_model = laion_clap.CLAP_Module(enable_fusion=False, amodel='HTSAT-base').cuda().eval() # laion_clap_model.load_ckpt(_clap_ckpt_path, verbose=False) @spaces.GPU(duration=10) @torch.inference_mode() def generate_audio_gradio( prompt, negative_prompt, duration, cfg_strength, num_steps, seed, variant, ): dtype = torch.float32 if duration <= 0 or num_steps <= 0: raise ValueError("Duration and number of steps must be positive.") if variant not in all_model_cfg: raise ValueError(f"Unknown model variant: {variant}. Available: {list(all_model_cfg.keys())}") model_path = all_model_cfg[variant].model_path # by default, this will use meanaudio_s_full.pth or fluxaudio_s_full.pth model = all_model_cfg[variant] seq_cfg = model.seq_cfg seq_cfg.duration = duration net = get_mean_audio(model.model_name, use_rope=True, text_c_dim=512) net = net.to(device, dtype).eval() net.load_weights(torch.load(model_path, map_location=device, weights_only=True)) net.update_seq_lengths(seq_cfg.latent_seq_len) feature_utils = FeaturesUtils(tod_vae_ckpt=model.vae_path, enable_conditions=True, encoder_name="t5_clap", mode=model.mode, bigvgan_vocoder_ckpt=model.bigvgan_16k_path, need_vae_encoder=False) feature_utils = feature_utils.to(device, dtype).eval() if variant == 'meanaudio_s_ac' or variant == 'meanaudio_s_full': use_meanflow=True elif variant == 'fluxaudio_s_full': use_meanflow=False if use_meanflow: sampler = MeanFlow(steps=num_steps) log.info("Using MeanFlow for generation.") generation_func = generate_mf sampler_arg_name = "mf" cfg_strength = 0 else: sampler = FlowMatching( min_sigma=0, inference_mode="euler", num_steps=num_steps ) log.info("Using FlowMatching for generation.") generation_func = generate_fm sampler_arg_name = "fm" rng = torch.Generator(device=device) rng.manual_seed(seed) audios = generation_func( [prompt]*NUM_SAMPLE, negative_text=[negative_prompt]*NUM_SAMPLE, feature_utils=feature_utils, net=net, rng=rng, cfg_strength=cfg_strength, **{sampler_arg_name: sampler}, ) audio = audios[0].float().cpu() # text_embed = laion_clap_model.get_text_embedding(prompt, use_tensor=True).squeeze() # audio_embed = laion_clap_model.get_audio_embedding_from_data(audios[:,0,:].float().cpu(), use_tensor=True).squeeze() # scores = torch.cosine_similarity(text_embed.expand(audio_embed.shape[0], -1), # audio_embed, # dim=-1) # log.info(scores) # log.info(torch.argmax(scores).item()) # audio = audios[torch.argmax(scores).item()].float().cpu() safe_prompt = ( "".join(c for c in prompt if c.isalnum() or c in (" ", "_")) .rstrip() .replace(" ", "_")[:50] ) current_time_string = datetime.now().strftime("%Y%m%d_%H%M%S_%f") filename = f"{safe_prompt}_{current_time_string}.flac" save_path = OUTPUT_DIR / filename torchaudio.save(str(save_path), audio, seq_cfg.sampling_rate) log.info(f"Audio saved to {save_path}") gc.collect() return ( f"Generated audio for prompt: '{prompt}' using {'MeanFlow' if use_meanflow else 'FlowMatching'}", str(save_path), ) theme = gr.themes.Soft( primary_hue="blue", secondary_hue="slate", neutral_hue="slate", text_size="sm", spacing_size="sm", ).set( background_fill_primary="*neutral_50", background_fill_secondary="*background_fill_primary", block_background_fill="*background_fill_primary", block_border_width="0px", panel_background_fill="*neutral_50", panel_border_width="0px", input_background_fill="*neutral_100", input_border_color="*neutral_200", button_primary_background_fill="*primary_300", button_primary_background_fill_hover="*primary_400", button_secondary_background_fill="*neutral_200", button_secondary_background_fill_hover="*neutral_300", ) custom_css = """ #main-headertitle { text-align: center; margin-top: 15px; margin-bottom: 10px; color: var(--neutral-600); font-weight: 600; } #main-header { text-align: center; margin-top: 5px; margin-bottom: 10px; color: var(--neutral-600); font-weight: 600; } #model-settings-header, #generation-settings-header { color: var(--neutral-600); margin-top: 8px; margin-bottom: 8px; font-weight: 500; font-size: 1.1em; } .setting-section { padding: 10px 12px; border-radius: 6px; background-color: var(--neutral-50); margin-bottom: 10px; border: 1px solid var(--neutral-100); } hr { border: none; height: 1px; background-color: var(--neutral-200); margin: 8px 0; } #generate-btn { width: 100%; max-width: 250px; margin: 10px auto; display: block; padding: 10px 15px; font-size: 16px; border-radius: 5px; } #status-box { min-height: 50px; display: flex; align-items: center; justify-content: center; padding: 8px; border-radius: 5px; border: 1px solid var(--neutral-200); color: var(--neutral-700); } #project-badges { text-align: center; margin-top: 30px; margin-bottom: 20px; } #project-badges #badge-container { display: flex; gap: 10px; align-items: center; justify-content: center; flex-wrap: wrap; } #project-badges img { border-radius: 5px; box-shadow: 0 1px 3px rgba(0, 0, 0, 0.1); height: 20px; transition: transform 0.1s ease, box-shadow 0.1s ease; } #project-badges a:hover img { transform: translateY(-2px); box-shadow: 0 4px 8px rgba(0, 0, 0, 0.15); } #audio-output { height: 200px; border-radius: 5px; border: 1px solid var(--neutral-200); } .gradio-dropdown label, .gradio-checkbox label, .gradio-number label, .gradio-textbox label { font-weight: 500; color: var(--neutral-700); font-size: 0.9em; } .gradio-row { gap: 8px; } .gradio-block { margin-bottom: 8px; } .setting-section .gradio-block { margin-bottom: 6px; } ::-webkit-scrollbar { width: 8px; height: 8px; } ::-webkit-scrollbar-track { background: var(--neutral-100); border-radius: 4px; } ::-webkit-scrollbar-thumb { background: var(--neutral-300); border-radius: 4px; } ::-webkit-scrollbar-thumb:hover { background: var(--neutral-400); } * { scrollbar-width: thin; scrollbar-color: var(--neutral-300) var(--neutral-100); } """ with gr.Blocks(title="MeanAudio Generator", theme=theme, css=custom_css) as demo: gr.Markdown("# MeanAudio: Fast and Faithful Text-to-Audio Generation with Mean Flows", elem_id="main-header") badge_html = '''
Hugging Face Model Hugging Face Space Project Page GitHub
''' gr.HTML(badge_html) with gr.Column(elem_classes="setting-section"): with gr.Row(): available_variants = ( list(all_model_cfg.keys()) if all_model_cfg else [] ) default_variant = ( 'meanaudio_mf' ) variant = gr.Dropdown( label="Model Variant", choices=available_variants, value=default_variant, interactive=True, scale=3, ) with gr.Column(elem_classes="setting-section"): with gr.Row(): prompt = gr.Textbox( label="Prompt", placeholder="Describe the sound you want to generate...", scale=1, ) negative_prompt = gr.Textbox( label="Negative Prompt", placeholder="Describe sounds you want to avoid...", value="", scale=1, ) with gr.Row(): duration = gr.Number( label="Duration (sec)", value=10.0, minimum=0.1, scale=1 ) cfg_strength = gr.Number( label="CFG (Meanflow forced to 3)", value=3, minimum=0.0, scale=1 ) with gr.Row(): seed = gr.Number( label="Seed (-1 for random)", value=42, precision=0, scale=1 ) num_steps = gr.Number( label="Number of Steps", value=1, precision=0, minimum=1, scale=1, ) generate_button = gr.Button("Generate", variant="primary", elem_id="generate-btn") generate_output_text = gr.Textbox( label="Result Status", interactive=False, elem_id="status-box" ) audio_output = gr.Audio( label="Generated Audio", type="filepath", elem_id="audio-output" ) generate_button.click( fn=generate_audio_gradio, inputs=[ prompt, negative_prompt, duration, cfg_strength, num_steps, seed, variant, ], outputs=[generate_output_text, audio_output], ) audio_examples = [ ["Typing on a keyboard", "", 10.0, 3, 1, 42, "meanaudio_mf"], ["A man speaks followed by a popping noise and laughter", "", 10.0, 3, 1, 42, "meanaudio_mf"], ["Some humming followed by a toilet flushing", "", 10.0, 3, 2, 42, "meanaudio_mf"], ["Rain falling on a hard surface as thunder roars in the distance", "", 10.0, 3, 5, 42, "meanaudio_mf"], ["Food sizzling and oil popping", "", 10.0, 3, 25, 42, "meanaudio_mf"], ["Pots and dishes clanking as a man talks followed by liquid pouring into a container", "", 8.0, 3, 2, 42, "meanaudio_mf"], ["A few seconds of silence then a rasping sound against wood", "", 12.0, 3, 2, 42, "meanaudio_mf"], ["A man speaks as he gives a speech and then the crowd cheers", "", 10.0, 3, 25, 42, "fluxaudio_fm"], ["A goat bleating repeatedly", "", 10.0, 3, 50, 123, "fluxaudio_fm"], ["A speech and gunfire followed by a gun being loaded", "", 10.0, 3, 1, 42, "meanaudio_mf"], ["Tires squealing followed by an engine revving", "", 12.0, 4, 25, 456, "fluxaudio_fm"], ["Hammer slowly hitting the wooden table", "", 10.0, 3.5, 25, 42, "fluxaudio_fm"], ["Dog barking excitedly and man shouting as race car engine roars past", "", 10.0, 3, 1, 42, "meanaudio_mf"], ["A dog barking and a cat mewing and a racing car passes by", "", 12.0, 3, 5, -1, "meanaudio_mf"], ["Whistling with birds chirping", "", 10.0, 4, 50, 42, "fluxaudio_fm"], ] gr.Examples( examples=audio_examples, inputs=[prompt, negative_prompt, duration, cfg_strength, num_steps, seed, variant], #outputs=[generate_output_text, audio_output], #fn=generate_audio_gradio, examples_per_page=5, label="Example Prompts", ) if __name__ == "__main__": demo.launch()