ford442 commited on
Commit
dca097b
·
verified ·
1 Parent(s): 881ee4d

Update demos/musicgen_app.py

Browse files
Files changed (1) hide show
  1. demos/musicgen_app.py +13 -47
demos/musicgen_app.py CHANGED
@@ -8,11 +8,9 @@ import sys
8
  import time
9
  import typing as tp
10
  from tempfile import NamedTemporaryFile, gettempdir
11
-
12
  from einops import rearrange
13
  import torch
14
  import gradio as gr
15
-
16
  from audiocraft.data.audio_utils import convert_audio
17
  from audiocraft.data.audio import audio_write
18
  from audiocraft.models.encodec import InterleaveStereoCompressionModel
@@ -20,17 +18,26 @@ from audiocraft.models import MusicGen, MultiBandDiffusion
20
  import multiprocessing as mp
21
  import warnings
22
 
23
- # --- Utility Functions and Classes ---
 
 
 
 
 
 
 
 
 
 
 
24
 
25
  class FileCleaner:
26
  def __init__(self, file_lifetime: float = 3600):
27
  self.file_lifetime = file_lifetime
28
  self.files = []
29
-
30
  def add(self, path: tp.Union[str, Path]):
31
  self._cleanup()
32
  self.files.append((time.time(), Path(path)))
33
-
34
  def _cleanup(self):
35
  now = time.time()
36
  for time_added, path in list(self.files):
@@ -40,9 +47,9 @@ class FileCleaner:
40
  self.files.pop(0)
41
  else:
42
  break
 
43
  file_cleaner = FileCleaner()
44
 
45
-
46
  def convert_wav_to_mp4(wav_path, output_path=None):
47
  """Converts a WAV file to a waveform MP4 video using ffmpeg."""
48
  if output_path is None:
@@ -62,19 +69,14 @@ def convert_wav_to_mp4(wav_path, output_path=None):
62
  "-preset", "fast", # Important, don't do veryslow.
63
  str(output_path),
64
  ]
65
-
66
  process = sp.run(command, capture_output=True, text=True, check=True)
67
  return str(output_path)
68
-
69
  except sp.CalledProcessError as e:
70
  print(f"Error in ffmpeg conversion: {e}")
71
  print(f"ffmpeg stdout: {e.stdout}")
72
  print(f"ffmpeg stderr: {e.stderr}")
73
  raise # Re-raise the exception to be caught by Gradio
74
 
75
-
76
- # --- Worker Process ---
77
-
78
  def model_worker(model_name: str, task_queue: mp.Queue, result_queue: mp.Queue):
