junxiliu commited on
Commit
fec53ba
·
1 Parent(s): 93f6a95

update app.py

Browse files
Files changed (3) hide show
  1. MeanAudio.py +0 -147
  2. app.py +19 -5
  3. easyinfer.py +0 -3
MeanAudio.py DELETED
@@ -1,147 +0,0 @@
1
- import warnings
2
- warnings.filterwarnings("ignore", category=FutureWarning)
3
- import logging
4
- from pathlib import Path
5
- import torch
6
- import torchaudio
7
- from meanaudio.eval_utils import (ModelConfig, all_model_cfg, generate_mf, generate_fm, setup_eval_logging)
8
- from meanaudio.model.flow_matching import FlowMatching
9
- from meanaudio.model.mean_flow import MeanFlow
10
- from meanaudio.model.networks import MeanAudio, get_mean_audio
11
- from meanaudio.model.utils.features_utils import FeaturesUtils
12
- from huggingface_hub import snapshot_download
13
-
14
- torch.backends.cuda.matmul.allow_tf32 = True
15
- torch.backends.cudnn.allow_tf32 = True
16
- log = logging.getLogger()
17
-
18
- @torch.inference_mode()
19
- def MeanAudioInference(
20
- prompt='',
21
- negative_prompt='',
22
- model_path='',
23
- encoder_name='t5_clap',
24
- variant='meanaudio_mf',
25
- duration=10,
26
- cfg_strength=4.5,
27
- num_steps=1,
28
- output='./output',
29
- seed=42,
30
- full_precision=False,
31
- use_rope=True,
32
- text_c_dim=512,
33
- use_meanflow=False
34
- ):
35
- '''
36
- prompt (str):
37
- The text description guiding the audio generation (e.g., "a dog is barking").
38
- negative_prompt (str):
39
- A text description for sounds that should be avoided in the generated audio.
40
- model_path (str):
41
- Path to the model weights file. If empty, it defaults to ./weights/{variant}.pth.
42
- encoder_name (str):
43
- Specifies the text encoder to use (default: 't5_clap').
44
- variant (str):
45
- Specifies the model variant to load (default: 'meanaudio_mf'). Must be a key in all_model_cfg.
46
- duration (int):
47
- The desired duration of the generated audio in seconds (default: 10).
48
- cfg_strength (float):
49
- Classifier-Free Guidance strength. Ignored if use_meanflow is True or variant is 'meanaudio_mf' (default: 4.5).
50
- num_steps (int):
51
- Number of steps for the generation process (default: 1).
52
- output (str):
53
- Directory path where the generated audio file will be saved (default: './output').
54
- seed (int):
55
- Random seed for generation reproducibility (default: 42).
56
- full_precision (bool):
57
- If True, uses torch.float32 precision; otherwise, uses torch.bfloat16 (default: False).
58
- use_rope (bool):
59
- Whether to use Rotary Position Embedding in the model (default: True).
60
- text_c_dim (int):
61
- Dimension of the text context vector (default: 512).
62
- use_meanflow (bool):
63
- If True, uses the MeanFlow generation method; otherwise, uses FlowMatching. If variant is 'meanaudio_mf', this is automatically set to True (default: False).
64
- '''
65
- setup_eval_logging()
66
- output_dir = Path(output).expanduser()
67
- output_dir.mkdir(parents=True, exist_ok=True)
68
- device = 'cuda' if torch.cuda.is_available() else 'cpu'
69
- dtype = torch.float32 if full_precision else torch.bfloat16
70
- if duration <= 0 or num_steps <= 0:
71
- raise ValueError("Duration and number of steps must be positive.")
72
- if variant not in all_model_cfg:
73
- raise ValueError(f"Unknown model variant: {variant}. Available: {list(all_model_cfg.keys())}")
74
- if not model_path or model_path == '':
75
- model_path = Path(f'./weights/{variant}.pth')
76
- else:
77
- model_path = Path(model_path)
78
- if not model_path.exists():
79
- if str(model_path) == f'./weights/{variant}.pth':
80
- log.info(f'Model not found at {model_path}')
81
- log.info('Downloading models to "./weights/"...')
82
- try:
83
- weights_dir = Path('./weights')
84
- weights_dir.mkdir(exist_ok=True)
85
- snapshot_download(repo_id="junxiliu/Meanaudio", local_dir="./weights",allow_patterns=["*.pt", "*.pth"] )
86
- raise NotImplementedError("Model download functionality needs to be implemented")
87
- except Exception as e:
88
- log.error(f"Failed to download model: {e}")
89
- raise FileNotFoundError(f"Model file not found and download failed: {model_path}")
90
- else:
91
- raise FileNotFoundError(f"Model file not found: {model_path}")
92
-
93
- model = all_model_cfg[variant]
94
- seq_cfg = model.seq_cfg
95
- seq_cfg.duration = duration
96
-
97
- net = get_mean_audio(model.model_name, use_rope=use_rope, text_c_dim=text_c_dim)
98
- net = net.to(device, dtype).eval()
99
- net.load_weights(torch.load(model_path, map_location=device, weights_only=True))
100
- net.update_seq_lengths(seq_cfg.latent_seq_len)
101
-
102
- if variant=='meanaudio_mf':
103
- use_meanflow=True
104
- if use_meanflow:
105
- generation_func = MeanFlow(steps=num_steps)
106
- cfg_strength=0
107
- else:
108
- generation_func = FlowMatching(min_sigma=0, inference_mode='euler', num_steps=num_steps)
109
-
110
- feature_utils = FeaturesUtils(
111
- tod_vae_ckpt=model.vae_path,
112
- enable_conditions=True,
113
- encoder_name=encoder_name,
114
- mode=model.mode,
115
- bigvgan_vocoder_ckpt=model.bigvgan_16k_path,
116
- need_vae_encoder=False
117
- )
118
- feature_utils = feature_utils.to(device, dtype).eval()
119
-
120
- rng = torch.Generator(device=device)
121
- rng.manual_seed(seed)
122
-
123
- generate_fn = generate_mf if use_meanflow else generate_fm
124
- kwargs = {
125
- 'negative_text': [negative_prompt],
126
- 'feature_utils': feature_utils,
127
- 'net': net,
128
- 'rng': rng,
129
- 'cfg_strength': cfg_strength
130
- }
131
-
132
- if use_meanflow:
133
- kwargs['mf'] = generation_func
134
- else:
135
- kwargs['fm'] = generation_func
136
-
137
- audios = generate_fn([prompt], **kwargs)
138
- audio = audios.float().cpu()[0]
139
- safe_filename = prompt.replace(' ', '_').replace('/', '_').replace('.', '')
140
- save_path = output_dir / f'{safe_filename}--numsteps{num_steps}--seed{seed}.wav'
141
- torchaudio.save(save_path, audio, seq_cfg.sampling_rate)
142
- log.info(f'Audio saved to {save_path}')
143
- log.info('Memory usage: %.2f GB', torch.cuda.max_memory_allocated() / (2**30))
144
- return save_path
145
-
146
- if __name__ == '__main__':
147
- MeanAudioInference('a dog is barking')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app.py CHANGED
@@ -8,6 +8,7 @@ import torch
8
  import torchaudio
