ford442 commited on
Commit
0b90fab
·
verified ·
1 Parent(s): 16a9f66

gemini edit

Browse files
Files changed (1) hide show
  1. demos/musicgen_app.py +185 -337
demos/musicgen_app.py CHANGED
@@ -1,12 +1,3 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
-
4
- # This source code is licensed under the license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
- # Updated to account for UI changes from https://github.com/rkfg/audiocraft/blob/long/app.py
8
- # also released under the MIT license.
9
- import spaces
10
  import argparse
11
  from concurrent.futures import ProcessPoolExecutor
12
  import logging
@@ -14,7 +5,6 @@ import os
14
  from pathlib import Path
15
  import subprocess as sp
16
  import sys
17
- from tempfile import NamedTemporaryFile
18
  import time
19
  import typing as tp
20
  import warnings
@@ -27,39 +17,11 @@ from audiocraft.data.audio_utils import convert_audio
27
  from audiocraft.data.audio import audio_write
28
  from audiocraft.models.encodec import InterleaveStereoCompressionModel
29
  from audiocraft.models import MusicGen, MultiBandDiffusion
 
30
 
 
31
 
32
- MODEL = None # Last used model
33
- SPACE_ID = os.environ.get('SPACE_ID', '')
34
- IS_BATCHED = "facebook/MusicGen" in SPACE_ID or 'musicgen-internal/musicgen_dev' in SPACE_ID
35
- print(IS_BATCHED)
36
- MAX_BATCH_SIZE = 12
37
- BATCHED_DURATION = 15
38
- INTERRUPTING = False
39
- MBD = None
40
- # We have to wrap subprocess call to clean a bit the log when using gr.make_waveform
41
- _old_call = sp.call
42
-
43
-
44
- def _call_nostderr(*args, **kwargs):
45
- # Avoid ffmpeg vomiting on the logs.
46
- kwargs['stderr'] = sp.DEVNULL
47
- kwargs['stdout'] = sp.DEVNULL
48
- _old_call(*args, **kwargs)
49
-
50
-
51
- sp.call = _call_nostderr
52
- # Preallocating the pool of processes.
53
- pool = ProcessPoolExecutor(4)
54
- pool.__enter__()
55
-
56
-
57
- def interrupt():
58
- global INTERRUPTING
59
- INTERRUPTING = True
60
-
61
-
62
- class FileCleaner:
63
  def __init__(self, file_lifetime: float = 3600):
64
  self.file_lifetime = file_lifetime
65
  self.files = []
@@ -77,13 +39,9 @@ class FileCleaner:
77
  self.files.pop(0)
78
  else:
79
  break
80
-
81
-
82
  file_cleaner = FileCleaner()
83
 
84
-
85
- def make_waveform(*args, **kwargs):
86
- # Further remove some warnings.
87
  be = time.time()
88
  with warnings.catch_warnings():
89
  warnings.simplefilter('ignore')
@@ -91,139 +49,175 @@ def make_waveform(*args, **kwargs):
91
  print("Make a video took", time.time() - be)
92
  return out
93
 
 
94
 
95
- def load_model(version='facebook/musicgen-melody'):
96
- global MODEL
97
- print("Loading model", version)
98
- if MODEL is None or MODEL.name != version:
99
- del MODEL
100
- MODEL = None # in case loading would crash
101
- MODEL = MusicGen.get_pretrained(version)
 
102
 
 
 
 
 
103
 
104
- def load_diffusion():
105
- global MBD
106
- if MBD is None:
107
- print("loading MBD")
108
- MBD = MultiBandDiffusion.get_mbd_musicgen()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
 