79
  """
80
  Persistent worker process (used when NOT running as a daemon).
@@ -83,14 +85,11 @@ def model_worker(model_name: str, task_queue: mp.Queue, result_queue: mp.Queue):
83
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
84
  model = MusicGen.get_pretrained(model_name, device=device)
85
  mbd = MultiBandDiffusion.get_mbd_musicgen(device=device)
86
-
87
  while True:
88
  task = task_queue.get()
89
  if task is None:
90
  break
91
-
92
  task_id, text, melody, duration, use_diffusion, gen_params = task
93
-
94
  try:
95
  model.set_generation_params(duration=duration, **gen_params)
96
  target_sr = model.sample_rate
@@ -103,7 +102,6 @@ def model_worker(model_name: str, task_queue: mp.Queue, result_queue: mp.Queue):
103
  melody_tensor = melody_tensor.unsqueeze(0)
104
  melody_tensor = melody_tensor[..., :int(sr * duration)]
105
  processed_melody = convert_audio(melody_tensor, sr, target_sr, target_ac)
106
-
107
  if processed_melody is not None:
108
  output, tokens = model.generate_with_chroma(
109
  descriptions=[text],
@@ -114,9 +112,7 @@ def model_worker(model_name: str, task_queue: mp.Queue, result_queue: mp.Queue):
114
  )
115
  else:
116
  output, tokens = model.generate([text], progress=True, return_tokens=True)
117
-
118
  output = output.detach().cpu()
119
-
120
  if use_diffusion:
121
  if isinstance(model.compression_model, InterleaveStereoCompressionModel):
122
  left, right = model.compression_model.get_left_right_codes(tokens)
@@ -129,16 +125,11 @@ def model_worker(model_name: str, task_queue: mp.Queue, result_queue: mp.Queue):
129
  result_queue.put((task_id, (output, outputs_diffusion)))
130
  else:
131
  result_queue.put((task_id, (output, None)))
132
-
133
  except Exception as e:
134
  result_queue.put((task_id, e))
135
-
136
  except Exception as e:
137
  result_queue.put((-1, e))
138
 
139
-
140
- # --- Predictor Class (Modified for conditional process creation) ---
141
-
142
  class Predictor:
143
  def __init__(self, model_name: str):
144
  self.model_name = model_name
@@ -190,7 +181,6 @@ class Predictor:
190
  melody_tensor = melody_tensor.unsqueeze(0)
191
  melody_tensor = melody_tensor[..., :int(sr * duration)]
192
  processed_melody = convert_audio(melody_tensor, sr, target_sr, target_ac)
193
-
194
  if processed_melody is not None:
195
  output, tokens = self.model.generate_with_chroma(
196
  descriptions=[text],
@@ -201,9 +191,7 @@ class Predictor:
201
  )
202
  else:
203
  output, tokens = self.model.generate([text], progress=True, return_tokens=True)
204
-
205
  output = output.detach().cpu()
206
-
207
  if use_diffusion:
208
  if isinstance(self.model.compression_model, InterleaveStereoCompressionModel):
209
  left, right = self.model.compression_model.get_left_right_codes(tokens)
@@ -216,11 +204,8 @@ class Predictor:
216
  return task_id, (output, outputs_diffusion) #Return the task id.
217
  else:
218
  return task_id, (output, None)
219
-
220
-
221
  except Exception as e:
222
  return task_id, e
223
-
224
  else:
225
  # Use the multiprocessing queue (multi-process mode)
226
  self.current_task_id += 1
@@ -239,7 +224,6 @@ class Predictor:
239
  result_task_id, result = self.result_queue.get()
240
  if result_task_id == task_id:
241
  break # Found the correct result
242
-
243
  if isinstance(result, Exception):
244
  raise result
245
  return result
@@ -250,14 +234,12 @@ class Predictor:
250
  self.task_queue.put(None)
251
  self.process.join()
252
 
253
-
254
  _default_model_name = "facebook/musicgen-melody"
255
 
256
  @spaces.GPU(duration=90) # Use the decorator for Spaces
257
  def predict_full(model, model_path, use_mbd, text, melody, duration, topk, topp, temperature, cfg_coef):
258
  # Initialize Predictor *INSIDE* the function
259
  predictor = Predictor(model)
260
-
261
  task_id, (wav, diffusion_wav) = predictor.predict( # Unpack directly!
262
  text=text,
263
  melody=melody,
@@ -268,11 +250,9 @@ def predict_full(model, model_path, use_mbd, text, melody, duration, topk, topp,
268
  temperature=temperature,
269
  cfg_coef=cfg_coef,
270
  )
271
-
272
  # Save and return audio files
273
  wav_paths = []
274
  video_paths = []
275
-
276
  # Save standard output
277
  with NamedTemporaryFile("wb", suffix=".wav", delete=False) as file:
278
  audio_write(
@@ -285,8 +265,6 @@ def predict_full(model, model_path, use_mbd, text, melody, duration, topk, topp,
285
  video_paths.append(video_path)
286
  file_cleaner.add(file.name)
287
  file_cleaner.add(video_path)
288
-
289
-
290
  # Save MBD output if used
291
  if diffusion_wav is not None:
292
  with NamedTemporaryFile("wb", suffix=".wav", delete=False) as file:
@@ -300,31 +278,24 @@ def predict_full(model, model_path, use_mbd, text, melody, duration, topk, topp,
300
  video_paths.append(video_path)
301
  file_cleaner.add(file.name)
302
  file_cleaner.add(video_path)
303
-
304
  # Shutdown predictor to prevent hanging processes!
305
-
306
  if not predictor.is_daemon: # Important!
307
  predictor.shutdown()
308
-
309
  if use_mbd:
310
  return video_paths[0], wav_paths[0], video_paths[1], wav_paths[1]
311
  return video_paths[0], wav_paths[0], None, None
312
 
313
-
314
  def toggle_audio_src(choice):
315
  if choice == "mic":
316
  return gr.update(sources="microphone", value=None, label="Microphone")
317
  else:
318
  return gr.update(sources="upload", value=None, label="File")
319
 
320
-
321
  def toggle_diffusion(choice):
322
  if choice == "MultiBand_Diffusion":
323
  return [gr.update(visible=True)] * 2
324
  else:
325
  return [gr.update(visible=False)] * 2
326
- # --- Gradio UI ---
327
-
328
 
329
  def ui_full(launch_kwargs):
330
  with gr.Blocks() as interface:
@@ -475,7 +446,6 @@ def ui_full(launch_kwargs):
475
 
476
  interface.queue().launch(**launch_kwargs)
477
 
478
- # --- Main Entry Point ---
479
  if __name__ == '__main__':
480
  parser = argparse.ArgumentParser()
481
  parser.add_argument(
@@ -502,12 +472,9 @@ if __name__ == '__main__':
502
  parser.add_argument(
503
  '--share', action='store_true', help='Share the gradio UI'
504
  )
505
-
506
  args = parser.parse_args()
507
-
508
  launch_kwargs = {}
509
  launch_kwargs['server_name'] = args.listen
510
-
511
  if args.username and args.password:
512
  launch_kwargs['auth'] = (args.username, args.password)
513
  if args.server_port:
@@ -516,7 +483,6 @@ if __name__ == '__main__':
516
  launch_kwargs['inbrowser'] = args.inbrowser
517
  if args.share:
518
  launch_kwargs['share'] = args.share
519
-
520
  logging.basicConfig(level=logging.INFO, stream=sys.stderr)
521
  # Added predictor shutdown
522
  try:
 
8
  import time
9
  import typing as tp
10
  from tempfile import NamedTemporaryFile, gettempdir
 
11
  from einops import rearrange
12
  import torch
13
  import gradio as gr
 
14
  from audiocraft.data.audio_utils import convert_audio
15
  from audiocraft.data.audio import audio_write
16
  from audiocraft.models.encodec import InterleaveStereoCompressionModel
 
18
  import multiprocessing as mp
19
  import warnings
20
 
21
+ os.putenv("HF_HUB_ENABLE_HF_TRANSFER","1")
22
+ os.environ["SAFETENSORS_FAST_GPU"] = "1"
23
+
24
+ torch.backends.cuda.matmul.allow_tf32 = False
25
+ torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = False
26
+ torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False
27
+ torch.backends.cudnn.allow_tf32 = False
28
+ torch.backends.cudnn.deterministic = False
29
+ torch.backends.cudnn.benchmark = False
30
+ # torch.backends.cuda.preferred_blas_library="cublas"
31
+ # torch.backends.cuda.preferred_linalg_library="cusolver"
32
+ torch.set_float32_matmul_precision("highest")
33
 
34
  class FileCleaner:
35
  def __init__(self, file_lifetime: float = 3600):
36
  self.file_lifetime = file_lifetime
37
  self.files = []
 
38
  def add(self, path: tp.Union[str, Path]):
39
  self._cleanup()
40
  self.files.append((time.time(), Path(path)))
 
41
  def _cleanup(self):
42
  now = time.time()
43
  for time_added, path in list(self.files):
 
47
  self.files.pop(0)
48
  else:
49
  break
50
+
51
  file_cleaner = FileCleaner()
52
 
 
53
  def convert_wav_to_mp4(wav_path, output_path=None):
54
  """Converts a WAV file to a waveform MP4 video using ffmpeg."""
55
  if output_path is None:
 
69
  "-preset", "fast", # Important, don't do veryslow.
70
  str(output_path),
71
  ]
 
72
  process = sp.run(command, capture_output=True, text=True, check=True)
73
  return str(output_path)
 
74
  except sp.CalledProcessError as e:
75
  print(f"Error in ffmpeg conversion: {e}")
76
  print(f"ffmpeg stdout: {e.stdout}")
77
  print(f"ffmpeg stderr: {e.stderr}")
78
  raise # Re-raise the exception to be caught by Gradio
79
 
 
 
 
80
  def model_worker(model_name: str, task_queue: mp.Queue, result_queue: mp.Queue):
81
  """