9
  import gradio as gr
10
  from transformers import AutoModel
 
11
  from meanaudio.eval_utils import (
12
  ModelConfig,
13
  all_model_cfg,
@@ -31,12 +32,15 @@ if torch.cuda.is_available():
31
  setup_eval_logging()
32
  OUTPUT_DIR = Path("./output/gradio")
33
  OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
34
-
35
  snapshot_download(repo_id="google/flan-t5-large")
36
  a=AutoModel.from_pretrained('bert-base-uncased')
37
  b=AutoModel.from_pretrained('roberta-base')
38
  snapshot_download(repo_id="junxiliu/Meanaudio", local_dir="./weights",allow_patterns=["*.pt", "*.pth"] )
39
-
 
 
 
40
  current_model_states = {
41
 
42
  }
@@ -190,16 +194,26 @@ def generate_audio_gradio(
190
  generation_func = generate_fm
191
  sampler_arg_name = "fm"
192
 
193
- prompts = [prompt]
194
  audios = generation_func(
195
- prompts,
196
- negative_text=[negative_prompt],
197
  feature_utils=feature_utils,
198
  net=net,
199
  rng=rng,
200
  cfg_strength=cfg_strength,
201
  **{sampler_arg_name: sampler},
202
  )
 
 
 
 
 
 
 
 
 
 
 
203
  audio = audios.float().cpu()[0]
204
  safe_prompt = (
205
  "".join(c for c in prompt if c.isalnum() or c in (" ", "_"))
 
8
  import torchaudio
9
  import gradio as gr
10
  from transformers import AutoModel
11
+ import laion_clap
12
  from meanaudio.eval_utils import (
13
  ModelConfig,
14
  all_model_cfg,
 
32
  setup_eval_logging()
33
  OUTPUT_DIR = Path("./output/gradio")
34
  OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
35
+ NUM_SAMPLE=8
36
  snapshot_download(repo_id="google/flan-t5-large")
37
  a=AutoModel.from_pretrained('bert-base-uncased')
38
  b=AutoModel.from_pretrained('roberta-base')
39
  snapshot_download(repo_id="junxiliu/Meanaudio", local_dir="./weights",allow_patterns=["*.pt", "*.pth"] )
40
+ _clap_ckpt_path='./weights/music_speech_audioset_epoch_15_esc_89.98.pt'
41
+ laion_clap_model = laion_clap.CLAP_Module(enable_fusion=False,
42
+ amodel='HTSAT-base').cuda().eval()
43
+ laion_clap_model.load_ckpt(_clap_ckpt_path, verbose=False)
44
  current_model_states = {
45
 
46
  }
 
194
  generation_func = generate_fm
195
  sampler_arg_name = "fm"
196
 
 
197
  audios = generation_func(
198
+ [prompt]*NUM_SAMPLE,
199
+ negative_text=[negative_prompt]*NUM_SAMPLE,
200
  feature_utils=feature_utils,
201
  net=net,
202
  rng=rng,
203
  cfg_strength=cfg_strength,
204
  **{sampler_arg_name: sampler},
205
  )
206
+ for i in range(NUM_SAMPLE):
207
+ audio = audios.float().cpu()[i]
208
+ text_embed = laion_clap_model.get_text_embedding(prompt, use_tensor=True).squeeze()
209
+ audio_embed = laion_clap_model.get_audio_embedding_from_data(audio, use_tensor=True).squeeze()
210
+ score = torch.cosine_similarity(text_embed,
211
+ audio_embed,
212
+ dim=-1).mean()
213
+ all_audios.append(audio)
214
+ all_scores.append(score)
215
+ winner_idx = torch.argmax(torch.tensor(all_scores)).item()
216
+ audio=all_audios[winner_idx]
217
  audio = audios.float().cpu()[0]
218
  safe_prompt = (
219
  "".join(c for c in prompt if c.isalnum() or c in (" ", "_"))
easyinfer.py DELETED
@@ -1,3 +0,0 @@
1
- from MeanAudio import MeanAudioInference
2
- audio_path=MeanAudioInference('a dog is barking')
3
- print(audio_path)