AndreasXi commited on
Commit
0ff9928
·
1 Parent(s): 73214d1

update new model versions and test

Browse files
.DS_Store ADDED
Binary file (6.15 kB). View file
 
__pycache__/MeanAudio.cpython-311.pyc DELETED
Binary file (8.39 kB)
 
app.py CHANGED
@@ -16,6 +16,7 @@ from meanaudio.eval_utils import (
16
  generate_fm,
17
  setup_eval_logging,
18
  )
 
19
  from meanaudio.model.flow_matching import FlowMatching
20
  from meanaudio.model.mean_flow import MeanFlow
21
  from meanaudio.model.networks import MeanAudio, get_mean_audio
@@ -25,117 +26,28 @@ torch.backends.cudnn.allow_tf32 = True
25
  import gc
26
  from datetime import datetime
27
  from huggingface_hub import snapshot_download
 
28
  log = logging.getLogger()
29
  device = "cpu"
 
30
  if torch.cuda.is_available():
31
  device = "cuda"
32
  setup_eval_logging()
 
33
  OUTPUT_DIR = Path("./output/gradio")
34
  OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
35
  NUM_SAMPLE=7
36
- snapshot_download(repo_id="google/flan-t5-large")
37
- a=AutoModel.from_pretrained('bert-base-uncased')
38
- b=AutoModel.from_pretrained('roberta-base')
39
- snapshot_download(repo_id="junxiliu/Meanaudio", local_dir="./weights",allow_patterns=["*.pt", "*.pth"] )
40
- _clap_ckpt_path='./weights/music_speech_audioset_epoch_15_esc_89.98.pt'
41
- laion_clap_model = laion_clap.CLAP_Module(enable_fusion=False,
42
- amodel='HTSAT-base').cuda().eval()
43
- laion_clap_model.load_ckpt(_clap_ckpt_path, verbose=False)
44
- current_model_states = {
45
-
46
- }
47
-
48
- def load_model_if_needed(
49
- variant, model_path, encoder_name, use_rope, text_c_dim
50
- ):
51
- global current_model_states
52
- dtype = torch.float32
53
- existing_state = current_model_states.get(variant)
54
- needs_reload = (
55
- existing_state is None
56
- or existing_state["args"].variant != variant
57
- or existing_state["args"].model_path != model_path
58
- or existing_state["args"].encoder_name != encoder_name
59
- or existing_state["args"].use_rope != use_rope
60
- or existing_state["args"].text_c_dim != text_c_dim
61
- )
62
- if needs_reload:
63
- log.info(f"Loading/reloading model '{variant}'.")
64
- if variant not in all_model_cfg:
65
- raise ValueError(f"Unknown model variant: {variant}")
66
- model: ModelConfig = all_model_cfg[variant]
67
- seq_cfg = model.seq_cfg
68
-
69
- class MockArgs:
70
- pass
71
- mock_args = MockArgs()
72
- mock_args.variant = variant
73
- mock_args.model_path = model_path
74
- mock_args.encoder_name = encoder_name
75
- mock_args.use_rope = use_rope
76
- mock_args.text_c_dim = text_c_dim
77
-
78
- net: MeanAudio = (
79
- get_mean_audio(
80
- model.model_name,
81
- use_rope=mock_args.use_rope,
82
- text_c_dim=mock_args.text_c_dim,
83
- )
84
- .to(device, dtype)
85
- .eval()
86
- )
87
- net.load_weights(
88
- torch.load(
89
- mock_args.model_path, map_location=device, weights_only=True
90
- )
91
- )
92
- log.info(f"Loaded weights from {mock_args.model_path}")
93
-
94
- feature_utils = FeaturesUtils(
95
- tod_vae_ckpt=model.vae_path,
96
- enable_conditions=True,
97
- encoder_name=mock_args.encoder_name,
98
- mode=model.mode,
99
- bigvgan_vocoder_ckpt=model.bigvgan_16k_path,
100
- need_vae_encoder=False,
101
- )
102
- feature_utils = feature_utils.to(device, dtype).eval()
103
-
104
- current_model_states[variant] = {
105
- "net": net,
106
- "feature_utils": feature_utils,
107
- "seq_cfg": seq_cfg,
108
- "args": mock_args,
109
- }
110
- log.info(f"Model '{variant}' loaded successfully.")
111
-
112
- return net, feature_utils, seq_cfg, mock_args
113
- else:
114
- log.info(f"Model '{variant}' already loaded with current settings. Skipping reload.")
115
-
116
- return existing_state["net"], existing_state["feature_utils"], existing_state["seq_cfg"], existing_state["args"]
117
 