82
  Persistent worker process (used when NOT running as a daemon).
 
85
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
86
  model = MusicGen.get_pretrained(model_name, device=device)
87
  mbd = MultiBandDiffusion.get_mbd_musicgen(device=device)
 
88
  while True:
89
  task = task_queue.get()
90
  if task is None:
91
  break
 
92
  task_id, text, melody, duration, use_diffusion, gen_params = task
 
93
  try:
94
  model.set_generation_params(duration=duration, **gen_params)
95
  target_sr = model.sample_rate
 
102
  melody_tensor = melody_tensor.unsqueeze(0)
103
  melody_tensor = melody_tensor[..., :int(sr * duration)]
104
  processed_melody = convert_audio(melody_tensor, sr, target_sr, target_ac)
 
105
  if processed_melody is not None:
106
  output, tokens = model.generate_with_chroma(
107
  descriptions=[text],
 
112
  )
113
  else:
114
  output, tokens = model.generate([text], progress=True, return_tokens=True)
 
115
  output = output.detach().cpu()
 
116
  if use_diffusion:
117
  if isinstance(model.compression_model, InterleaveStereoCompressionModel):
118
  left, right = model.compression_model.get_left_right_codes(tokens)
 
125
  result_queue.put((task_id, (output, outputs_diffusion)))