110
- @spaces.GPU(duration=65)
111
- def _do_predictions(texts, melodies, duration, progress=False, gradio_progress=None, **gen_kwargs):
112
- MODEL.set_generation_params(duration=duration, **gen_kwargs)
113
- print("new batch", len(texts), texts, [None if m is None else (m[0], m[1].shape) for m in melodies])
114
- be = time.time()
115
- processed_melodies = []
116
- target_sr = 32000
117
- target_ac = 1
118
- for melody in melodies:
119
- if melody is None:
120
- processed_melodies.append(None)
121
- else:
122
- sr, melody = melody[0], torch.from_numpy(melody[1]).to(MODEL.device).float().t()
123
- if melody.dim() == 1:
124
- melody = melody[None]
125
- melody = melody[..., :int(sr * duration)]
126
- melody = convert_audio(melody, sr, target_sr, target_ac)
127
- processed_melodies.append(melody)
128
 
129
- try:
130
- if any(m is not None for m in processed_melodies):
131
- outputs = MODEL.generate_with_chroma(
132
- descriptions=texts,
133
- melody_wavs=processed_melodies,
134
- melody_sample_rate=target_sr,
135
- progress=progress,
136
- return_tokens=USE_DIFFUSION
137
- )
138
- else:
139
- outputs = MODEL.generate(texts, progress=progress, return_tokens=USE_DIFFUSION)
140
- except RuntimeError as e:
141
- raise gr.Error("Error while generating " + e.args[0])
142
- if USE_DIFFUSION:
143
- if gradio_progress is not None:
144
- gradio_progress(1, desc='Running MultiBandDiffusion...')
145
- tokens = outputs[1]
146
- if isinstance(MODEL.compression_model, InterleaveStereoCompressionModel):
147
- left, right = MODEL.compression_model.get_left_right_codes(tokens)
148
- tokens = torch.cat([left, right])
149
- outputs_diffusion = MBD.tokens_to_wav(tokens)
150
- if isinstance(MODEL.compression_model, InterleaveStereoCompressionModel):
151
- assert outputs_diffusion.shape[1] == 1 # output is mono
152
- outputs_diffusion = rearrange(outputs_diffusion, '(s b) c t -> b (s c) t', s=2)
153
- outputs = torch.cat([outputs[0], outputs_diffusion], dim=0)
154
- outputs = outputs.detach().cpu().float()
155
- pending_videos = []
156
- out_wavs = []
157
- for output in outputs:
158
  with NamedTemporaryFile("wb", suffix=".wav", delete=False) as file:
159
  audio_write(
160
- file.name, output, MODEL.sample_rate, strategy="loudness",
161
- loudness_headroom_db=16, loudness_compressor=True, add_suffix=False)
162
- pending_videos.append(pool.submit(make_waveform, file.name))
163
- out_wavs.append(file.name)
 
164
  file_cleaner.add(file.name)
165
- out_videos = [pending_video.result() for pending_video in pending_videos]
166
- for video in out_videos:
167
- file_cleaner.add(video)
168
- print("batch finished", len(texts), time.time() - be)
169
- print("Tempfiles currently stored: ", len(file_cleaner.files))
170
- return out_videos, out_wavs
171
-
172
-
173
- def predict_batched(texts, melodies):
174
- max_text_length = 512
175
- texts = [text[:max_text_length] for text in texts]
176
- load_model('facebook/musicgen-stereo-melody')
177
- res = _do_predictions(texts, melodies, BATCHED_DURATION)
178
- return res
179
-
180
-
181
- def predict_full(model, model_path, decoder, text, melody, duration, topk, topp, temperature, cfg_coef, progress=gr.Progress()):
182
- global INTERRUPTING
183
- global USE_DIFFUSION
184
- INTERRUPTING = False
185
- progress(0, desc="Loading model...")
186
- model_path = model_path.strip()
187
- if model_path:
188
- if not Path(model_path).exists():
189
- raise gr.Error(f"Model path {model_path} doesn't exist.")
190
- if not Path(model_path).is_dir():
191
- raise gr.Error(f"Model path {model_path} must be a folder containing "
192
- "state_dict.bin and compression_state_dict_.bin.")
193
- model = model_path
194
- if temperature < 0:
195
- raise gr.Error("Temperature must be >= 0.")
196
- if topk < 0:
197
- raise gr.Error("Topk must be non-negative.")
198
- if topp < 0:
199
- raise gr.Error("Topp must be non-negative.")
200
-
201
- topk = int(topk)
202
- if decoder == "MultiBand_Diffusion":
203
- USE_DIFFUSION = True
204
- progress(0, desc="Loading diffusion model...")
205
- load_diffusion()
206
- else:
207
- USE_DIFFUSION = False
208
- load_model(model)
209
-
210
- max_generated = 0
211
-
212
- def _progress(generated, to_generate):
213
- nonlocal max_generated
214
- max_generated = max(generated, max_generated)
215
- progress((min(max_generated, to_generate), to_generate))
216
- if INTERRUPTING:
217
- raise gr.Error("Interrupted.")
218
- MODEL.set_custom_progress_callback(_progress)
219
 
