AndreasXi commited on
Commit
f47e09f
·
1 Parent(s): fbfb8b6
Files changed (1) hide show
  1. app.py +42 -26
app.py CHANGED
@@ -38,18 +38,48 @@ OUTPUT_DIR = Path("./output/gradio")
38
  OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
39
  NUM_SAMPLE = 1
40
 
41
- # snapshot_download(repo_id="google/flan-t5-large")
42
- # a = AutoModel.from_pretrained('bert-base-uncased')
43
- # b = AutoModel.from_pretrained('roberta-base')
44
 
45
- # snapshot_download(repo_id="AndreasXi/MeanAudio", local_dir="./weights",allow_patterns=["*.pt", "*.pth"] )
46
- # _clap_ckpt_path='./weights/music_speech_audioset_epoch_15_esc_89.98.pt'
47
- # laion_clap_model = laion_clap.CLAP_Module(enable_fusion=False, amodel='HTSAT-base').cuda().eval()
 
 
 
48
 
49
- # laion_clap_model.load_ckpt(_clap_ckpt_path, verbose=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
 
51
 
52
- @spaces.GPU(duration=10)
53
  @torch.inference_mode()
54
  def generate_audio_gradio(
55
  prompt,
@@ -66,29 +96,14 @@ def generate_audio_gradio(
66
  if variant not in all_model_cfg:
67
  raise ValueError(f"Unknown model variant: {variant}. Available: {list(all_model_cfg.keys())}")
68
 
69
- model_path = all_model_cfg[variant].model_path # by default, this will use meanaudio_s_full.pth or fluxaudio_s_full.pth
70
- if not model_path.exists():
71
- log.info(f'Model not found at {model_path}')
72
- log.info('Downloading models to "./weights/"...')
73
- snapshot_download(repo_id="AndreasXi/MeanAudio", local_dir="./weights",allow_patterns=["*.pt", "*.pth"] )
74
-
75
  model = all_model_cfg[variant]
76
  seq_cfg = model.seq_cfg
77
  seq_cfg.duration = duration
78
 
79
- net = get_mean_audio(model.model_name, use_rope=True, text_c_dim=512)
80
- net = net.to(device, dtype).eval()
81
- net.load_weights(torch.load(model_path, map_location=device, weights_only=True))
82
  net.update_seq_lengths(seq_cfg.latent_seq_len)
83
 
84
- feature_utils = FeaturesUtils(tod_vae_ckpt=model.vae_path,
85
- enable_conditions=True,
86
- encoder_name="t5_clap",
87
- mode=model.mode,
88
- bigvgan_vocoder_ckpt=model.bigvgan_16k_path,
89
- need_vae_encoder=False)
90
- feature_utils = feature_utils.to(device, dtype).eval()
91
-
92
 
93
  if variant == 'meanaudio_s_ac' or variant == 'meanaudio_s_full':
94
  use_meanflow=True
@@ -141,7 +156,8 @@ def generate_audio_gradio(
141
  torchaudio.save(str(save_path), audio, seq_cfg.sampling_rate)
142
  log.info(f"Audio saved to {save_path}")
143
 
144
- gc.collect()
 
145
 
146
  return (
147
  f"Generated audio for prompt: '{prompt}' using {'MeanFlow' if use_meanflow else 'FlowMatching'}",
 
38
  OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
39
  NUM_SAMPLE = 1
40
 
41
+ # Global model cache to avoid reloading
42
+ MODEL_CACHE = {}
43
+ FEATURE_UTILS_CACHE = {}
44
 
45
+ def ensure_models_downloaded():
46
+ for variant, model_cfg in all_model_cfg.items():
47
+ if not model_cfg.model_path.exists():
48
+ log.info(f'Model {variant} not found, downloading...')
49
+ snapshot_download(repo_id="AndreasXi/MeanAudio", local_dir="./weights", allow_patterns=["*.pt", "*.pth"])
50
+ break
51
 
52
+ def load_model_if_needed(variant: str):
53
+ if variant in MODEL_CACHE:
54
+ return MODEL_CACHE[variant], FEATURE_UTILS_CACHE[variant]
55
+
56
+ log.info(f"Loading model {variant} for the first time...")
57
+ model_cfg = all_model_cfg[variant]
58
+
59
+ net = get_mean_audio(model_cfg.model_name, use_rope=True, text_c_dim=512)
60
+ net = net.to(device, torch.bfloat16).eval()
61
+ net.load_weights(torch.load(model_cfg.model_path, map_location=device, weights_only=True))
62
+
63
+ feature_utils = FeaturesUtils(
64
+ tod_vae_ckpt=model_cfg.vae_path,
65
+ enable_conditions=True,
66
+ encoder_name="t5_clap",
67
+ mode=model_cfg.mode,
68
+ bigvgan_vocoder_ckpt=model_cfg.bigvgan_16k_path,
69
+ need_vae_encoder=False
70
+ )
71
+ feature_utils = feature_utils.to(device, torch.bfloat16).eval()
72
+
73
+ MODEL_CACHE[variant] = net
74
+ FEATURE_UTILS_CACHE[variant] = feature_utils
75
+
76
+ log.info(f"Model {variant} loaded and cached successfully")
77
+ return net, feature_utils
78
+
79
+ ensure_models_downloaded()
80
 
81
 
82
+ @spaces.GPU(duration=60)
83
  @torch.inference_mode()
84
  def generate_audio_gradio(
85
  prompt,
 
96
  if variant not in all_model_cfg:
97
  raise ValueError(f"Unknown model variant: {variant}. Available: {list(all_model_cfg.keys())}")
98
 
99
+ net, feature_utils = load_model_if_needed(variant)
100
+
 
 
 
 
101
  model = all_model_cfg[variant]
102
  seq_cfg = model.seq_cfg
103
  seq_cfg.duration = duration
104
 
 
 
 
105
  net.update_seq_lengths(seq_cfg.latent_seq_len)
106
 
 
 
 
 
 
 
 
 
107
 
108
  if variant == 'meanaudio_s_ac' or variant == 'meanaudio_s_full':
109
  use_meanflow=True
 
156
  torchaudio.save(str(save_path), audio, seq_cfg.sampling_rate)
157
  log.info(f"Audio saved to {save_path}")
158
 
159
+ if device == "cuda":
160
+ torch.cuda.empty_cache()
161
 
162
  return (
163
  f"Generated audio for prompt: '{prompt}' using {'MeanFlow' if use_meanflow else 'FlowMatching'}",