File size: 5,957 Bytes
0ad0454 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 |
import warnings
warnings.filterwarnings("ignore", category=FutureWarning)
import logging
from pathlib import Path
import torch
import torchaudio
from meanaudio.eval_utils import (ModelConfig, all_model_cfg, generate_mf, generate_fm, setup_eval_logging)
from meanaudio.model.flow_matching import FlowMatching
from meanaudio.model.mean_flow import MeanFlow
from meanaudio.model.networks import MeanAudio, get_mean_audio
from meanaudio.model.utils.features_utils import FeaturesUtils
from huggingface_hub import snapshot_download
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
log = logging.getLogger()
@torch.inference_mode()
def MeanAudioInference(
prompt='',
negative_prompt='',
model_path='',
encoder_name='t5_clap',
variant='meanaudio_mf',
duration=10,
cfg_strength=4.5,
num_steps=1,
output='./output',
seed=42,
full_precision=False,
use_rope=True,
text_c_dim=512,
use_meanflow=False
):
'''
prompt (str):
The text description guiding the audio generation (e.g., "a dog is barking").
negative_prompt (str):
A text description for sounds that should be avoided in the generated audio.
model_path (str):
Path to the model weights file. If empty, it defaults to ./weights/{variant}.pth.
encoder_name (str):
Specifies the text encoder to use (default: 't5_clap').
variant (str):
Specifies the model variant to load (default: 'meanaudio_mf'). Must be a key in all_model_cfg.
duration (int):
The desired duration of the generated audio in seconds (default: 10).
cfg_strength (float):
Classifier-Free Guidance strength. Ignored if use_meanflow is True or variant is 'meanaudio_mf' (default: 4.5).
num_steps (int):
Number of steps for the generation process (default: 1).
output (str):
Directory path where the generated audio file will be saved (default: './output').
seed (int):
Random seed for generation reproducibility (default: 42).
full_precision (bool):
If True, uses torch.float32 precision; otherwise, uses torch.bfloat16 (default: False).
use_rope (bool):
Whether to use Rotary Position Embedding in the model (default: True).
text_c_dim (int):
Dimension of the text context vector (default: 512).
use_meanflow (bool):
If True, uses the MeanFlow generation method; otherwise, uses FlowMatching. If variant is 'meanaudio_mf', this is automatically set to True (default: False).
'''
setup_eval_logging()
output_dir = Path(output).expanduser()
output_dir.mkdir(parents=True, exist_ok=True)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
dtype = torch.float32 if full_precision else torch.bfloat16
if duration <= 0 or num_steps <= 0:
raise ValueError("Duration and number of steps must be positive.")
if variant not in all_model_cfg:
raise ValueError(f"Unknown model variant: {variant}. Available: {list(all_model_cfg.keys())}")
if not model_path or model_path == '':
model_path = Path(f'./weights/{variant}.pth')
else:
model_path = Path(model_path)
if not model_path.exists():
if str(model_path) == f'./weights/{variant}.pth':
log.info(f'Model not found at {model_path}')
log.info('Downloading models to "./weights/"...')
try:
weights_dir = Path('./weights')
weights_dir.mkdir(exist_ok=True)
snapshot_download(repo_id="junxiliu/Meanaudio", local_dir="./weights",allow_patterns=["*.pt", "*.pth"] )
raise NotImplementedError("Model download functionality needs to be implemented")
except Exception as e:
log.error(f"Failed to download model: {e}")
raise FileNotFoundError(f"Model file not found and download failed: {model_path}")
else:
raise FileNotFoundError(f"Model file not found: {model_path}")
model = all_model_cfg[variant]
seq_cfg = model.seq_cfg
seq_cfg.duration = duration
net = get_mean_audio(model.model_name, use_rope=use_rope, text_c_dim=text_c_dim)
net = net.to(device, dtype).eval()
net.load_weights(torch.load(model_path, map_location=device, weights_only=True))
net.update_seq_lengths(seq_cfg.latent_seq_len)
if variant=='meanaudio_mf':
use_meanflow=True
if use_meanflow:
generation_func = MeanFlow(steps=num_steps)
cfg_strength=0
else:
generation_func = FlowMatching(min_sigma=0, inference_mode='euler', num_steps=num_steps)
feature_utils = FeaturesUtils(
tod_vae_ckpt=model.vae_path,
enable_conditions=True,
encoder_name=encoder_name,
mode=model.mode,
bigvgan_vocoder_ckpt=model.bigvgan_16k_path,
need_vae_encoder=False
)
feature_utils = feature_utils.to(device, dtype).eval()
rng = torch.Generator(device=device)
rng.manual_seed(seed)
generate_fn = generate_mf if use_meanflow else generate_fm
kwargs = {
'negative_text': [negative_prompt],
'feature_utils': feature_utils,
'net': net,
'rng': rng,
'cfg_strength': cfg_strength
}
if use_meanflow:
kwargs['mf'] = generation_func
else:
kwargs['fm'] = generation_func
audios = generate_fn([prompt], **kwargs)
audio = audios.float().cpu()[0]
safe_filename = prompt.replace(' ', '_').replace('/', '_').replace('.', '')
save_path = output_dir / f'{safe_filename}--numsteps{num_steps}--seed{seed}.wav'
torchaudio.save(save_path, audio, seq_cfg.sampling_rate)
log.info(f'Audio saved to {save_path}')
log.info('Memory usage: %.2f GB', torch.cuda.max_memory_allocated() / (2**30))
return save_path
if __name__ == '__main__':
MeanAudioInference('a dog is barking') |