220
- videos, wavs = _do_predictions(
221
- [text], [melody], duration, progress=True,
222
- top_k=topk, top_p=topp, temperature=temperature, cfg_coef=cfg_coef,
223
- gradio_progress=progress)
224
- if USE_DIFFUSION:
225
- return videos[0], wavs[0], videos[1], wavs[1]
226
- return videos[0], wavs[0], None, None
227
 
228
 
229
  def toggle_audio_src(choice):
@@ -238,7 +232,7 @@ def toggle_diffusion(choice):
238
  return [gr.update(visible=True)] * 2
239
  else:
240
  return [gr.update(visible=False)] * 2
241
-
242
 
243
  def ui_full(launch_kwargs):
244
  with gr.Blocks() as interface:
@@ -261,16 +255,15 @@ def ui_full(launch_kwargs):
261
  interactive=True, elem_id="melody-input")
262
  with gr.Row():
263
  submit = gr.Button("Submit")
264
- # Adapted from https://github.com/rkfg/audiocraft/blob/long/app.py, MIT license.
265
- _ = gr.Button("Interrupt").click(fn=interrupt, queue=False)
266
  with gr.Row():
267
  model = gr.Radio(["facebook/musicgen-melody", "facebook/musicgen-medium", "facebook/musicgen-small",
268
  "facebook/musicgen-large", "facebook/musicgen-melody-large",
269
  "facebook/musicgen-stereo-small", "facebook/musicgen-stereo-medium",
270
  "facebook/musicgen-stereo-melody", "facebook/musicgen-stereo-large",
271
  "facebook/musicgen-stereo-melody-large"],
272
- label="Model", value="facebook/musicgen-stereo-melody", interactive=True)
273
- model_path = gr.Text(label="Model Path (custom models)")
274
  with gr.Row():
275
  decoder = gr.Radio(["Default", "MultiBand_Diffusion"],
276
  label="Decoder", value="Default", interactive=True)
@@ -284,12 +277,16 @@ def ui_full(launch_kwargs):
284
  with gr.Column():
285
  output = gr.Video(label="Generated Music")
286
  audio_output = gr.Audio(label="Generated Music (wav)", type='filepath')
287
- diffusion_output = gr.Video(label="MultiBand Diffusion Decoder")
288
- audio_diffusion = gr.Audio(label="MultiBand Diffusion Decoder (wav)", type='filepath')
289
- submit.click(toggle_diffusion, decoder, [diffusion_output, audio_diffusion], queue=False,
290
- show_progress=False).then(predict_full, inputs=[model, model_path, decoder, text, melody, duration, topk, topp,
291
- temperature, cfg_coef],
292
- outputs=[output, audio_output, diffusion_output, audio_diffusion])
 
 
 
 
293
  radio.change(toggle_audio_src, radio, [melody], queue=False, show_progress=False)
294
 