126
  else:
127
  result_queue.put((task_id, (output, None)))
 
128
  except Exception as e:
129
  result_queue.put((task_id, e))
 
130
  except Exception as e:
131
  result_queue.put((-1, e))
132
 
 
 
 
133
  class Predictor:
134
  def __init__(self, model_name: str):
135
  self.model_name = model_name
 
181
  melody_tensor = melody_tensor.unsqueeze(0)
182
  melody_tensor = melody_tensor[..., :int(sr * duration)]
183
  processed_melody = convert_audio(melody_tensor, sr, target_sr, target_ac)
 
184
  if processed_melody is not None:
185
  output, tokens = self.model.generate_with_chroma(
186
  descriptions=[text],
 
191
  )
192
  else:
193
  output, tokens = self.model.generate([text], progress=True, return_tokens=True)
 
194
  output = output.detach().cpu()
 
195
  if use_diffusion:
196
  if isinstance(self.model.compression_model, InterleaveStereoCompressionModel):
197
  left, right = self.model.compression_model.get_left_right_codes(tokens)
 
204
  return task_id, (output, outputs_diffusion) #Return the task id.
205
  else:
206
  return task_id, (output, None)
 
 
207
  except Exception as e:
208
  return task_id, e
 
209
  else:
210
  # Use the multiprocessing queue (multi-process mode)
211
  self.current_task_id += 1
 
224
  result_task_id, result = self.result_queue.get()
225
  if result_task_id == task_id:
226
  break # Found the correct result
 
227
  if isinstance(result, Exception):
228
  raise result
229
  return result
 
234
  self.task_queue.put(None)
235
  self.process.join()
236
 
 
237
  _default_model_name = "facebook/musicgen-melody"
238
 
239
  @spaces.GPU(duration=90) # Use the decorator for Spaces
240
  def predict_full(model, model_path, use_mbd, text, melody, duration, topk, topp, temperature, cfg_coef):
241
  # Initialize Predictor *INSIDE* the function
242
  predictor = Predictor(model)
 
243
  task_id, (wav, diffusion_wav) = predictor.predict( # Unpack directly!
244
  text=text,
245
  melody=melody,
 
250
  temperature=temperature,
251
  cfg_coef=cfg_coef,
252
  )
 
253
  # Save and return audio files
254
  wav_paths = []
255
  video_paths = []
 
256
  # Save standard output
257
  with NamedTemporaryFile("wb", suffix=".wav", delete=False) as file:
258
  audio_write(
 
265
  video_paths.append(video_path)
266
  file_cleaner.add(file.name)
267
  file_cleaner.add(video_path)
 
 
268
  # Save MBD output if used
269
  if diffusion_wav is not None:
270
  with NamedTemporaryFile("wb", suffix=".wav", delete=False) as file:
 
278
  video_paths.append(video_path)
279
  file_cleaner.add(file.name)
280
  file_cleaner.add(video_path)
 
281
  # Shutdown predictor to prevent hanging processes!
 
282
  if not predictor.is_daemon: # Important!
283
  predictor.shutdown()
 
284
  if use_mbd:
285
  return video_paths[0], wav_paths[0], video_paths[1], wav_paths[1]
286
  return video_paths[0], wav_paths[0], None, None
287
 
 
288
  def toggle_audio_src(choice):
289
  if choice == "mic":
290
  return gr.update(sources="microphone", value=None, label="Microphone")
291
  else:
292
  return gr.update(sources="upload", value=None, label="File")
293
 
 
294
  def toggle_diffusion(choice):
295
  if choice == "MultiBand_Diffusion":
296
  return [gr.update(visible=True)] * 2
297
  else:
298
  return [gr.update(visible=False)] * 2
 
 
299
 
300
  def ui_full(launch_kwargs):
301
  with gr.Blocks() as interface:
 
446
 
447
  interface.queue().launch(**launch_kwargs)
448
 
 
449
  if __name__ == '__main__':
450
  parser = argparse.ArgumentParser()
451
  parser.add_argument(
 
472
  parser.add_argument(
473
  '--share', action='store_true', help='Share the gradio UI'
474
  )
 
475
  args = parser.parse_args()
 
476
  launch_kwargs = {}
477
  launch_kwargs['server_name'] = args.listen
 
478
  if args.username and args.password:
479
  launch_kwargs['auth'] = (args.username, args.password)
480
  if args.server_port:
 
483
  launch_kwargs['inbrowser'] = args.inbrowser
484
  if args.share:
485
  launch_kwargs['share'] = args.share
 
486
  logging.basicConfig(level=logging.INFO, stream=sys.stderr)
487
  # Added predictor shutdown
488
  try: