import warnings | |
import spaces | |
warnings.filterwarnings("ignore") | |
import logging | |
from argparse import ArgumentParser | |
from pathlib import Path | |
import torch | |
import torchaudio | |
import gradio as gr | |
from transformers import AutoModel | |
import laion_clap | |
from meanaudio.eval_utils import ( | |
ModelConfig, | |
all_model_cfg, | |
generate_mf, | |
generate_fm, | |
setup_eval_logging, | |
) | |
from meanaudio.model.flow_matching import FlowMatching | |
from meanaudio.model.mean_flow import MeanFlow | |
from meanaudio.model.networks import MeanAudio, get_mean_audio | |
from meanaudio.model.utils.features_utils import FeaturesUtils | |
torch.backends.cuda.matmul.allow_tf32 = True | |
torch.backends.cudnn.allow_tf32 = True | |
import gc | |
from datetime import datetime | |
from huggingface_hub import snapshot_download | |
import numpy as np | |
log = logging.getLogger() | |
device = "cpu" | |
if torch.cuda.is_available(): | |
device = "cuda" | |
setup_eval_logging() | |
OUTPUT_DIR = Path("./output/gradio") | |
OUTPUT_DIR.mkdir(parents=True, exist_ok=True) | |
NUM_SAMPLE = 1 | |
# Global model cache to avoid reloading | |
MODEL_CACHE = {} | |
FEATURE_UTILS_CACHE = {} | |
def ensure_models_downloaded(): | |
for variant, model_cfg in all_model_cfg.items(): | |
if not model_cfg.model_path.exists(): | |
log.info(f'Model {variant} not found, downloading...') | |
snapshot_download(repo_id="AndreasXi/MeanAudio", local_dir="./weights") | |
break | |
def load_model_cache(): | |
for variant in all_model_cfg.keys(): | |
if variant in MODEL_CACHE: | |
return MODEL_CACHE[variant], FEATURE_UTILS_CACHE['default'] | |
else: | |
log.info(f"Loading model {variant} for the first time...") | |
model_cfg = all_model_cfg[variant] | |
net = get_mean_audio(model_cfg.model_name, use_rope=True, text_c_dim=512) | |
net = net.to(device, torch.bfloat16).eval() | |
net.load_weights(torch.load(model_cfg.model_path, map_location=device, weights_only=True)) | |
MODEL_CACHE[variant] = net | |
feature_utils = FeaturesUtils( | |
tod_vae_ckpt=model_cfg.vae_path, | |
enable_conditions=True, | |
encoder_name="t5_clap", | |
mode=model_cfg.mode, | |
bigvgan_vocoder_ckpt=model_cfg.bigvgan_16k_path, | |
need_vae_encoder=False | |
) | |
FEATURE_UTILS_CACHE['default'] = feature_utils | |
def generate_audio_gradio( | |
prompt, | |
duration, | |
cfg_strength, | |
num_steps, | |
variant, | |
): | |
if duration <= 0 or num_steps <= 0: | |
raise ValueError("Duration and number of steps must be positive.") | |
if variant not in all_model_cfg: | |
raise ValueError(f"Unknown model variant: {variant}. Available: {list(all_model_cfg.keys())}") | |
net, feature_utils = MODEL_CACHE[variant], FEATURE_UTILS_CACHE['default'] | |
model = all_model_cfg[variant] | |
seq_cfg = model.seq_cfg | |
seq_cfg.duration = duration | |
net.update_seq_lengths(seq_cfg.latent_seq_len) | |
if variant == 'meanaudio_s_ac' or variant == 'meanaudio_s_full': | |
use_meanflow=True | |
elif variant == 'fluxaudio_s_full': | |
use_meanflow=False | |
if use_meanflow: | |
sampler = MeanFlow(steps=num_steps) | |
log.info("Using MeanFlow for generation.") | |
generation_func = generate_mf | |
sampler_arg_name = "mf" | |
cfg_strength = 0 | |
else: | |
sampler = FlowMatching( | |
min_sigma=0, inference_mode="euler", num_steps=num_steps | |
) | |
log.info("Using FlowMatching for generation.") | |
generation_func = generate_fm | |
sampler_arg_name = "fm" | |
rng = torch.Generator(device=device) | |
# force to 42 | |
rng.manual_seed(42) | |
audios = generation_func( | |
[prompt]*NUM_SAMPLE, | |
negative_text=None, | |
feature_utils=feature_utils, | |
net=net, | |
rng=rng, | |
cfg_strength=cfg_strength, | |
**{sampler_arg_name: sampler}, | |
) | |
audio = audios[0].float().cpu() | |
def fade_out(x, sr, fade_ms=50): | |
n = len(x) | |
k = int(sr * fade_ms / 1000) | |
if k <= 0 or k >= n: | |
return x | |
w = np.linspace(1.0, 0.0, k) | |
x[-k:] = x[-k:] * w | |
return x | |
audio = fade_out(audio, seq_cfg.sampling_rate) | |
safe_prompt = ( | |
"".join(c for c in prompt if c.isalnum() or c in (" ", "_")) | |
.rstrip() | |
.replace(" ", "_")[:50] | |
) | |
current_time_string = datetime.now().strftime("%Y%m%d_%H%M%S_%f") | |
filename = f"{safe_prompt}_{current_time_string}.flac" | |
save_path = OUTPUT_DIR / filename | |
torchaudio.save(str(save_path), audio, seq_cfg.sampling_rate) | |
log.info(f"Audio saved to {save_path}") | |
if device == "cuda": | |
torch.cuda.empty_cache() | |
return ( | |
f"Generated audio for prompt: '{prompt}' using {'MeanFlow' if use_meanflow else 'FlowMatching'}", | |
str(save_path), | |
) | |
# Gradio input and output components | |
input_text = gr.Textbox(lines=2, label="Prompt") | |
output_audio = gr.Audio(label="Generated Audio", type="filepath") | |
denoising_steps = gr.Slider(minimum=1, maximum=25, value=1, step=1, label="SamplingSteps", interactive=True) | |
cfg_strength = gr.Slider(minimum=1, maximum=10, value=4.5, step=0.5, label="Guidance Scale (For MeanAudio, it is forced to 3 as integrated in training)", interactive=True) | |
duration = gr.Slider(minimum=1, maximum=30, value=10, step=1, label="Duration", interactive=True) | |
# seed = gr.Slider(minimum=1, maximum=1000000, value=42, step=1, label="Seed", interactive=True) | |
variant = gr.Dropdown(label="Model Variant", choices=list(all_model_cfg.keys()), value='meanaudio_s_full', interactive=True) | |
gr_interface = gr.Interface( | |
fn=generate_audio_gradio, | |
inputs=[input_text, duration, cfg_strength, denoising_steps, variant], | |
outputs=["text", "audio"], | |
title="MeanAudio: Fast and Faithful Text-to-Audio Generation with Mean Flows", | |
description="", | |
flagging_mode="never", | |
examples=[ | |
["Generate the festive sounds of a fireworks show: explosions lighting up the sky, crowd cheering, and the faint music playing in the background!! Celebration of the new year!"], | |
["Melodic human whistling harmonizing with natural birdsong"], | |
["A parade marches through a town square, with drumbeats pounding, children clapping, and a horse neighing amidst the commotion"], | |
["Quiet speech and then and airplane flying away"], | |
["A soccer ball hits a goalpost with a metallic clang, followed by cheers, clapping, and the distant hum of a commentator’s voice"], | |
["A basketball bounces rhythmically on a court, shoes squeak against the floor, and a referee’s whistle cuts through the air"], | |
["Dripping water echoes sharply, a distant growl reverberates through the cavern, and soft scraping metal suggests something lurking unseen"], | |
["A cow is mooing whilst a lion is roaring in the background as a hunter shoots. A flock of birds subsequently fly away from the trees."], | |
["The deep growl of an alligator ripples through the swamp as reeds sway with a soft rustle and a turtle splashes into the murky water"], | |
["Gentle female voice cooing and baby responding with happy gurgles and giggles"], | |
['doorbell ding once followed by footsteps gradually getting louder and a door is opened '], | |
["A fork scrapes a plate, water drips slowly into a sink, and the faint hum of a refrigerator lingers in the background"] | |
], | |
cache_examples="lazy", # Turn on to cache. | |
) | |
if __name__ == "__main__": | |
ensure_models_downloaded() | |
load_model_cache() | |
gr_interface.queue(15).launch() | |
# theme = gr.themes.Soft( | |
# primary_hue="blue", | |
# secondary_hue="slate", | |
# neutral_hue="slate", | |
# text_size="sm", | |
# spacing_size="sm", | |
# ).set( | |
# background_fill_primary="*neutral_50", | |
# background_fill_secondary="*background_fill_primary", | |
# block_background_fill="*background_fill_primary", | |
# block_border_width="0px", | |
# panel_background_fill="*neutral_50", | |
# panel_border_width="0px", | |
# input_background_fill="*neutral_100", | |
# input_border_color="*neutral_200", | |
# button_primary_background_fill="*primary_300", | |
# button_primary_background_fill_hover="*primary_400", | |
# button_secondary_background_fill="*neutral_200", | |
# button_secondary_background_fill_hover="*neutral_300", | |
# ) | |
# custom_css = """ | |
# #main-headertitle { | |
# text-align: center; | |
# margin-top: 15px; | |
# margin-bottom: 10px; | |
# color: var(--neutral-600); | |
# font-weight: 600; | |
# } | |
# #main-header { | |
# text-align: center; | |
# margin-top: 5px; | |
# margin-bottom: 10px; | |
# color: var(--neutral-600); | |
# font-weight: 600; | |
# } | |
# #model-settings-header, #generation-settings-header { | |
# color: var(--neutral-600); | |
# margin-top: 8px; | |
# margin-bottom: 8px; | |
# font-weight: 500; | |
# font-size: 1.1em; | |
# } | |
# .setting-section { | |
# padding: 10px 12px; | |
# border-radius: 6px; | |
# background-color: var(--neutral-50); | |
# margin-bottom: 10px; | |
# border: 1px solid var(--neutral-100); | |
# } | |
# hr { | |
# border: none; | |
# height: 1px; | |
# background-color: var(--neutral-200); | |
# margin: 8px 0; | |
# } | |
# #generate-btn { | |
# width: 100%; | |
# max-width: 250px; | |
# margin: 10px auto; | |
# display: block; | |
# padding: 10px 15px; | |
# font-size: 16px; | |
# border-radius: 5px; | |
# } | |
# #status-box { | |
# min-height: 50px; | |
# display: flex; | |
# align-items: center; | |
# justify-content: center; | |
# padding: 8px; | |
# border-radius: 5px; | |
# border: 1px solid var(--neutral-200); | |
# color: var(--neutral-700); | |
# } | |
# #project-badges { | |
# text-align: center; | |
# margin-top: 30px; | |
# margin-bottom: 20px; | |
# } | |
# #project-badges #badge-container { | |
# display: flex; | |
# gap: 10px; | |
# align-items: center; | |
# justify-content: center; | |
# flex-wrap: wrap; | |
# } | |
# #project-badges img { | |
# border-radius: 5px; | |
# box-shadow: 0 1px 3px rgba(0, 0, 0, 0.1); | |
# height: 20px; | |
# transition: transform 0.1s ease, box-shadow 0.1s ease; | |
# } | |
# #project-badges a:hover img { | |
# transform: translateY(-2px); | |
# box-shadow: 0 4px 8px rgba(0, 0, 0, 0.15); | |
# } | |
# #audio-output { | |
# height: 200px; | |
# border-radius: 5px; | |
# border: 1px solid var(--neutral-200); | |
# } | |
# .gradio-dropdown label, .gradio-checkbox label, .gradio-number label, .gradio-textbox label { | |
# font-weight: 500; | |
# color: var(--neutral-700); | |
# font-size: 0.9em; | |
# } | |
# .gradio-row { | |
# gap: 8px; | |
# } | |
# .gradio-block { | |
# margin-bottom: 8px; | |
# } | |
# .setting-section .gradio-block { | |
# margin-bottom: 6px; | |
# } | |
# ::-webkit-scrollbar { | |
# width: 8px; | |
# height: 8px; | |
# } | |
# ::-webkit-scrollbar-track { | |
# background: var(--neutral-100); | |
# border-radius: 4px; | |
# } | |
# ::-webkit-scrollbar-thumb { | |
# background: var(--neutral-300); | |
# border-radius: 4px; | |
# } | |
# ::-webkit-scrollbar-thumb:hover { | |
# background: var(--neutral-400); | |
# } | |
# * { | |
# scrollbar-width: thin; | |
# scrollbar-color: var(--neutral-300) var(--neutral-100); | |
# } | |
# """ | |
# with gr.Blocks(title="MeanAudio Generator", theme=theme, css=custom_css) as demo: | |
# gr.Markdown("# MeanAudio: Fast and Faithful Text-to-Audio Generation with Mean Flows", elem_id="main-header") | |
# badge_html = ''' | |
# <div id="project-badges"> <!-- 使用 ID | |
# 以便应用 CSS --> | |
# <div id="badge-container"> <!-- 添加这个容器 div 并使用 ID --> | |
# <a href="https://huggingface.co/junxiliu/MeanAudio"> | |
# <img src="https://img.shields.io/badge/Model-HuggingFace-violet?logo=huggingface" alt="Hugging Face Model"> | |
# </a> | |
# <a href="https://huggingface.co/spaces/chenxie95/MeanAudio"> | |
# <img src="https://img.shields.io/badge/Space-HuggingFace-8A2BE2?logo=huggingface" alt="Hugging Face Space"> | |
# </a> | |
# <a href="https://meanaudio.github.io/"> | |
# <img src="https://img.shields.io/badge/Project-Page-brightred?style=flat" alt="Project Page"> | |
# </a> | |
# <a href="https://github.com/xiquan-li/MeanAudio"> | |
# <img src="https://img.shields.io/badge/Code-GitHub-black?logo=github" alt="GitHub"> | |
# </a> | |
# </div> | |
# </div> | |
# ''' | |
# gr.HTML(badge_html) | |
# with gr.Column(elem_classes="setting-section"): | |
# with gr.Row(): | |
# available_variants = ( | |
# list(all_model_cfg.keys()) if all_model_cfg else [] | |
# ) | |
# default_variant = ( | |
# 'meanaudio_mf' | |
# ) | |
# variant = gr.Dropdown( | |
# label="Model Variant", | |
# choices=available_variants, | |
# value=default_variant, | |
# interactive=True, | |
# scale=3, | |
# ) | |
# with gr.Column(elem_classes="setting-section"): | |
# with gr.Row(): | |
# prompt = gr.Textbox( | |
# label="Prompt", | |
# placeholder="Describe the sound you want to generate...", | |
# scale=1, | |
# ) | |
# negative_prompt = gr.Textbox( | |
# label="Negative Prompt", | |
# placeholder="Describe sounds you want to avoid...", | |
# value="", | |
# scale=1, | |
# ) | |
# with gr.Row(): | |
# duration = gr.Number( | |
# label="Duration (sec)", value=10.0, minimum=0.1, scale=1 | |
# ) | |
# cfg_strength = gr.Number( | |
# label="CFG (Meanflow forced to 3)", value=3, minimum=0.0, scale=1 | |
# ) | |
# with gr.Row(): | |
# seed = gr.Number( | |
# label="Seed (-1 for random)", value=42, precision=0, scale=1 | |
# ) | |
# num_steps = gr.Number( | |
# label="Number of Steps", | |
# value=1, | |
# precision=0, | |
# minimum=1, | |
# scale=1, | |
# ) | |
# generate_button = gr.Button("Generate", variant="primary", elem_id="generate-btn") | |
# generate_output_text = gr.Textbox( | |
# label="Result Status", interactive=False, elem_id="status-box" | |
# ) | |
# audio_output = gr.Audio( | |
# label="Generated Audio", type="filepath", elem_id="audio-output" | |
# ) | |
# generate_button.click( | |
# fn=generate_audio_gradio, | |
# inputs=[ | |
# prompt, | |
# negative_prompt, | |
# duration, | |
# cfg_strength, | |
# num_steps, | |
# seed, | |
# variant, | |
# ], | |
# outputs=[generate_output_text, audio_output], | |
# ) | |
# audio_examples = [ | |
# ["Typing on a keyboard", "", 10.0, 3, 1, 42, "meanaudio_mf"], | |
# ["A man speaks followed by a popping noise and laughter", "", 10.0, 3, 1, 42, "meanaudio_mf"], | |
# ["Some humming followed by a toilet flushing", "", 10.0, 3, 2, 42, "meanaudio_mf"], | |
# ["Rain falling on a hard surface as thunder roars in the distance", "", 10.0, 3, 5, 42, "meanaudio_mf"], | |
# ["Food sizzling and oil popping", "", 10.0, 3, 25, 42, "meanaudio_mf"], | |
# ["Pots and dishes clanking as a man talks followed by liquid pouring into a container", "", 8.0, 3, 2, 42, "meanaudio_mf"], | |
# ["A few seconds of silence then a rasping sound against wood", "", 12.0, 3, 2, 42, "meanaudio_mf"], | |
# ["A man speaks as he gives a speech and then the crowd cheers", "", 10.0, 3, 25, 42, "fluxaudio_fm"], | |
# ["A goat bleating repeatedly", "", 10.0, 3, 50, 123, "fluxaudio_fm"], | |
# ["A speech and gunfire followed by a gun being loaded", "", 10.0, 3, 1, 42, "meanaudio_mf"], | |
# ["Tires squealing followed by an engine revving", "", 12.0, 4, 25, 456, "fluxaudio_fm"], | |
# ["Hammer slowly hitting the wooden table", "", 10.0, 3.5, 25, 42, "fluxaudio_fm"], | |
# ["Dog barking excitedly and man shouting as race car engine roars past", "", 10.0, 3, 1, 42, "meanaudio_mf"], | |
# ["A dog barking and a cat mewing and a racing car passes by", "", 12.0, 3, 5, -1, "meanaudio_mf"], | |
# ["Whistling with birds chirping", "", 10.0, 4, 50, 42, "fluxaudio_fm"], | |
# ] | |
# gr.Examples( | |
# examples=audio_examples, | |
# inputs=[prompt, negative_prompt, duration, cfg_strength, num_steps, seed, variant], | |
# #outputs=[generate_output_text, audio_output], | |
# #fn=generate_audio_gradio, | |
# examples_per_page=5, | |
# label="Example Prompts", | |
# ) | |
# if __name__ == "__main__": | |
# demo.launch() | |