295
  gr.Examples(
@@ -298,37 +295,37 @@ def ui_full(launch_kwargs):
298
  [
299
  "An 80s driving pop song with heavy drums and synth pads in the background",
300
  "./assets/bach.mp3",
301
- "facebook/musicgen-stereo-melody",
302
  "Default"
303
  ],
304
  [
305
  "A cheerful country song with acoustic guitars",
306
  "./assets/bolero_ravel.mp3",
307
- "facebook/musicgen-stereo-melody",
308
  "Default"
309
  ],
310
  [
311
  "90s rock song with electric guitar and heavy drums",
312
  None,
313
- "facebook/musicgen-stereo-medium",
314
  "Default"
315
  ],
316
  [
317
  "a light and cheerly EDM track, with syncopated drums, aery pads, and strong emotions",
318
  "./assets/bach.mp3",
319
- "facebook/musicgen-stereo-melody",
320
  "Default"
321
  ],
322
  [
323
  "lofi slow bpm electro chill with organic samples",
324
  None,
325
- "facebook/musicgen-stereo-medium",
326
  "Default"
327
  ],
328
  [
329
  "Punk rock with loud drum and power guitar",
330
  None,
331
- "facebook/musicgen-stereo-medium",
332
  "MultiBand_Diffusion"
333
  ],
334
  ],
@@ -373,153 +370,4 @@ def ui_full(launch_kwargs):
373
  for crashes, snares etc.