118
- def initialize_all_default_models():
119
- log.info("Initializing default models...")
120
- default_models = ['meanaudio_mf', 'fluxaudio_fm']
121
- common_params = {
122
- "encoder_name": "t5_clap",
123
- "use_rope": True,
124
- "text_c_dim": 512,
125
 
126
- }
127
- for variant in default_models:
128
- model_path = f"./weights/{variant}.pth"
129
 
130
- try:
131
- load_model_if_needed(
132
- variant, model_path, **common_params
133
- )
134
- log.info(f"Default model '{variant}' initialized successfully.")
135
- except Exception as e:
136
- log.error(f"Failed to initialize default model '{variant}': {e}")
137
 
138
- initialize_all_default_models()
139
 
140
  @spaces.GPU(duration=10)
141
  @torch.inference_mode()
@@ -148,44 +60,42 @@ def generate_audio_gradio(
148
  seed,
149
  variant,
150
  ):
151
- global current_model_states
152
-
153
- model_path = f"./weights/{variant}.pth"
154
- encoder_name = "t5_clap"
155
- use_rope = True
156
- text_c_dim = 512
157
-
158
- model_state = current_model_states.get(variant)
159
- if model_state is None:
160
- error_msg = f"Error: Model '{variant}' is not available. It may not have been loaded correctly during startup."
161
- log.error(error_msg)
162
- return error_msg, None
163
-
164
- net = model_state["net"]
165
- feature_utils = model_state["feature_utils"]
166
- seq_cfg = model_state["seq_cfg"]
167
-
168
- args = model_state["args"]
169
  dtype = torch.float32
 
 
 
 
170
 
171
- temp_seq_cfg = type(seq_cfg)(**seq_cfg.__dict__)
172
- temp_seq_cfg.duration = duration
 
 
 
 
 
 
 
173
 
174
- net.update_seq_lengths(temp_seq_cfg.latent_seq_len)
 
 
 
 
 
 
175
 
176
- rng = torch.Generator(device=device)
177
- if seed >= 0:
178
- rng.manual_seed(seed)
179
- else:
180
- rng.seed()
181
 
182
- use_meanflow = variant == "meanaudio_mf"
 
 
 
 
183
  if use_meanflow:
184
  sampler = MeanFlow(steps=num_steps)
185
  log.info("Using MeanFlow for generation.")
186
  generation_func = generate_mf
187
  sampler_arg_name = "mf"
188
- cfg_strength = 3
189
  else:
