gemini edit
Browse files- demos/musicgen_app.py +185 -337
demos/musicgen_app.py
CHANGED
@@ -1,12 +1,3 @@
|
|
1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
-
# All rights reserved.
|
3 |
-
|
4 |
-
# This source code is licensed under the license found in the
|
5 |
-
# LICENSE file in the root directory of this source tree.
|
6 |
-
|
7 |
-
# Updated to account for UI changes from https://github.com/rkfg/audiocraft/blob/long/app.py
|
8 |
-
# also released under the MIT license.
|
9 |
-
import spaces
|
10 |
import argparse
|
11 |
from concurrent.futures import ProcessPoolExecutor
|
12 |
import logging
|
@@ -14,7 +5,6 @@ import os
|
|
14 |
from pathlib import Path
|
15 |
import subprocess as sp
|
16 |
import sys
|
17 |
-
from tempfile import NamedTemporaryFile
|
18 |
import time
|
19 |
import typing as tp
|
20 |
import warnings
|
@@ -27,39 +17,11 @@ from audiocraft.data.audio_utils import convert_audio
|
|
27 |
from audiocraft.data.audio import audio_write
|
28 |
from audiocraft.models.encodec import InterleaveStereoCompressionModel
|
29 |
from audiocraft.models import MusicGen, MultiBandDiffusion
|
|
|
30 |
|
|
|
31 |
|
32 |
-
|
33 |
-
SPACE_ID = os.environ.get('SPACE_ID', '')
|
34 |
-
IS_BATCHED = "facebook/MusicGen" in SPACE_ID or 'musicgen-internal/musicgen_dev' in SPACE_ID
|
35 |
-
print(IS_BATCHED)
|
36 |
-
MAX_BATCH_SIZE = 12
|
37 |
-
BATCHED_DURATION = 15
|
38 |
-
INTERRUPTING = False
|
39 |
-
MBD = None
|
40 |
-
# We have to wrap subprocess call to clean a bit the log when using gr.make_waveform
|
41 |
-
_old_call = sp.call
|
42 |
-
|
43 |
-
|
44 |
-
def _call_nostderr(*args, **kwargs):
|
45 |
-
# Avoid ffmpeg vomiting on the logs.
|
46 |
-
kwargs['stderr'] = sp.DEVNULL
|
47 |
-
kwargs['stdout'] = sp.DEVNULL
|
48 |
-
_old_call(*args, **kwargs)
|
49 |
-
|
50 |
-
|
51 |
-
sp.call = _call_nostderr
|
52 |
-
# Preallocating the pool of processes.
|
53 |
-
pool = ProcessPoolExecutor(4)
|
54 |
-
pool.__enter__()
|
55 |
-
|
56 |
-
|
57 |
-
def interrupt():
|
58 |
-
global INTERRUPTING
|
59 |
-
INTERRUPTING = True
|
60 |
-
|
61 |
-
|
62 |
-
class FileCleaner:
|
63 |
def __init__(self, file_lifetime: float = 3600):
|
64 |
self.file_lifetime = file_lifetime
|
65 |
self.files = []
|
@@ -77,13 +39,9 @@ class FileCleaner:
|
|
77 |
self.files.pop(0)
|
78 |
else:
|
79 |
break
|
80 |
-
|
81 |
-
|
82 |
file_cleaner = FileCleaner()
|
83 |
|
84 |
-
|
85 |
-
def make_waveform(*args, **kwargs):
|
86 |
-
# Further remove some warnings.
|
87 |
be = time.time()
|
88 |
with warnings.catch_warnings():
|
89 |
warnings.simplefilter('ignore')
|
@@ -91,139 +49,175 @@ def make_waveform(*args, **kwargs):
|
|
91 |
print("Make a video took", time.time() - be)
|
92 |
return out
|
93 |
|
|
|
94 |
|
95 |
-
def
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
|
|
102 |
|
|
|
|
|
|
|
|
|
103 |
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
109 |
|
110 |
-
|
111 |
-
def _do_predictions(texts, melodies, duration, progress=False, gradio_progress=None, **gen_kwargs):
|
112 |
-
MODEL.set_generation_params(duration=duration, **gen_kwargs)
|
113 |
-
print("new batch", len(texts), texts, [None if m is None else (m[0], m[1].shape) for m in melodies])
|
114 |
-
be = time.time()
|
115 |
-
processed_melodies = []
|
116 |
-
target_sr = 32000
|
117 |
-
target_ac = 1
|
118 |
-
for melody in melodies:
|
119 |
-
if melody is None:
|
120 |
-
processed_melodies.append(None)
|
121 |
-
else:
|
122 |
-
sr, melody = melody[0], torch.from_numpy(melody[1]).to(MODEL.device).float().t()
|
123 |
-
if melody.dim() == 1:
|
124 |
-
melody = melody[None]
|
125 |
-
melody = melody[..., :int(sr * duration)]
|
126 |
-
melody = convert_audio(melody, sr, target_sr, target_ac)
|
127 |
-
processed_melodies.append(melody)
|
128 |
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
tokens = outputs[1]
|
146 |
-
if isinstance(MODEL.compression_model, InterleaveStereoCompressionModel):
|
147 |
-
left, right = MODEL.compression_model.get_left_right_codes(tokens)
|
148 |
-
tokens = torch.cat([left, right])
|
149 |
-
outputs_diffusion = MBD.tokens_to_wav(tokens)
|
150 |
-
if isinstance(MODEL.compression_model, InterleaveStereoCompressionModel):
|
151 |
-
assert outputs_diffusion.shape[1] == 1 # output is mono
|
152 |
-
outputs_diffusion = rearrange(outputs_diffusion, '(s b) c t -> b (s c) t', s=2)
|
153 |
-
outputs = torch.cat([outputs[0], outputs_diffusion], dim=0)
|
154 |
-
outputs = outputs.detach().cpu().float()
|
155 |
-
pending_videos = []
|
156 |
-
out_wavs = []
|
157 |
-
for output in outputs:
|
158 |
with NamedTemporaryFile("wb", suffix=".wav", delete=False) as file:
|
159 |
audio_write(
|
160 |
-
file.name,
|
161 |
-
loudness_headroom_db=16, loudness_compressor=True, add_suffix=False
|
162 |
-
|
163 |
-
|
|
|
164 |
file_cleaner.add(file.name)
|
165 |
-
out_videos = [pending_video.result() for pending_video in pending_videos]
|
166 |
-
for video in out_videos:
|
167 |
-
file_cleaner.add(video)
|
168 |
-
print("batch finished", len(texts), time.time() - be)
|
169 |
-
print("Tempfiles currently stored: ", len(file_cleaner.files))
|
170 |
-
return out_videos, out_wavs
|
171 |
-
|
172 |
-
|
173 |
-
def predict_batched(texts, melodies):
|
174 |
-
max_text_length = 512
|
175 |
-
texts = [text[:max_text_length] for text in texts]
|
176 |
-
load_model('facebook/musicgen-stereo-melody')
|
177 |
-
res = _do_predictions(texts, melodies, BATCHED_DURATION)
|
178 |
-
return res
|
179 |
-
|
180 |
-
|
181 |
-
def predict_full(model, model_path, decoder, text, melody, duration, topk, topp, temperature, cfg_coef, progress=gr.Progress()):
|
182 |
-
global INTERRUPTING
|
183 |
-
global USE_DIFFUSION
|
184 |
-
INTERRUPTING = False
|
185 |
-
progress(0, desc="Loading model...")
|
186 |
-
model_path = model_path.strip()
|
187 |
-
if model_path:
|
188 |
-
if not Path(model_path).exists():
|
189 |
-
raise gr.Error(f"Model path {model_path} doesn't exist.")
|
190 |
-
if not Path(model_path).is_dir():
|
191 |
-
raise gr.Error(f"Model path {model_path} must be a folder containing "
|
192 |
-
"state_dict.bin and compression_state_dict_.bin.")
|
193 |
-
model = model_path
|
194 |
-
if temperature < 0:
|
195 |
-
raise gr.Error("Temperature must be >= 0.")
|
196 |
-
if topk < 0:
|
197 |
-
raise gr.Error("Topk must be non-negative.")
|
198 |
-
if topp < 0:
|
199 |
-
raise gr.Error("Topp must be non-negative.")
|
200 |
-
|
201 |
-
topk = int(topk)
|
202 |
-
if decoder == "MultiBand_Diffusion":
|
203 |
-
USE_DIFFUSION = True
|
204 |
-
progress(0, desc="Loading diffusion model...")
|
205 |
-
load_diffusion()
|
206 |
-
else:
|
207 |
-
USE_DIFFUSION = False
|
208 |
-
load_model(model)
|
209 |
-
|
210 |
-
max_generated = 0
|
211 |
-
|
212 |
-
def _progress(generated, to_generate):
|
213 |
-
nonlocal max_generated
|
214 |
-
max_generated = max(generated, max_generated)
|
215 |
-
progress((min(max_generated, to_generate), to_generate))
|
216 |
-
if INTERRUPTING:
|
217 |
-
raise gr.Error("Interrupted.")
|
218 |
-
MODEL.set_custom_progress_callback(_progress)
|
219 |
|
220 |
-
|
221 |
-
|
222 |
-
|
223 |
-
gradio_progress=progress)
|
224 |
-
if USE_DIFFUSION:
|
225 |
-
return videos[0], wavs[0], videos[1], wavs[1]
|
226 |
-
return videos[0], wavs[0], None, None
|
227 |
|
228 |
|
229 |
def toggle_audio_src(choice):
|
@@ -238,7 +232,7 @@ def toggle_diffusion(choice):
|
|
238 |
return [gr.update(visible=True)] * 2
|
239 |
else:
|
240 |
return [gr.update(visible=False)] * 2
|
241 |
-
|
242 |
|
243 |
def ui_full(launch_kwargs):
|
244 |
with gr.Blocks() as interface:
|
@@ -261,16 +255,15 @@ def ui_full(launch_kwargs):
|
|
261 |
interactive=True, elem_id="melody-input")
|
262 |
with gr.Row():
|
263 |
submit = gr.Button("Submit")
|
264 |
-
#
|
265 |
-
_ = gr.Button("Interrupt").click(fn=interrupt, queue=False)
|
266 |
with gr.Row():
|
267 |
model = gr.Radio(["facebook/musicgen-melody", "facebook/musicgen-medium", "facebook/musicgen-small",
|
268 |
"facebook/musicgen-large", "facebook/musicgen-melody-large",
|
269 |
"facebook/musicgen-stereo-small", "facebook/musicgen-stereo-medium",
|
270 |
"facebook/musicgen-stereo-melody", "facebook/musicgen-stereo-large",
|
271 |
"facebook/musicgen-stereo-melody-large"],
|
272 |
-
label="Model", value="facebook/musicgen-
|
273 |
-
model_path = gr.Text(label="Model Path (custom models)")
|
274 |
with gr.Row():
|
275 |
decoder = gr.Radio(["Default", "MultiBand_Diffusion"],
|
276 |
label="Decoder", value="Default", interactive=True)
|
@@ -284,12 +277,16 @@ def ui_full(launch_kwargs):
|
|
284 |
with gr.Column():
|
285 |
output = gr.Video(label="Generated Music")
|
286 |
audio_output = gr.Audio(label="Generated Music (wav)", type='filepath')
|
287 |
-
diffusion_output = gr.Video(label="MultiBand Diffusion Decoder")
|
288 |
-
audio_diffusion = gr.Audio(label="MultiBand Diffusion Decoder (wav)", type='filepath')
|
289 |
-
|
290 |
-
|
291 |
-
|
292 |
-
|
|
|
|
|
|
|
|
|
293 |
radio.change(toggle_audio_src, radio, [melody], queue=False, show_progress=False)
|
294 |
|
295 |
gr.Examples(
|
@@ -298,37 +295,37 @@ def ui_full(launch_kwargs):
|
|
298 |
[
|
299 |
"An 80s driving pop song with heavy drums and synth pads in the background",
|
300 |
"./assets/bach.mp3",
|
301 |
-
"facebook/musicgen-
|
302 |
"Default"
|
303 |
],
|
304 |
[
|
305 |
"A cheerful country song with acoustic guitars",
|
306 |
"./assets/bolero_ravel.mp3",
|
307 |
-
"facebook/musicgen-
|
308 |
"Default"
|
309 |
],
|
310 |
[
|
311 |
"90s rock song with electric guitar and heavy drums",
|
312 |
None,
|
313 |
-
"facebook/musicgen-
|
314 |
"Default"
|
315 |
],
|
316 |
[
|
317 |
"a light and cheerly EDM track, with syncopated drums, aery pads, and strong emotions",
|
318 |
"./assets/bach.mp3",
|
319 |
-
"facebook/musicgen-
|
320 |
"Default"
|
321 |
],
|
322 |
[
|
323 |
"lofi slow bpm electro chill with organic samples",
|
324 |
None,
|
325 |
-
"facebook/musicgen-
|
326 |
"Default"
|
327 |
],
|
328 |
[
|
329 |
"Punk rock with loud drum and power guitar",
|
330 |
None,
|
331 |
-
"facebook/musicgen-
|
332 |
"MultiBand_Diffusion"
|
333 |
],
|
334 |
],
|
@@ -373,153 +370,4 @@ def ui_full(launch_kwargs):
|
|
373 |
for crashes, snares etc.
|
374 |
2. Use [MultiBand Diffusion](https://arxiv.org/abs/2308.02560). Should improve the audio quality,
|
375 |
at an extra computational cost. When this is selected, we provide both the GAN based decoded
|
376 |
-
audio, and the one obtained with MBD.
|
377 |
-
|
378 |
-
See [github.com/facebookresearch/audiocraft](https://github.com/facebookresearch/audiocraft/blob/main/docs/MUSICGEN.md)
|
379 |
-
for more details.
|
380 |
-
"""
|
381 |
-
)
|
382 |
-
|
383 |
-
interface.queue().launch(**launch_kwargs)
|
384 |
-
|
385 |
-
|
386 |
-
def ui_batched(launch_kwargs):
|
387 |
-
with gr.Blocks() as demo:
|
388 |
-
gr.Markdown(
|
389 |
-
"""
|
390 |
-
# MusicGen
|
391 |
-
|
392 |
-
This is the demo for [MusicGen](https://github.com/facebookresearch/audiocraft/blob/main/docs/MUSICGEN.md),
|
393 |
-
a simple and controllable model for music generation
|
394 |
-
presented at: ["Simple and Controllable Music Generation"](https://huggingface.co/papers/2306.05284).
|
395 |
-
<br/>
|
396 |
-
<a href="https://huggingface.co/spaces/facebook/MusicGen?duplicate=true"
|
397 |
-
style="display: inline-block;margin-top: .5em;margin-right: .25em;" target="_blank">
|
398 |
-
<img style="margin-bottom: 0em;display: inline;margin-top: -.25em;"
|
399 |
-
src="https://bit.ly/3gLdBN6" alt="Duplicate Space"></a>
|
400 |
-
for longer sequences, more control and no queue.</p>
|
401 |
-
"""
|
402 |
-
)
|
403 |
-
with gr.Row():
|
404 |
-
with gr.Column():
|
405 |
-
with gr.Row():
|
406 |
-
text = gr.Text(label="Describe your music", lines=2, interactive=True)
|
407 |
-
with gr.Column():
|
408 |
-
radio = gr.Radio(["file", "mic"], value="file",
|
409 |
-
label="Condition on a melody (optional) File or Mic")
|
410 |
-
melody = gr.Audio(sources="upload", type="numpy", label="File",
|
411 |
-
interactive=True, elem_id="melody-input")
|
412 |
-
with gr.Row():
|
413 |
-
submit = gr.Button("Generate")
|
414 |
-
with gr.Column():
|
415 |
-
output = gr.Video(label="Generated Music")
|
416 |
-
audio_output = gr.Audio(label="Generated Music (wav)", type='filepath')
|
417 |
-
submit.click(predict_batched, inputs=[text, melody],
|
418 |
-
outputs=[output, audio_output], batch=True, max_batch_size=MAX_BATCH_SIZE)
|
419 |
-
radio.change(toggle_audio_src, radio, [melody], queue=False, show_progress=False)
|
420 |
-
gr.Examples(
|
421 |
-
fn=predict_batched,
|
422 |
-
examples=[
|
423 |
-
[
|
424 |
-
"An 80s driving pop song with heavy drums and synth pads in the background",
|
425 |
-
"./assets/bach.mp3",
|
426 |
-
],
|
427 |
-
[
|
428 |
-
"A cheerful country song with acoustic guitars",
|
429 |
-
"./assets/bolero_ravel.mp3",
|
430 |
-
],
|
431 |
-
[
|
432 |
-
"90s rock song with electric guitar and heavy drums",
|
433 |
-
None,
|
434 |
-
],
|
435 |
-
[
|
436 |
-
"a light and cheerly EDM track, with syncopated drums, aery pads, and strong emotions bpm: 130",
|
437 |
-
"./assets/bach.mp3",
|
438 |
-
],
|
439 |
-
[
|
440 |
-
"lofi slow bpm electro chill with organic samples",
|
441 |
-
None,
|
442 |
-
],
|
443 |
-
],
|
444 |
-
inputs=[text, melody],
|
445 |
-
outputs=[output]
|
446 |
-
)
|
447 |
-
gr.Markdown("""
|
448 |
-
### More details
|
449 |
-
|
450 |
-
The model will generate 15 seconds of audio based on the description you provided.
|
451 |
-
The model was trained with description from a stock music catalog, descriptions that will work best
|
452 |
-
should include some level of details on the instruments present, along with some intended use case
|
453 |
-
(e.g. adding "perfect for a commercial" can somehow help).
|
454 |
-
|
455 |
-
You can optionally provide a reference audio from which a broad melody will be extracted.
|
456 |
-
The model will then try to follow both the description and melody provided.
|
457 |
-
For best results, the melody should be 30 seconds long (I know, the samples we provide are not...)
|
458 |
-
|
459 |
-
You can access more control (longer generation, more models etc.) by clicking
|
460 |
-
the <a href="https://huggingface.co/spaces/facebook/MusicGen?duplicate=true"
|
461 |
-
style="display: inline-block;margin-top: .5em;margin-right: .25em;" target="_blank">
|
462 |
-
<img style="margin-bottom: 0em;display: inline;margin-top: -.25em;"
|
463 |
-
src="https://bit.ly/3gLdBN6" alt="Duplicate Space"></a>
|
464 |
-
(you will then need a paid GPU from HuggingFace).
|
465 |
-
If you have a GPU, you can run the gradio demo locally (click the link to our repo below for more info).
|
466 |
-
Finally, you can get a GPU for free from Google
|
467 |
-
and run the demo in [a Google Colab.](https://ai.honu.io/red/musicgen-colab).
|
468 |
-
|
469 |
-
See [github.com/facebookresearch/audiocraft](https://github.com/facebookresearch/audiocraft/blob/main/docs/MUSICGEN.md)
|
470 |
-
for more details. All samples are generated with the `stereo-melody` model.
|
471 |
-
""")
|
472 |
-
|
473 |
-
demo.queue(max_size=8 * 4).launch(**launch_kwargs)
|
474 |
-
|
475 |
-
|
476 |
-
if __name__ == "__main__":
|
477 |
-
parser = argparse.ArgumentParser()
|
478 |
-
parser.add_argument(
|
479 |
-
'--listen',
|
480 |
-
type=str,
|
481 |
-
default='0.0.0.0' if 'SPACE_ID' in os.environ else '127.0.0.1',
|
482 |
-
help='IP to listen on for connections to Gradio',
|
483 |
-
)
|
484 |
-
parser.add_argument(
|
485 |
-
'--username', type=str, default='', help='Username for authentication'
|
486 |
-
)
|
487 |
-
parser.add_argument(
|
488 |
-
'--password', type=str, default='', help='Password for authentication'
|
489 |
-
)
|
490 |
-
parser.add_argument(
|
491 |
-
'--server_port',
|
492 |
-
type=int,
|
493 |
-
default=0,
|
494 |
-
help='Port to run the server listener on',
|
495 |
-
)
|
496 |
-
parser.add_argument(
|
497 |
-
'--inbrowser', action='store_true', help='Open in browser'
|
498 |
-
)
|
499 |
-
parser.add_argument(
|
500 |
-
'--share', action='store_true', help='Share the gradio UI'
|
501 |
-
)
|
502 |
-
|
503 |
-
args = parser.parse_args()
|
504 |
-
|
505 |
-
launch_kwargs = {}
|
506 |
-
launch_kwargs['server_name'] = args.listen
|
507 |
-
|
508 |
-
if args.username and args.password:
|
509 |
-
launch_kwargs['auth'] = (args.username, args.password)
|
510 |
-
if args.server_port:
|
511 |
-
launch_kwargs['server_port'] = args.server_port
|
512 |
-
if args.inbrowser:
|
513 |
-
launch_kwargs['inbrowser'] = args.inbrowser
|
514 |
-
if args.share:
|
515 |
-
launch_kwargs['share'] = args.share
|
516 |
-
|
517 |
-
logging.basicConfig(level=logging.INFO, stream=sys.stderr)
|
518 |
-
|
519 |
-
# Show the interface
|
520 |
-
if IS_BATCHED:
|
521 |
-
global USE_DIFFUSION
|
522 |
-
USE_DIFFUSION = False
|
523 |
-
ui_batched(launch_kwargs)
|
524 |
-
else:
|
525 |
-
ui_full(launch_kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import argparse
|
2 |
from concurrent.futures import ProcessPoolExecutor
|
3 |
import logging
|
|
|
5 |
from pathlib import Path
|
6 |
import subprocess as sp
|
7 |
import sys
|
|
|
8 |
import time
|
9 |
import typing as tp
|
10 |
import warnings
|
|
|
17 |
from audiocraft.data.audio import audio_write
|
18 |
from audiocraft.models.encodec import InterleaveStereoCompressionModel
|
19 |
from audiocraft.models import MusicGen, MultiBandDiffusion
|
20 |
+
import multiprocessing as mp
|
21 |
|
22 |
+
# --- Utility Functions and Classes ---
|
23 |
|
24 |
+
class FileCleaner: # Unchanged from previous example, included for completeness
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
def __init__(self, file_lifetime: float = 3600):
|
26 |
self.file_lifetime = file_lifetime
|
27 |
self.files = []
|
|
|
39 |
self.files.pop(0)
|
40 |
else:
|
41 |
break
|
|
|
|
|
42 |
file_cleaner = FileCleaner()
|
43 |
|
44 |
+
def make_waveform(*args, **kwargs): # Unchanged
|
|
|
|
|
45 |
be = time.time()
|
46 |
with warnings.catch_warnings():
|
47 |
warnings.simplefilter('ignore')
|
|
|
49 |
print("Make a video took", time.time() - be)
|
50 |
return out
|
51 |
|
52 |
+
# --- Worker Process ---
|
53 |
|
54 |
+
def model_worker(model_name: str, task_queue: mp.Queue, result_queue: mp.Queue):
|
55 |
+
"""
|
56 |
+
Persistent worker process that loads the model and handles prediction tasks.
|
57 |
+
"""
|
58 |
+
try:
|
59 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
60 |
+
model = MusicGen.get_pretrained(model_name, device=device)
|
61 |
+
mbd = MultiBandDiffusion.get_mbd_musicgen(device=device) # Load MBD here too
|
62 |
|
63 |
+
while True:
|
64 |
+
task = task_queue.get()
|
65 |
+
if task is None: # Sentinel value to exit
|
66 |
+
break
|
67 |
|
68 |
+
task_id, text, melody, duration, use_diffusion, gen_params = task
|
69 |
+
|
70 |
+
try:
|
71 |
+
model.set_generation_params(duration=duration, **gen_params)
|
72 |
+
target_sr = model.sample_rate
|
73 |
+
target_ac = 1
|
74 |
+
processed_melody = None
|
75 |
+
if melody:
|
76 |
+
sr, melody_data = melody
|
77 |
+
melody_tensor = torch.from_numpy(melody_data).to(device).float().t()
|
78 |
+
if melody_tensor.ndim == 1:
|
79 |
+
melody_tensor = melody_tensor.unsqueeze(0)
|
80 |
+
melody_tensor = melody_tensor[..., :int(sr * duration)]
|
81 |
+
processed_melody = convert_audio(melody_tensor, sr, target_sr, target_ac)
|
82 |
+
|
83 |
+
if processed_melody is not None:
|
84 |
+
output, tokens = model.generate_with_chroma(
|
85 |
+
descriptions=[text],
|
86 |
+
melody_wavs=[processed_melody],
|
87 |
+
melody_sample_rate=target_sr,
|
88 |
+
progress=True,
|
89 |
+
return_tokens=True
|
90 |
+
)
|
91 |
+
else:
|
92 |
+
output, tokens = model.generate([text], progress=True, return_tokens=True)
|
93 |
+
|
94 |
+
output = output.detach().cpu()
|
95 |
+
|
96 |
+
if use_diffusion:
|
97 |
+
if isinstance(model.compression_model, InterleaveStereoCompressionModel):
|
98 |
+
left, right = model.compression_model.get_left_right_codes(tokens)
|
99 |
+
tokens = torch.cat([left, right])
|
100 |
+
outputs_diffusion = mbd.tokens_to_wav(tokens)
|
101 |
+
if isinstance(model.compression_model, InterleaveStereoCompressionModel):
|
102 |
+
assert outputs_diffusion.shape[1] == 1 # output is mono
|
103 |
+
outputs_diffusion = rearrange(outputs_diffusion, '(s b) c t -> b (s c) t', s=2)
|
104 |
+
outputs_diffusion = outputs_diffusion.detach().cpu()
|
105 |
+
result_queue.put((task_id, (output, outputs_diffusion))) # Send BOTH results.
|
106 |
+
else:
|
107 |
+
result_queue.put((task_id, (output, None))) # Send back the result
|
108 |
+
|
109 |
+
except Exception as e:
|
110 |
+
result_queue.put((task_id, e)) # Send back the exception
|
111 |
+
|
112 |
+
except Exception as e:
|
113 |
+
result_queue.put((-1,e)) #Fatal error on loading.
|
114 |
+
|
115 |
+
# --- Gradio Interface Functions ---
|
116 |
+
|
117 |
+
class Predictor:
|
118 |
+
def __init__(self, model_name: str):
|
119 |
+
self.task_queue = mp.Queue()
|
120 |
+
self.result_queue = mp.Queue()
|
121 |
+
self.process = mp.Process(target=model_worker, args=(model_name, self.task_queue, self.result_queue))
|
122 |
+
self.process.start()
|
123 |
+
self.current_task_id = 0
|
124 |
+
self._check_initialization()
|
125 |
+
|
126 |
+
|
127 |
+
def _check_initialization(self):
|
128 |
+
"""Check if the worker process initialized successfully."""
|
129 |
+
# Give it some time to either load or report failure.
|
130 |
+
time.sleep(2)
|
131 |
+
try:
|
132 |
+
task_id, result = self.result_queue.get(timeout=3) # Get result from model_worker
|
133 |
+
|
134 |
+
if isinstance(result, Exception):
|
135 |
+
if task_id == -1:
|
136 |
+
raise RuntimeError("Model loading failed in worker process.") from result
|
137 |
+
except:
|
138 |
+
pass # Expected if model loads fast enough
|
139 |
+
|
140 |
+
def predict(self, text, melody, duration, use_diffusion, **gen_params):
|
141 |
+
"""
|
142 |
+
Submits a prediction task to the worker process.
|
143 |
+
"""
|
144 |
+
self.current_task_id += 1
|
145 |
+
task = (self.current_task_id, text, melody, duration, use_diffusion, gen_params)
|
146 |
+
self.task_queue.put(task)
|
147 |
+
return self.current_task_id
|
148 |
+
|
149 |
+
def get_result(self, task_id):
|
150 |
+
"""
|
151 |
+
Retrieves the result of a prediction task. Blocks until the result is available.
|
152 |
+
"""
|
153 |
+
while True: # Loop to get the correct task
|
154 |
+
result_task_id, result = self.result_queue.get()
|
155 |
+
if result_task_id == task_id:
|
156 |
+
if isinstance(result, Exception):
|
157 |
+
raise result # Re-raise the exception in the main process
|
158 |
+
return result # (wav, diffusion_wav) or (wav, None)
|
159 |
+
|
160 |
+
def shutdown(self):
|
161 |
+
"""
|
162 |
+
Shuts down the worker process.
|
163 |
+
"""
|
164 |
+
self.task_queue.put(None) # Send sentinel value to stop the worker
|
165 |
+
self.process.join() # Wait for the process to terminate
|
166 |
+
|
167 |
+
|
168 |
+
# Global predictor instance
|
169 |
+
_predictor = None
|
170 |
+
|
171 |
+
def get_predictor(model_name:str = 'facebook/musicgen-melody'):
|
172 |
+
global _predictor
|
173 |
+
if _predictor is None:
|
174 |
+
_predictor = Predictor(model_name)
|
175 |
+
return _predictor
|
176 |
+
|
177 |
+
def predict_full(model, model_path, use_mbd, text, melody, duration, topk, topp, temperature, cfg_coef):
|
178 |
+
|
179 |
+
predictor = get_predictor(model)
|
180 |
+
task_id = predictor.predict(
|
181 |
+
text=text,
|
182 |
+
melody=melody,
|
183 |
+
duration=duration,
|
184 |
+
use_diffusion=use_mbd,
|
185 |
+
top_k=topk,
|
186 |
+
top_p=topp,
|
187 |
+
temperature=temperature,
|
188 |
+
cfg_coef=cfg_coef,
|
189 |
+
)
|
190 |
|
191 |
+
wav, diffusion_wav = predictor.get_result(task_id)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
192 |
|
193 |
+
# Save and return audio files
|
194 |
+
wav_paths = []
|
195 |
+
video_paths = []
|
196 |
+
|
197 |
+
# Save standard output
|
198 |
+
with NamedTemporaryFile("wb", suffix=".wav", delete=False) as file:
|
199 |
+
audio_write(
|
200 |
+
file.name, wav[0], 32000, strategy="loudness", #hardcoded sample rate
|
201 |
+
loudness_headroom_db=16, loudness_compressor=True, add_suffix=False
|
202 |
+
)
|
203 |
+
wav_paths.append(file.name)
|
204 |
+
video_paths.append(make_waveform(file.name)) # Make and clean up video
|
205 |
+
file_cleaner.add(file.name)
|
206 |
+
|
207 |
+
# Save MBD output if used
|
208 |
+
if diffusion_wav is not None:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
209 |
with NamedTemporaryFile("wb", suffix=".wav", delete=False) as file:
|
210 |
audio_write(
|
211 |
+
file.name, diffusion_wav[0], 32000, strategy="loudness", #hardcoded sample rate
|
212 |
+
loudness_headroom_db=16, loudness_compressor=True, add_suffix=False
|
213 |
+
)
|
214 |
+
wav_paths.append(file.name)
|
215 |
+
video_paths.append(make_waveform(file.name)) # Make and clean up video
|
216 |
file_cleaner.add(file.name)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
217 |
|
218 |
+
if use_mbd:
|
219 |
+
return video_paths[0], wav_paths[0], video_paths[1], wav_paths[1]
|
220 |
+
return video_paths[0], wav_paths[0], None, None
|
|
|
|
|
|
|
|
|
221 |
|
222 |
|
223 |
def toggle_audio_src(choice):
|
|
|
232 |
return [gr.update(visible=True)] * 2
|
233 |
else:
|
234 |
return [gr.update(visible=False)] * 2
|
235 |
+
# --- Gradio UI ---
|
236 |
|
237 |
def ui_full(launch_kwargs):
|
238 |
with gr.Blocks() as interface:
|
|
|
255 |
interactive=True, elem_id="melody-input")
|
256 |
with gr.Row():
|
257 |
submit = gr.Button("Submit")
|
258 |
+
# _ = gr.Button("Interrupt").click(fn=interrupt, queue=False) # Interrupt is now handled implicitly
|
|
|
259 |
with gr.Row():
|
260 |
model = gr.Radio(["facebook/musicgen-melody", "facebook/musicgen-medium", "facebook/musicgen-small",
|
261 |
"facebook/musicgen-large", "facebook/musicgen-melody-large",
|
262 |
"facebook/musicgen-stereo-small", "facebook/musicgen-stereo-medium",
|
263 |
"facebook/musicgen-stereo-melody", "facebook/musicgen-stereo-large",
|
264 |
"facebook/musicgen-stereo-melody-large"],
|
265 |
+
label="Model", value="facebook/musicgen-melody", interactive=True)
|
266 |
+
model_path = gr.Text(label="Model Path (custom models)", interactive=False, visible=False) # Keep, but hide
|
267 |
with gr.Row():
|
268 |
decoder = gr.Radio(["Default", "MultiBand_Diffusion"],
|
269 |
label="Decoder", value="Default", interactive=True)
|
|
|
277 |
with gr.Column():
|
278 |
output = gr.Video(label="Generated Music")
|
279 |
audio_output = gr.Audio(label="Generated Music (wav)", type='filepath')
|
280 |
+
diffusion_output = gr.Video(label="MultiBand Diffusion Decoder", visible=False)
|
281 |
+
audio_diffusion = gr.Audio(label="MultiBand Diffusion Decoder (wav)", type='filepath', visible=False)
|
282 |
+
|
283 |
+
submit.click(
|
284 |
+
toggle_diffusion, decoder, [diffusion_output, audio_diffusion], queue=False
|
285 |
+
).then(
|
286 |
+
predict_full,
|
287 |
+
inputs=[model, model_path, decoder, text, melody, duration, topk, topp, temperature, cfg_coef],
|
288 |
+
outputs=[output, audio_output, diffusion_output, audio_diffusion]
|
289 |
+
)
|
290 |
radio.change(toggle_audio_src, radio, [melody], queue=False, show_progress=False)
|
291 |
|
292 |
gr.Examples(
|
|
|
295 |
[
|
296 |
"An 80s driving pop song with heavy drums and synth pads in the background",
|
297 |
"./assets/bach.mp3",
|
298 |
+
"facebook/musicgen-melody",
|
299 |
"Default"
|
300 |
],
|
301 |
[
|
302 |
"A cheerful country song with acoustic guitars",
|
303 |
"./assets/bolero_ravel.mp3",
|
304 |
+
"facebook/musicgen-melody",
|
305 |
"Default"
|
306 |
],
|
307 |
[
|
308 |
"90s rock song with electric guitar and heavy drums",
|
309 |
None,
|
310 |
+
"facebook/musicgen-medium",
|
311 |
"Default"
|
312 |
],
|
313 |
[
|
314 |
"a light and cheerly EDM track, with syncopated drums, aery pads, and strong emotions",
|
315 |
"./assets/bach.mp3",
|
316 |
+
"facebook/musicgen-melody",
|
317 |
"Default"
|
318 |
],
|
319 |
[
|
320 |
"lofi slow bpm electro chill with organic samples",
|
321 |
None,
|
322 |
+
"facebook/musicgen-medium",
|
323 |
"Default"
|
324 |
],
|
325 |
[
|
326 |
"Punk rock with loud drum and power guitar",
|
327 |
None,
|
328 |
+
"facebook/musicgen-medium",
|
329 |
"MultiBand_Diffusion"
|
330 |
],
|
331 |
],
|
|
|
370 |
for crashes, snares etc.
|
371 |
2. Use [MultiBand Diffusion](https://arxiv.org/abs/2308.02560). Should improve the audio quality,
|
372 |
at an extra computational cost. When this is selected, we provide both the GAN based decoded
|
373 |
+
audio, and the one obtained with MBD.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|