374
  2. Use [MultiBand Diffusion](https://arxiv.org/abs/2308.02560). Should improve the audio quality,
375
  at an extra computational cost. When this is selected, we provide both the GAN based decoded
376
- audio, and the one obtained with MBD.
377
-
378
- See [github.com/facebookresearch/audiocraft](https://github.com/facebookresearch/audiocraft/blob/main/docs/MUSICGEN.md)
379
- for more details.
380
- """
381
- )
382
-
383
- interface.queue().launch(**launch_kwargs)
384
-
385
-
386
- def ui_batched(launch_kwargs):
387
- with gr.Blocks() as demo:
388
- gr.Markdown(
389
- """
390
- # MusicGen
391
-
392
- This is the demo for [MusicGen](https://github.com/facebookresearch/audiocraft/blob/main/docs/MUSICGEN.md),
393
- a simple and controllable model for music generation
394
- presented at: ["Simple and Controllable Music Generation"](https://huggingface.co/papers/2306.05284).
395
- <br/>
396
- <a href="https://huggingface.co/spaces/facebook/MusicGen?duplicate=true"
397
- style="display: inline-block;margin-top: .5em;margin-right: .25em;" target="_blank">
398
- <img style="margin-bottom: 0em;display: inline;margin-top: -.25em;"
399
- src="https://bit.ly/3gLdBN6" alt="Duplicate Space"></a>
400
- for longer sequences, more control and no queue.</p>
401
- """
402
- )
403
- with gr.Row():
404
- with gr.Column():
405
- with gr.Row():
406
- text = gr.Text(label="Describe your music", lines=2, interactive=True)
407
- with gr.Column():
408
- radio = gr.Radio(["file", "mic"], value="file",
409
- label="Condition on a melody (optional) File or Mic")
410
- melody = gr.Audio(sources="upload", type="numpy", label="File",
411
- interactive=True, elem_id="melody-input")
412
- with gr.Row():
413
- submit = gr.Button("Generate")
414
- with gr.Column():
415
- output = gr.Video(label="Generated Music")
416
- audio_output = gr.Audio(label="Generated Music (wav)", type='filepath')
417
- submit.click(predict_batched, inputs=[text, melody],
418
- outputs=[output, audio_output], batch=True, max_batch_size=MAX_BATCH_SIZE)
419
- radio.change(toggle_audio_src, radio, [melody], queue=False, show_progress=False)
420
- gr.Examples(
421
- fn=predict_batched,
422
- examples=[
423
- [
424
- "An 80s driving pop song with heavy drums and synth pads in the background",
425
- "./assets/bach.mp3",
426
- ],
427
- [
428
- "A cheerful country song with acoustic guitars",
429
- "./assets/bolero_ravel.mp3",
430
- ],
431
- [
432
- "90s rock song with electric guitar and heavy drums",
433
- None,
434
- ],
435
- [
436
- "a light and cheerly EDM track, with syncopated drums, aery pads, and strong emotions bpm: 130",
437
- "./assets/bach.mp3",
438
- ],
439
- [
440
- "lofi slow bpm electro chill with organic samples",
441
- None,
442
- ],
443
- ],
444
- inputs=[text, melody],
445
- outputs=[output]
446
- )
447
- gr.Markdown("""
448
- ### More details
449
-
450
- The model will generate 15 seconds of audio based on the description you provided.
451
- The model was trained with description from a stock music catalog, descriptions that will work best
452
- should include some level of details on the instruments present, along with some intended use case
453
- (e.g. adding "perfect for a commercial" can somehow help).
454
-
455
- You can optionally provide a reference audio from which a broad melody will be extracted.
456
- The model will then try to follow both the description and melody provided.
457
- For best results, the melody should be 30 seconds long (I know, the samples we provide are not...)
458
-
459
- You can access more control (longer generation, more models etc.) by clicking
460
- the <a href="https://huggingface.co/spaces/facebook/MusicGen?duplicate=true"
461
- style="display: inline-block;margin-top: .5em;margin-right: .25em;" target="_blank">
462
- <img style="margin-bottom: 0em;display: inline;margin-top: -.25em;"
463
- src="https://bit.ly/3gLdBN6" alt="Duplicate Space"></a>
464
- (you will then need a paid GPU from HuggingFace).
465
- If you have a GPU, you can run the gradio demo locally (click the link to our repo below for more info).
466
- Finally, you can get a GPU for free from Google
467
- and run the demo in [a Google Colab.](https://ai.honu.io/red/musicgen-colab).
468
-
469
- See [github.com/facebookresearch/audiocraft](https://github.com/facebookresearch/audiocraft/blob/main/docs/MUSICGEN.md)
470
- for more details. All samples are generated with the `stereo-melody` model.
471
- """)
472
-
473
- demo.queue(max_size=8 * 4).launch(**launch_kwargs)
474
-
475
-
476
- if __name__ == "__main__":
477
- parser = argparse.ArgumentParser()
478
- parser.add_argument(
479
- '--listen',
480
- type=str,
481
- default='0.0.0.0' if 'SPACE_ID' in os.environ else '127.0.0.1',
482
- help='IP to listen on for connections to Gradio',
483
- )
484
- parser.add_argument(
485
- '--username', type=str, default='', help='Username for authentication'
486
- )
487
- parser.add_argument(
488
- '--password', type=str, default='', help='Password for authentication'
489
- )
490
- parser.add_argument(
491
- '--server_port',
492
- type=int,
493
- default=0,
494
- help='Port to run the server listener on',
495
- )
496
- parser.add_argument(
497
- '--inbrowser', action='store_true', help='Open in browser'
498
- )
499
- parser.add_argument(
500
- '--share', action='store_true', help='Share the gradio UI'
501
- )
502
-
503
- args = parser.parse_args()
504
-
505
- launch_kwargs = {}
506
- launch_kwargs['server_name'] = args.listen
507
-
508
- if args.username and args.password:
509
- launch_kwargs['auth'] = (args.username, args.password)
510
- if args.server_port:
511
- launch_kwargs['server_port'] = args.server_port
512
- if args.inbrowser:
513
- launch_kwargs['inbrowser'] = args.inbrowser
514
- if args.share:
515
- launch_kwargs['share'] = args.share
516
-
517
- logging.basicConfig(level=logging.INFO, stream=sys.stderr)
518
-
519
- # Show the interface
520
- if IS_BATCHED:
521
- global USE_DIFFUSION
522
- USE_DIFFUSION = False
523
- ui_batched(launch_kwargs)
524
- else:
525
- ui_full(launch_kwargs)
 
 
 
 
 
 
 
 
 
 
1
  import argparse
2
  from concurrent.futures import ProcessPoolExecutor
3
  import logging
 
5
  from pathlib import Path
6
  import subprocess as sp
7
  import sys
 
8
  import time
9
  import typing as tp
10
  import warnings
 
17
  from audiocraft.data.audio import audio_write
18
  from audiocraft.models.encodec import InterleaveStereoCompressionModel
19
  from audiocraft.models import MusicGen, MultiBandDiffusion
20
+ import multiprocessing as mp
21
 
22
+ # --- Utility Functions and Classes ---
23
 
24
+ class FileCleaner: # Unchanged from previous example, included for completeness
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  def __init__(self, file_lifetime: float = 3600):
26
  self.file_lifetime = file_lifetime
27
  self.files = []
 
39
  self.files.pop(0)
40
  else:
41
  break
 
 
42
  file_cleaner = FileCleaner()
43
 
44
+ def make_waveform(*args, **kwargs): # Unchanged
 
 
45
  be = time.time()
46
  with warnings.catch_warnings():
47
  warnings.simplefilter('ignore')
 
49
  print("Make a video took", time.time() - be)
50
  return out
51
 
52
+ # --- Worker Process ---
53
 
54
+ def model_worker(model_name: str, task_queue: mp.Queue, result_queue: mp.Queue):
55
+ """
56
+ Persistent worker process that loads the model and handles prediction tasks.
57
+ """
58
+ try:
59
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
60
+ model = MusicGen.get_pretrained(model_name, device=device)
61
+ mbd = MultiBandDiffusion.get_mbd_musicgen(device=device) # Load MBD here too
62
 
63
+ while True:
64
+ task = task_queue.get()
65
+ if task is None: # Sentinel value to exit
66
+ break
67
 
68
+ task_id, text, melody, duration, use_diffusion, gen_params = task
69
+
70
+ try:
71
+ model.set_generation_params(duration=duration, **gen_params)
72
+ target_sr = model.sample_rate
73
+ target_ac = 1
74
+ processed_melody = None
75
+ if melody:
76
+ sr, melody_data = melody
77
+ melody_tensor = torch.from_numpy(melody_data).to(device).float().t()
78
+ if melody_tensor.ndim == 1:
79
+ melody_tensor = melody_tensor.unsqueeze(0)
80
+ melody_tensor = melody_tensor[..., :int(sr * duration)]
81
+ processed_melody = convert_audio(melody_tensor, sr, target_sr, target_ac)
82
+
83
+ if processed_melody is not None:
84
+ output, tokens = model.generate_with_chroma(
85
+ descriptions=[text],
86
+ melody_wavs=[processed_melody],
87
+ melody_sample_rate=target_sr,
88
+ progress=True,
89
+ return_tokens=True
90
+ )
91
+ else:
92
+ output, tokens = model.generate([text], progress=True, return_tokens=True)
93
+
94
+ output = output.detach().cpu()
95
+
96
+ if use_diffusion:
97
+ if isinstance(model.compression_model, InterleaveStereoCompressionModel):
98
+ left, right = model.compression_model.get_left_right_codes(tokens)
99
+ tokens = torch.cat([left, right])
100
+ outputs_diffusion = mbd.tokens_to_wav(tokens)
101
+ if isinstance(model.compression_model, InterleaveStereoCompressionModel):
102
+ assert outputs_diffusion.shape[1] == 1 # output is mono
103
+ outputs_diffusion = rearrange(outputs_diffusion, '(s b) c t -> b (s c) t', s=2)
104
+ outputs_diffusion = outputs_diffusion.detach().cpu()
105
+ result_queue.put((task_id, (output, outputs_diffusion))) # Send BOTH results.
106
+ else:
107
+ result_queue.put((task_id, (output, None))) # Send back the result
108
+
109
+ except Exception as e:
110
+ result_queue.put((task_id, e)) # Send back the exception
111
+
112
+ except Exception as e:
113
+ result_queue.put((-1,e)) #Fatal error on loading.
114
+
115
+ # --- Gradio Interface Functions ---
116
+
117
+ class Predictor:
118
+ def __init__(self, model_name: str):
119
+ self.task_queue = mp.Queue()
120
+ self.result_queue = mp.Queue()
121
+ self.process = mp.Process(target=model_worker, args=(model_name, self.task_queue, self.result_queue))
122
+ self.process.start()
123
+ self.current_task_id = 0
124
+ self._check_initialization()
125
+
126
+
127
+ def _check_initialization(self):
128
+ """Check if the worker process initialized successfully."""
129
+ # Give it some time to either load or report failure.
130
+ time.sleep(2)
131
+ try:
132
+ task_id, result = self.result_queue.get(timeout=3) # Get result from model_worker
133
+
134
+ if isinstance(result, Exception):
135
+ if task_id == -1:
136
+ raise RuntimeError("Model loading failed in worker process.") from result
137
+ except:
138
+ pass # Expected if model loads fast enough
139
+
140
+ def predict(self, text, melody, duration, use_diffusion, **gen_params):
141
+ """
142
+ Submits a prediction task to the worker process.
143
+ """
144
+ self.current_task_id += 1
145
+ task = (self.current_task_id, text, melody, duration, use_diffusion, gen_params)
146
+ self.task_queue.put(task)
147
+ return self.current_task_id
148
+
149
+ def get_result(self, task_id):
150
+ """
151
+ Retrieves the result of a prediction task. Blocks until the result is available.
152
+ """
153
+ while True: # Loop to get the correct task
154
+ result_task_id, result = self.result_queue.get()
155
+ if result_task_id == task_id:
156
+ if isinstance(result, Exception):
157
+ raise result # Re-raise the exception in the main process
158
+ return result # (wav, diffusion_wav) or (wav, None)
159
+
160
+ def shutdown(self):
161
+ """
162
+ Shuts down the worker process.
163
+ """
164
+ self.task_queue.put(None) # Send sentinel value to stop the worker
165
+ self.process.join() # Wait for the process to terminate
166
+
167
+
168
+ # Global predictor instance
169
+ _predictor = None
170
+
171
+ def get_predictor(model_name:str = 'facebook/musicgen-melody'):
172
+ global _predictor
173
+ if _predictor is None:
174
+ _predictor = Predictor(model_name)
175
+ return _predictor
176
+
177
+ def predict_full(model, model_path, use_mbd, text, melody, duration, topk, topp, temperature, cfg_coef):
178
+
179
+ predictor = get_predictor(model)
180
+ task_id = predictor.predict(
181
+ text=text,
182
+ melody=melody,
183
+ duration=duration,
184
+ use_diffusion=use_mbd,
185
+ top_k=topk,
186
+ top_p=topp,
187
+ temperature=temperature,
188
+ cfg_coef=cfg_coef,
189
+ )
190
 
191
+ wav, diffusion_wav = predictor.get_result(task_id)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
192
 
193
+ # Save and return audio files
194
+ wav_paths = []
195
+ video_paths = []
196
+
197
+ # Save standard output
198
+ with NamedTemporaryFile("wb", suffix=".wav", delete=False) as file:
199
+ audio_write(
200
+ file.name, wav[0], 32000, strategy="loudness", #hardcoded sample rate
201
+ loudness_headroom_db=16, loudness_compressor=True, add_suffix=False
202
+ )
203
+ wav_paths.append(file.name)
204
+ video_paths.append(make_waveform(file.name)) # Make and clean up video
205
+ file_cleaner.add(file.name)
206
+
207
+ # Save MBD output if used
208
+ if diffusion_wav is not None:
 
 
 
 
 
 
 
 
 
 
 
 
 
209
  with NamedTemporaryFile("wb", suffix=".wav", delete=False) as file:
210
  audio_write(
211
+ file.name, diffusion_wav[0], 32000, strategy="loudness", #hardcoded sample rate
212
+ loudness_headroom_db=16, loudness_compressor=True, add_suffix=False
213
+ )
214
+ wav_paths.append(file.name)
215
+ video_paths.append(make_waveform(file.name)) # Make and clean up video
216
  file_cleaner.add(file.name)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
217
 
218
+ if use_mbd:
219
+ return video_paths[0], wav_paths[0], video_paths[1], wav_paths[1]
220
+ return video_paths[0], wav_paths[0], None, None
 
 
 
 
221
 
222
 
223
  def toggle_audio_src(choice):
 
232
  return [gr.update(visible=True)] * 2
233
  else:
234
  return [gr.update(visible=False)] * 2
235
+ # --- Gradio UI ---
236
 
237
  def ui_full(launch_kwargs):
238
  with gr.Blocks() as interface:
 
255
  interactive=True, elem_id="melody-input")
256
  with gr.Row():
257
  submit = gr.Button("Submit")
258
+ # _ = gr.Button("Interrupt").click(fn=interrupt, queue=False) # Interrupt is now handled implicitly
 
259
  with gr.Row():
260
  model = gr.Radio(["facebook/musicgen-melody", "facebook/musicgen-medium", "facebook/musicgen-small",
261
  "facebook/musicgen-large", "facebook/musicgen-melody-large",
262
  "facebook/musicgen-stereo-small", "facebook/musicgen-stereo-medium",
263
  "facebook/musicgen-stereo-melody", "facebook/musicgen-stereo-large",
264
  "facebook/musicgen-stereo-melody-large"],
265
+ label="Model", value="facebook/musicgen-melody", interactive=True)
266
+ model_path = gr.Text(label="Model Path (custom models)", interactive=False, visible=False) # Keep, but hide
267
  with gr.Row():
268
  decoder = gr.Radio(["Default", "MultiBand_Diffusion"],
269
  label="Decoder", value="Default", interactive=True)
 
277
  with gr.Column():
278
  output = gr.Video(label="Generated Music")
279
  audio_output = gr.Audio(label="Generated Music (wav)", type='filepath')
280
+ diffusion_output = gr.Video(label="MultiBand Diffusion Decoder", visible=False)
281
+ audio_diffusion = gr.Audio(label="MultiBand Diffusion Decoder (wav)", type='filepath', visible=False)
282
+
283
+ submit.click(
284
+ toggle_diffusion, decoder, [diffusion_output, audio_diffusion], queue=False
285
+ ).then(
286
+ predict_full,
287
+ inputs=[model, model_path, decoder, text, melody, duration, topk, topp, temperature, cfg_coef],
288
+ outputs=[output, audio_output, diffusion_output, audio_diffusion]
289
+ )
290
  radio.change(toggle_audio_src, radio, [melody], queue=False, show_progress=False)
291
 
292
  gr.Examples(
 
295
  [
296
  "An 80s driving pop song with heavy drums and synth pads in the background",
297
  "./assets/bach.mp3",
298
+ "facebook/musicgen-melody",
299
  "Default"
300
  ],
301
  [
302
  "A cheerful country song with acoustic guitars",
303
  "./assets/bolero_ravel.mp3",
304
+ "facebook/musicgen-melody",
305
  "Default"
306
  ],
307
  [
308
  "90s rock song with electric guitar and heavy drums",
309
  None,
310
+ "facebook/musicgen-medium",
311
  "Default"
312
  ],
313
  [
314
  "a light and cheerly EDM track, with syncopated drums, aery pads, and strong emotions",
315
  "./assets/bach.mp3",
316
+ "facebook/musicgen-melody",
317
  "Default"
318
  ],
319
  [
320
  "lofi slow bpm electro chill with organic samples",
321
  None,
322
+ "facebook/musicgen-medium",
323
  "Default"
324
  ],
325
  [
326
  "Punk rock with loud drum and power guitar",
327
  None,
328
+ "facebook/musicgen-medium",
329
  "MultiBand_Diffusion"
330
  ],
331
  ],
 
370
  for crashes, snares etc.
371
  2. Use [MultiBand Diffusion](https://arxiv.org/abs/2308.02560). Should improve the audio quality,
372
  at an extra computational cost. When this is selected, we provide both the GAN based decoded
373
+ audio, and the one obtained with MBD.