190
  sampler = FlowMatching(
191
  min_sigma=0, inference_mode="euler", num_steps=num_steps
@@ -193,6 +103,10 @@ def generate_audio_gradio(
193
  log.info("Using FlowMatching for generation.")
194
  generation_func = generate_fm
195
  sampler_arg_name = "fm"
 
 
 
 
196
  audios = generation_func(
197
  [prompt]*NUM_SAMPLE,
198
  negative_text=[negative_prompt]*NUM_SAMPLE,
@@ -205,11 +119,11 @@ def generate_audio_gradio(
205
  text_embed = laion_clap_model.get_text_embedding(prompt, use_tensor=True).squeeze()
206
  audio_embed = laion_clap_model.get_audio_embedding_from_data(audios[:,0,:].float().cpu(), use_tensor=True).squeeze()
207
  scores = torch.cosine_similarity(text_embed.expand(audio_embed.shape[0], -1),
208
- audio_embed,
209
- dim=-1)
210
  log.info(scores)
211
  log.info(torch.argmax(scores).item())
212
- audio=audios[torch.argmax(scores).item()].float().cpu()
213
  safe_prompt = (
214
  "".join(c for c in prompt if c.isalnum() or c in (" ", "_"))
215
  .rstrip()
@@ -400,7 +314,6 @@ with gr.Blocks(title="MeanAudio Generator", theme=theme, css=custom_css) as demo
400
  interactive=True,
401
  scale=3,
402
  )
403
-
404
  with gr.Column(elem_classes="setting-section"):
405
  with gr.Row():
406
  prompt = gr.Textbox(
 
16
  generate_fm,
17
  setup_eval_logging,
18
  )
19
+
20
  from meanaudio.model.flow_matching import FlowMatching
21
  from meanaudio.model.mean_flow import MeanFlow
22
  from meanaudio.model.networks import MeanAudio, get_mean_audio
 
26
  import gc
27
  from datetime import datetime
28
  from huggingface_hub import snapshot_download
29
+
30
  log = logging.getLogger()
31
  device = "cpu"
32
+
33
  if torch.cuda.is_available():
34
  device = "cuda"
35
  setup_eval_logging()
36
+
37
  OUTPUT_DIR = Path("./output/gradio")
38
  OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
39
  NUM_SAMPLE=7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
 
41
+ snapshot_download(repo_id="google/flan-t5-large")
42
+ a = AutoModel.from_pretrained('bert-base-uncased')
43
+ b = AutoModel.from_pretrained('roberta-base')
 
 
 
 
44
 
45
+ snapshot_download(repo_id="AndreasXi/MeanAudio", local_dir="./weights",allow_patterns=["*.pt", "*.pth"] )
46
+ _clap_ckpt_path='./weights/music_speech_audioset_epoch_15_esc_89.98.pt'
47
+ laion_clap_model = laion_clap.CLAP_Module(enable_fusion=False, amodel='HTSAT-base').cuda().eval()
48
 
49
+ laion_clap_model.load_ckpt(_clap_ckpt_path, verbose=False)
 
 
 
 
 
 
50
 
 
51
 
52
  @spaces.GPU(duration=10)
53
  @torch.inference_mode()
 
60
  seed,
61
  variant,
62
  ):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  dtype = torch.float32
64
+ if duration <= 0 or num_steps <= 0:
65
+ raise ValueError("Duration and number of steps must be positive.")
66
+ if variant not in all_model_cfg:
67
+ raise ValueError(f"Unknown model variant: {variant}. Available: {list(all_model_cfg.keys())}")
68
 
69
+ model_path = all_model_cfg[variant].model_path # by default, this will use meanaudio_s_full.pth or fluxaudio_s_full.pth
70
+ model = all_model_cfg[variant]
71
+ seq_cfg = model.seq_cfg
72
+ seq_cfg.duration = duration
73
+
74
+ net = get_mean_audio(model.model_name, use_rope=True, text_c_dim=512)
75
+ net = net.to(device, dtype).eval()
76
+ net.load_weights(torch.load(model_path, map_location=device, weights_only=True))
77
+ net.update_seq_lengths(seq_cfg.latent_seq_len)
78
 
79
+ feature_utils = FeaturesUtils(tod_vae_ckpt=model.vae_path,
80
+ enable_conditions=True,
81
+ encoder_name="t5_clap",
82
+ mode=model.mode,
83
+ bigvgan_vocoder_ckpt=model.bigvgan_16k_path,
84
+ need_vae_encoder=False)
85
+ feature_utils = feature_utils.to(device, dtype).eval()
86
 
 
 
 
 
 
87
 
88
+ if variant == 'meanaudio_s_ac' or variant == 'meanaudio_s_full':
89
+ use_meanflow=True
90
+ elif variant == 'fluxaudio_s_full':
91
+ use_meanflow=False
92
+
93
  if use_meanflow:
94
  sampler = MeanFlow(steps=num_steps)
95
  log.info("Using MeanFlow for generation.")
96
  generation_func = generate_mf
97
  sampler_arg_name = "mf"
98
+ cfg_strength = 0
99
  else:
100
  sampler = FlowMatching(
101
  min_sigma=0, inference_mode="euler", num_steps=num_steps
 
103
  log.info("Using FlowMatching for generation.")
104
  generation_func = generate_fm
105
  sampler_arg_name = "fm"
106
+
107
+ rng = torch.Generator(device=device)
108
+ # rng.manual_seed(seed)
109
+
110
  audios = generation_func(
111
  [prompt]*NUM_SAMPLE,
112
  negative_text=[negative_prompt]*NUM_SAMPLE,
 
119
  text_embed = laion_clap_model.get_text_embedding(prompt, use_tensor=True).squeeze()
120
  audio_embed = laion_clap_model.get_audio_embedding_from_data(audios[:,0,:].float().cpu(), use_tensor=True).squeeze()
121
  scores = torch.cosine_similarity(text_embed.expand(audio_embed.shape[0], -1),
122
+ audio_embed,
123
+ dim=-1)
124
  log.info(scores)
125
  log.info(torch.argmax(scores).item())
126
+ audio = audios[torch.argmax(scores).item()].float().cpu()
127
  safe_prompt = (
128
  "".join(c for c in prompt if c.isalnum() or c in (" ", "_"))
129
  .rstrip()
 
314
  interactive=True,
315
  scale=3,
316
  )
 
317
  with gr.Column(elem_classes="setting-section"):
318
  with gr.Row():
319
  prompt = gr.Textbox(
meanaudio/eval_utils.py CHANGED
@@ -43,20 +43,26 @@ class ModelConfig:
43
  download_model_if_needed(self.bigvgan_16k_path)
44
 
45
 
46
- fluxaudio_fm = ModelConfig(model_name='fluxaudio_fm',
47
- model_path=Path('./weights/fluxaudio_fm.pth'),
48
  vae_path=Path('./weights/v1-16.pth'),
49
  bigvgan_16k_path=Path('./weights/best_netG.pt'),
50
  mode='16k')
51
- meanaudio_mf = ModelConfig(model_name='meanaudio_mf',
52
- model_path=Path('./weights/meanaudio_mf.pth'),
 
 
 
 
 
53
  vae_path=Path('./weights/v1-16.pth'),
54
  bigvgan_16k_path=Path('./weights/best_netG.pt'),
55
  mode='16k')
56
 
57
  all_model_cfg: dict[str, ModelConfig] = {
58
- 'fluxaudio_fm': fluxaudio_fm,
59
- 'meanaudio_mf': meanaudio_mf,
 
60
  }
61
 
62
 
 
43
  download_model_if_needed(self.bigvgan_16k_path)
44
 
45
 
46
+ fluxaudio_s_full = ModelConfig(model_name='fluxaudio_s_full',
47
+ model_path=Path('./weights/fluxaudio_s_full.pth'), # will be specified later
48
  vae_path=Path('./weights/v1-16.pth'),
49
  bigvgan_16k_path=Path('./weights/best_netG.pt'),
50
  mode='16k')
51
+ meanaudio_s_full = ModelConfig(model_name='meanaudio_s_full',
52
+ model_path=Path('./weights/meanaudio_s_full.pth'), # will be specified later
53
+ vae_path=Path('./weights/v1-16.pth'),
54
+ bigvgan_16k_path=Path('./weights/best_netG.pt'),
55
+ mode='16k')
56
+ meanaudio_s_ac = ModelConfig(model_name='meanaudio_s_ac',
57
+ model_path=Path('./weights/meanaudio_s_ac.pth'), # will be specified later
58
  vae_path=Path('./weights/v1-16.pth'),
59
  bigvgan_16k_path=Path('./weights/best_netG.pt'),
60
  mode='16k')
61
 
62
  all_model_cfg: dict[str, ModelConfig] = {
63
+ 'fluxaudio_s_full': fluxaudio_s_full,
64
+ 'meanaudio_s_full': meanaudio_s_full,
65
+ 'meanaudio_s_ac': meanaudio_s_ac,
66
  }
67
 
68
 
meanaudio/model/networks.py CHANGED
@@ -577,7 +577,7 @@ class MeanAudio(nn.Module):
577
  return self._latent_seq_len
578
 
579
 
580
- def fluxaudio_fm(**kwargs) -> FluxAudio:
581
  num_heads = 7
582
  return FluxAudio(latent_dim=20,
583
  text_dim=1024,
@@ -587,7 +587,7 @@ def fluxaudio_fm(**kwargs) -> FluxAudio:
587
  num_heads=num_heads,
588
  latent_seq_len=312, # for 10s audio
589
  **kwargs)
590
- def meanaudio_mf(**kwargs) -> MeanAudio:
591
  num_heads = 7
592
  return MeanAudio(latent_dim=20,
593
  text_dim=1024,
@@ -600,10 +600,10 @@ def meanaudio_mf(**kwargs) -> MeanAudio:
600
 
601
 
602
  def get_mean_audio(name: str, **kwargs) -> MeanAudio:
603
- if name == 'meanaudio_mf':
604
- return meanaudio_mf(**kwargs)
605
- if name == 'fluxaudio_fm':
606
- return fluxaudio_fm(**kwargs)
607
 
608
  raise ValueError(f'Unknown model name: {name}')
609
 
@@ -620,7 +620,7 @@ if __name__ == '__main__':
620
  ]
621
  )
622
 
623
- network: MeanAudio = get_mean_audio('meanaudio_mf',
624
  use_rope=False,
625
  text_c_dim=512)
626
 
 
577
  return self._latent_seq_len
578
 
579
 
580
+ def fluxaudio_s(**kwargs) -> FluxAudio:
581
  num_heads = 7
582
  return FluxAudio(latent_dim=20,
583
  text_dim=1024,
 
587
  num_heads=num_heads,
588
  latent_seq_len=312, # for 10s audio
589
  **kwargs)
590
+ def meanaudio_s(**kwargs) -> MeanAudio:
591
  num_heads = 7
592
  return MeanAudio(latent_dim=20,
593
  text_dim=1024,
 
600
 
601
 
602
  def get_mean_audio(name: str, **kwargs) -> MeanAudio:
603
+ if name == 'meanaudio_s':
604
+ return meanaudio_s(**kwargs)
605
+ if name == 'fluxaudio_s':
606
+ return fluxaudio_s(**kwargs)
607
 
608
  raise ValueError(f'Unknown model name: {name}')
609
 
 
620
  ]
621
  )
622
 
623
+ network: MeanAudio = get_mean_audio('meanaudio_s',
624
  use_rope=False,
625
  text_c_dim=512)
626