update app.py
Browse files- MeanAudio.py +0 -147
- app.py +19 -5
- easyinfer.py +0 -3
MeanAudio.py
DELETED
@@ -1,147 +0,0 @@
|
|
1 |
-
import warnings
|
2 |
-
warnings.filterwarnings("ignore", category=FutureWarning)
|
3 |
-
import logging
|
4 |
-
from pathlib import Path
|
5 |
-
import torch
|
6 |
-
import torchaudio
|
7 |
-
from meanaudio.eval_utils import (ModelConfig, all_model_cfg, generate_mf, generate_fm, setup_eval_logging)
|
8 |
-
from meanaudio.model.flow_matching import FlowMatching
|
9 |
-
from meanaudio.model.mean_flow import MeanFlow
|
10 |
-
from meanaudio.model.networks import MeanAudio, get_mean_audio
|
11 |
-
from meanaudio.model.utils.features_utils import FeaturesUtils
|
12 |
-
from huggingface_hub import snapshot_download
|
13 |
-
|
14 |
-
torch.backends.cuda.matmul.allow_tf32 = True
|
15 |
-
torch.backends.cudnn.allow_tf32 = True
|
16 |
-
log = logging.getLogger()
|
17 |
-
|
18 |
-
@torch.inference_mode()
|
19 |
-
def MeanAudioInference(
|
20 |
-
prompt='',
|
21 |
-
negative_prompt='',
|
22 |
-
model_path='',
|
23 |
-
encoder_name='t5_clap',
|
24 |
-
variant='meanaudio_mf',
|
25 |
-
duration=10,
|
26 |
-
cfg_strength=4.5,
|
27 |
-
num_steps=1,
|
28 |
-
output='./output',
|
29 |
-
seed=42,
|
30 |
-
full_precision=False,
|
31 |
-
use_rope=True,
|
32 |
-
text_c_dim=512,
|
33 |
-
use_meanflow=False
|
34 |
-
):
|
35 |
-
'''
|
36 |
-
prompt (str):
|
37 |
-
The text description guiding the audio generation (e.g., "a dog is barking").
|
38 |
-
negative_prompt (str):
|
39 |
-
A text description for sounds that should be avoided in the generated audio.
|
40 |
-
model_path (str):
|
41 |
-
Path to the model weights file. If empty, it defaults to ./weights/{variant}.pth.
|
42 |
-
encoder_name (str):
|
43 |
-
Specifies the text encoder to use (default: 't5_clap').
|
44 |
-
variant (str):
|
45 |
-
Specifies the model variant to load (default: 'meanaudio_mf'). Must be a key in all_model_cfg.
|
46 |
-
duration (int):
|
47 |
-
The desired duration of the generated audio in seconds (default: 10).
|
48 |
-
cfg_strength (float):
|
49 |
-
Classifier-Free Guidance strength. Ignored if use_meanflow is True or variant is 'meanaudio_mf' (default: 4.5).
|
50 |
-
num_steps (int):
|
51 |
-
Number of steps for the generation process (default: 1).
|
52 |
-
output (str):
|
53 |
-
Directory path where the generated audio file will be saved (default: './output').
|
54 |
-
seed (int):
|
55 |
-
Random seed for generation reproducibility (default: 42).
|
56 |
-
full_precision (bool):
|
57 |
-
If True, uses torch.float32 precision; otherwise, uses torch.bfloat16 (default: False).
|
58 |
-
use_rope (bool):
|
59 |
-
Whether to use Rotary Position Embedding in the model (default: True).
|
60 |
-
text_c_dim (int):
|
61 |
-
Dimension of the text context vector (default: 512).
|
62 |
-
use_meanflow (bool):
|
63 |
-
If True, uses the MeanFlow generation method; otherwise, uses FlowMatching. If variant is 'meanaudio_mf', this is automatically set to True (default: False).
|
64 |
-
'''
|
65 |
-
setup_eval_logging()
|
66 |
-
output_dir = Path(output).expanduser()
|
67 |
-
output_dir.mkdir(parents=True, exist_ok=True)
|
68 |
-
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
69 |
-
dtype = torch.float32 if full_precision else torch.bfloat16
|
70 |
-
if duration <= 0 or num_steps <= 0:
|
71 |
-
raise ValueError("Duration and number of steps must be positive.")
|
72 |
-
if variant not in all_model_cfg:
|
73 |
-
raise ValueError(f"Unknown model variant: {variant}. Available: {list(all_model_cfg.keys())}")
|
74 |
-
if not model_path or model_path == '':
|
75 |
-
model_path = Path(f'./weights/{variant}.pth')
|
76 |
-
else:
|
77 |
-
model_path = Path(model_path)
|
78 |
-
if not model_path.exists():
|
79 |
-
if str(model_path) == f'./weights/{variant}.pth':
|
80 |
-
log.info(f'Model not found at {model_path}')
|
81 |
-
log.info('Downloading models to "./weights/"...')
|
82 |
-
try:
|
83 |
-
weights_dir = Path('./weights')
|
84 |
-
weights_dir.mkdir(exist_ok=True)
|
85 |
-
snapshot_download(repo_id="junxiliu/Meanaudio", local_dir="./weights",allow_patterns=["*.pt", "*.pth"] )
|
86 |
-
raise NotImplementedError("Model download functionality needs to be implemented")
|
87 |
-
except Exception as e:
|
88 |
-
log.error(f"Failed to download model: {e}")
|
89 |
-
raise FileNotFoundError(f"Model file not found and download failed: {model_path}")
|
90 |
-
else:
|
91 |
-
raise FileNotFoundError(f"Model file not found: {model_path}")
|
92 |
-
|
93 |
-
model = all_model_cfg[variant]
|
94 |
-
seq_cfg = model.seq_cfg
|
95 |
-
seq_cfg.duration = duration
|
96 |
-
|
97 |
-
net = get_mean_audio(model.model_name, use_rope=use_rope, text_c_dim=text_c_dim)
|
98 |
-
net = net.to(device, dtype).eval()
|
99 |
-
net.load_weights(torch.load(model_path, map_location=device, weights_only=True))
|
100 |
-
net.update_seq_lengths(seq_cfg.latent_seq_len)
|
101 |
-
|
102 |
-
if variant=='meanaudio_mf':
|
103 |
-
use_meanflow=True
|
104 |
-
if use_meanflow:
|
105 |
-
generation_func = MeanFlow(steps=num_steps)
|
106 |
-
cfg_strength=0
|
107 |
-
else:
|
108 |
-
generation_func = FlowMatching(min_sigma=0, inference_mode='euler', num_steps=num_steps)
|
109 |
-
|
110 |
-
feature_utils = FeaturesUtils(
|
111 |
-
tod_vae_ckpt=model.vae_path,
|
112 |
-
enable_conditions=True,
|
113 |
-
encoder_name=encoder_name,
|
114 |
-
mode=model.mode,
|
115 |
-
bigvgan_vocoder_ckpt=model.bigvgan_16k_path,
|
116 |
-
need_vae_encoder=False
|
117 |
-
)
|
118 |
-
feature_utils = feature_utils.to(device, dtype).eval()
|
119 |
-
|
120 |
-
rng = torch.Generator(device=device)
|
121 |
-
rng.manual_seed(seed)
|
122 |
-
|
123 |
-
generate_fn = generate_mf if use_meanflow else generate_fm
|
124 |
-
kwargs = {
|
125 |
-
'negative_text': [negative_prompt],
|
126 |
-
'feature_utils': feature_utils,
|
127 |
-
'net': net,
|
128 |
-
'rng': rng,
|
129 |
-
'cfg_strength': cfg_strength
|
130 |
-
}
|
131 |
-
|
132 |
-
if use_meanflow:
|
133 |
-
kwargs['mf'] = generation_func
|
134 |
-
else:
|
135 |
-
kwargs['fm'] = generation_func
|
136 |
-
|
137 |
-
audios = generate_fn([prompt], **kwargs)
|
138 |
-
audio = audios.float().cpu()[0]
|
139 |
-
safe_filename = prompt.replace(' ', '_').replace('/', '_').replace('.', '')
|
140 |
-
save_path = output_dir / f'{safe_filename}--numsteps{num_steps}--seed{seed}.wav'
|
141 |
-
torchaudio.save(save_path, audio, seq_cfg.sampling_rate)
|
142 |
-
log.info(f'Audio saved to {save_path}')
|
143 |
-
log.info('Memory usage: %.2f GB', torch.cuda.max_memory_allocated() / (2**30))
|
144 |
-
return save_path
|
145 |
-
|
146 |
-
if __name__ == '__main__':
|
147 |
-
MeanAudioInference('a dog is barking')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
app.py
CHANGED
@@ -8,6 +8,7 @@ import torch
|
|
8 |
import torchaudio
|
9 |
import gradio as gr
|
10 |
from transformers import AutoModel
|
|
|
11 |
from meanaudio.eval_utils import (
|
12 |
ModelConfig,
|
13 |
all_model_cfg,
|
@@ -31,12 +32,15 @@ if torch.cuda.is_available():
|
|
31 |
setup_eval_logging()
|
32 |
OUTPUT_DIR = Path("./output/gradio")
|
33 |
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
|
34 |
-
|
35 |
snapshot_download(repo_id="google/flan-t5-large")
|
36 |
a=AutoModel.from_pretrained('bert-base-uncased')
|
37 |
b=AutoModel.from_pretrained('roberta-base')
|
38 |
snapshot_download(repo_id="junxiliu/Meanaudio", local_dir="./weights",allow_patterns=["*.pt", "*.pth"] )
|
39 |
-
|
|
|
|
|
|
|
40 |
current_model_states = {
|
41 |
|
42 |
}
|
@@ -190,16 +194,26 @@ def generate_audio_gradio(
|
|
190 |
generation_func = generate_fm
|
191 |
sampler_arg_name = "fm"
|
192 |
|
193 |
-
prompts = [prompt]
|
194 |
audios = generation_func(
|
195 |
-
|
196 |
-
negative_text=[negative_prompt],
|
197 |
feature_utils=feature_utils,
|
198 |
net=net,
|
199 |
rng=rng,
|
200 |
cfg_strength=cfg_strength,
|
201 |
**{sampler_arg_name: sampler},
|
202 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
203 |
audio = audios.float().cpu()[0]
|
204 |
safe_prompt = (
|
205 |
"".join(c for c in prompt if c.isalnum() or c in (" ", "_"))
|
|
|
8 |
import torchaudio
|
9 |
import gradio as gr
|
10 |
from transformers import AutoModel
|
11 |
+
import laion_clap
|
12 |
from meanaudio.eval_utils import (
|
13 |
ModelConfig,
|
14 |
all_model_cfg,
|
|
|
32 |
setup_eval_logging()
|
33 |
OUTPUT_DIR = Path("./output/gradio")
|
34 |
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
|
35 |
+
NUM_SAMPLE=8
|
36 |
snapshot_download(repo_id="google/flan-t5-large")
|
37 |
a=AutoModel.from_pretrained('bert-base-uncased')
|
38 |
b=AutoModel.from_pretrained('roberta-base')
|
39 |
snapshot_download(repo_id="junxiliu/Meanaudio", local_dir="./weights",allow_patterns=["*.pt", "*.pth"] )
|
40 |
+
_clap_ckpt_path='./weights/music_speech_audioset_epoch_15_esc_89.98.pt'
|
41 |
+
laion_clap_model = laion_clap.CLAP_Module(enable_fusion=False,
|
42 |
+
amodel='HTSAT-base').cuda().eval()
|
43 |
+
laion_clap_model.load_ckpt(_clap_ckpt_path, verbose=False)
|
44 |
current_model_states = {
|
45 |
|
46 |
}
|
|
|
194 |
generation_func = generate_fm
|
195 |
sampler_arg_name = "fm"
|
196 |
|
|
|
197 |
audios = generation_func(
|
198 |
+
[prompt]*NUM_SAMPLE,
|
199 |
+
negative_text=[negative_prompt]*NUM_SAMPLE,
|
200 |
feature_utils=feature_utils,
|
201 |
net=net,
|
202 |
rng=rng,
|
203 |
cfg_strength=cfg_strength,
|
204 |
**{sampler_arg_name: sampler},
|
205 |
)
|
206 |
+
for i in range(NUM_SAMPLE):
|
207 |
+
audio = audios.float().cpu()[i]
|
208 |
+
text_embed = laion_clap_model.get_text_embedding(prompt, use_tensor=True).squeeze()
|
209 |
+
audio_embed = laion_clap_model.get_audio_embedding_from_data(audio, use_tensor=True).squeeze()
|
210 |
+
score = torch.cosine_similarity(text_embed,
|
211 |
+
audio_embed,
|
212 |
+
dim=-1).mean()
|
213 |
+
all_audios.append(audio)
|
214 |
+
all_scores.append(score)
|
215 |
+
winner_idx = torch.argmax(torch.tensor(all_scores)).item()
|
216 |
+
audio=all_audios[winner_idx]
|
217 |
audio = audios.float().cpu()[0]
|
218 |
safe_prompt = (
|
219 |
"".join(c for c in prompt if c.isalnum() or c in (" ", "_"))
|
easyinfer.py
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
from MeanAudio import MeanAudioInference
|
2 |
-
audio_path=MeanAudioInference('a dog is barking')
|
3 |
-
print(audio_path)
|
|
|
|
|
|
|
|