import warnings import spaces warnings.filterwarnings("ignore", category=FutureWarning) import logging from argparse import ArgumentParser from pathlib import Path import torch import torchaudio import gradio as gr from meanaudio.eval_utils import ( ModelConfig, all_model_cfg, generate_mf, generate_fm, setup_eval_logging, ) from meanaudio.model.flow_matching import FlowMatching from meanaudio.model.mean_flow import MeanFlow from meanaudio.model.networks import MeanAudio, get_mean_audio from meanaudio.model.utils.features_utils import FeaturesUtils torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True import gc from datetime import datetime log = logging.getLogger() device = "cpu" if torch.cuda.is_available(): device = "cuda" setup_eval_logging() OUTPUT_DIR = Path("./output/gradio") OUTPUT_DIR.mkdir(parents=True, exist_ok=True) current_model_state = { "net": None, "feature_utils": None, "seq_cfg": None, "args": None, } def load_model_if_needed( variant, model_path, encoder_name, use_rope, text_c_dim, full_precision ): global current_model_state dtype = torch.float32 if full_precision else torch.bfloat16 needs_reload = ( current_model_state["args"] is None or current_model_state["args"].variant != variant or current_model_state["args"].model_path != model_path or current_model_state["args"].encoder_name != encoder_name or current_model_state["args"].use_rope != use_rope or current_model_state["args"].text_c_dim != text_c_dim or current_model_state["args"].full_precision != full_precision ) if needs_reload: try: if variant not in all_model_cfg: raise ValueError(f"Unknown model variant: {variant}") model: ModelConfig = all_model_cfg[variant] seq_cfg = model.seq_cfg class MockArgs: pass mock_args = MockArgs() mock_args.variant = variant mock_args.model_path = model_path mock_args.encoder_name = encoder_name mock_args.use_rope = use_rope mock_args.text_c_dim = text_c_dim mock_args.full_precision = full_precision net: MeanAudio = ( get_mean_audio( model.model_name, use_rope=mock_args.use_rope, text_c_dim=mock_args.text_c_dim, ) .to(device, dtype) .eval() ) net.load_weights( torch.load( mock_args.model_path, map_location=device, weights_only=True ) ) log.info(f"Loaded weights from {mock_args.model_path}") feature_utils = FeaturesUtils( tod_vae_ckpt=model.vae_path, enable_conditions=True, encoder_name=mock_args.encoder_name, mode=model.mode, bigvgan_vocoder_ckpt=model.bigvgan_16k_path, need_vae_encoder=False, ) feature_utils = feature_utils.to(device, dtype).eval() current_model_state["net"] = net current_model_state["feature_utils"] = feature_utils current_model_state["seq_cfg"] = seq_cfg current_model_state["args"] = mock_args log.info(f"Model '{variant}' loaded successfully.") return True except Exception as e: log.error(f"Error loading model: {e}") current_model_state = { "net": None, "feature_utils": None, "seq_cfg": None, "args": None, } raise e else: log.info(f"Model '{variant}' already loaded with current settings.") return False @spaces.GPU def generate_audio_gradio( prompt, negative_prompt, duration, cfg_strength, num_steps, seed, variant, full_precision, ): global current_model_state use_meanflow = variant == "meanaudio_mf" model_path = ( "./weights/meanaudio_mf.pth" if use_meanflow else "./weights/fluxaudio_fm.pth" ) encoder_name = "t5_clap" use_rope = True text_c_dim = 512 try: load_model_if_needed( variant, model_path, encoder_name, use_rope, text_c_dim, full_precision ) except Exception as e: return f"Error loading model: {str(e)}", None if current_model_state["net"] is None: return "Error: Model could not be loaded.", None net = current_model_state["net"] feature_utils = current_model_state["feature_utils"] seq_cfg = current_model_state["seq_cfg"] args = current_model_state["args"] dtype = torch.float32 if args.full_precision else torch.bfloat16 try: seq_cfg.duration = duration net.update_seq_lengths(seq_cfg.latent_seq_len) rng = torch.Generator(device=device) if seed >= 0: rng.manual_seed(seed) else: rng.seed() if use_meanflow: sampler = MeanFlow(steps=num_steps) log.info("Using MeanFlow for generation.") generation_func = generate_mf sampler_arg_name = "mf" cfg_strength = 3 else: sampler = FlowMatching( min_sigma=0, inference_mode="euler", num_steps=num_steps ) log.info("Using FlowMatching for generation.") generation_func = generate_fm sampler_arg_name = "fm" prompts = [prompt] audios = generation_func( prompts, negative_text=[negative_prompt], feature_utils=feature_utils, net=net, rng=rng, cfg_strength=cfg_strength, **{sampler_arg_name: sampler}, ) audio = audios.float().cpu()[0] safe_prompt = ( "".join(c for c in prompt if c.isalnum() or c in (" ", "_")) .rstrip() .replace(" ", "_")[:50] ) current_time_string = datetime.now().strftime("%Y%m%d_%H%M%S_%f") filename = f"{safe_prompt}_{current_time_string}.flac" save_path = OUTPUT_DIR / filename torchaudio.save(str(save_path), audio, seq_cfg.sampling_rate) log.info(f"Audio saved to {save_path}") gc.collect() return ( f"Generated audio for prompt: '{prompt}' using {'MeanFlow' if use_meanflow else 'FlowMatching'}", str(save_path), ) except Exception as e: gc.collect() log.error(f"Generation error: {e}") return f"Error during generation: {str(e)}", None theme = gr.themes.Soft( primary_hue="blue", secondary_hue="slate", neutral_hue="slate", text_size="sm", spacing_size="sm", ).set( background_fill_primary="*neutral_50", background_fill_secondary="*background_fill_primary", block_background_fill="*background_fill_primary", block_border_width="0px", panel_background_fill="*neutral_50", panel_border_width="0px", input_background_fill="*neutral_100", input_border_color="*neutral_200", button_primary_background_fill="*primary_300", button_primary_background_fill_hover="*primary_400", button_secondary_background_fill="*neutral_200", button_secondary_background_fill_hover="*neutral_300", ) custom_css = """ #main-header { text-align: center; margin-top: 5px; margin-bottom: 10px; color: var(--neutral-600); font-weight: 600; } #model-settings-header, #generation-settings-header { color: var(--neutral-600); margin-top: 8px; margin-bottom: 8px; font-weight: 500; font-size: 1.1em; } .setting-section { padding: 10px 12px; border-radius: 6px; background-color: var(--neutral-50); margin-bottom: 10px; border: 1px solid var(--neutral-100); } hr { border: none; height: 1px; background-color: var(--neutral-200); margin: 8px 0; } #generate-btn { width: 100%; max-width: 250px; margin: 10px auto; display: block; padding: 10px 15px; font-size: 16px; border-radius: 5px; } #status-box { min-height: 50px; display: flex; align-items: center; justify-content: center; padding: 8px; border-radius: 5px; border: 1px solid var(--neutral-200); color: var(--neutral-700); } #audio-output { height: 100px; border-radius: 5px; border: 1px solid var(--neutral-200); } .gradio-dropdown label, .gradio-checkbox label, .gradio-number label, .gradio-textbox label { font-weight: 500; color: var(--neutral-700); font-size: 0.9em; } .gradio-row { gap: 8px; } .gradio-block { margin-bottom: 8px; } .setting-section .gradio-block { margin-bottom: 6px; } ::-webkit-scrollbar { width: 8px; height: 8px; } ::-webkit-scrollbar-track { background: var(--neutral-100); border-radius: 4px; } ::-webkit-scrollbar-thumb { background: var(--neutral-300); border-radius: 4px; } ::-webkit-scrollbar-thumb:hover { background: var(--neutral-400); } * { scrollbar-width: thin; scrollbar-color: var(--neutral-300) var(--neutral-100); } """ with gr.Blocks(title="MeanAudio Generator", theme=theme, css=custom_css) as demo: gr.Markdown("# MeanAudio Text-to-Audio Generator", elem_id="main-header") gr.Markdown("### Model and Generation Settings", elem_id="model-settings-header") with gr.Column(elem_classes="setting-section"): with gr.Row(): available_variants = ( list(all_model_cfg.keys()) if all_model_cfg else [] ) default_variant = ( 'meanaudio_mf' ) variant = gr.Dropdown( label="Model Variant", choices=available_variants, value=default_variant, interactive=True, scale=3, ) full_precision = gr.Checkbox( label="Full Precision (float32)", value=True, scale=1 ) gr.Markdown("### Audio Generation", elem_id="generation-settings-header") with gr.Column(elem_classes="setting-section"): with gr.Row(): prompt = gr.Textbox( label="Prompt", placeholder="Describe the sound you want to generate...", scale=1, ) negative_prompt = gr.Textbox( label="Negative Prompt", placeholder="Describe sounds you want to avoid...", value="", scale=1, ) with gr.Row(): duration = gr.Number( label="Duration (sec)", value=10.0, minimum=0.1, scale=1 ) cfg_strength = gr.Number( label="CFG (Meanflow forced to 3)", value=3, minimum=0.0, scale=1 ) with gr.Row(): seed = gr.Number( label="Seed (-1 for random)", value=42, precision=0, scale=1 ) num_steps = gr.Number( label="Number of Steps", value=1, precision=0, minimum=1, scale=1, ) generate_button = gr.Button("Generate", variant="primary", elem_id="generate-btn") generate_output_text = gr.Textbox( label="Result Status", interactive=False, elem_id="status-box" ) audio_output = gr.Audio( label="Generated Audio", type="filepath", elem_id="audio-output" ) generate_button.click( fn=generate_audio_gradio, inputs=[ prompt, negative_prompt, duration, cfg_strength, num_steps, seed, variant, full_precision, ], outputs=[generate_output_text, audio_output], ) if __name__ == "__main__": parser = ArgumentParser() parser.add_argument("--port", type=int, default=7861) args = parser.parse_args() demo.launch(server_port=args.port, allowed_paths=[OUTPUT_DIR.resolve()])