Update demos/musicgen_app.py
Browse files- demos/musicgen_app.py +13 -47
demos/musicgen_app.py
CHANGED
@@ -8,11 +8,9 @@ import sys
|
|
8 |
import time
|
9 |
import typing as tp
|
10 |
from tempfile import NamedTemporaryFile, gettempdir
|
11 |
-
|
12 |
from einops import rearrange
|
13 |
import torch
|
14 |
import gradio as gr
|
15 |
-
|
16 |
from audiocraft.data.audio_utils import convert_audio
|
17 |
from audiocraft.data.audio import audio_write
|
18 |
from audiocraft.models.encodec import InterleaveStereoCompressionModel
|
@@ -20,17 +18,26 @@ from audiocraft.models import MusicGen, MultiBandDiffusion
|
|
20 |
import multiprocessing as mp
|
21 |
import warnings
|
22 |
|
23 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
|
25 |
class FileCleaner:
|
26 |
def __init__(self, file_lifetime: float = 3600):
|
27 |
self.file_lifetime = file_lifetime
|
28 |
self.files = []
|
29 |
-
|
30 |
def add(self, path: tp.Union[str, Path]):
|
31 |
self._cleanup()
|
32 |
self.files.append((time.time(), Path(path)))
|
33 |
-
|
34 |
def _cleanup(self):
|
35 |
now = time.time()
|
36 |
for time_added, path in list(self.files):
|
@@ -40,9 +47,9 @@ class FileCleaner:
|
|
40 |
self.files.pop(0)
|
41 |
else:
|
42 |
break
|
|
|
43 |
file_cleaner = FileCleaner()
|
44 |
|
45 |
-
|
46 |
def convert_wav_to_mp4(wav_path, output_path=None):
|
47 |
"""Converts a WAV file to a waveform MP4 video using ffmpeg."""
|
48 |
if output_path is None:
|
@@ -62,19 +69,14 @@ def convert_wav_to_mp4(wav_path, output_path=None):
|
|
62 |
"-preset", "fast", # Important, don't do veryslow.
|
63 |
str(output_path),
|
64 |
]
|
65 |
-
|
66 |
process = sp.run(command, capture_output=True, text=True, check=True)
|
67 |
return str(output_path)
|
68 |
-
|
69 |
except sp.CalledProcessError as e:
|
70 |
print(f"Error in ffmpeg conversion: {e}")
|
71 |
print(f"ffmpeg stdout: {e.stdout}")
|
72 |
print(f"ffmpeg stderr: {e.stderr}")
|
73 |
raise # Re-raise the exception to be caught by Gradio
|
74 |
|
75 |
-
|
76 |
-
# --- Worker Process ---
|
77 |
-
|
78 |
def model_worker(model_name: str, task_queue: mp.Queue, result_queue: mp.Queue):
|
79 |
"""
|
80 |
Persistent worker process (used when NOT running as a daemon).
|
@@ -83,14 +85,11 @@ def model_worker(model_name: str, task_queue: mp.Queue, result_queue: mp.Queue):
|
|
83 |
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
84 |
model = MusicGen.get_pretrained(model_name, device=device)
|
85 |
mbd = MultiBandDiffusion.get_mbd_musicgen(device=device)
|
86 |
-
|
87 |
while True:
|
88 |
task = task_queue.get()
|
89 |
if task is None:
|
90 |
break
|
91 |
-
|
92 |
task_id, text, melody, duration, use_diffusion, gen_params = task
|
93 |
-
|
94 |
try:
|
95 |
model.set_generation_params(duration=duration, **gen_params)
|
96 |
target_sr = model.sample_rate
|
@@ -103,7 +102,6 @@ def model_worker(model_name: str, task_queue: mp.Queue, result_queue: mp.Queue):
|
|
103 |
melody_tensor = melody_tensor.unsqueeze(0)
|
104 |
melody_tensor = melody_tensor[..., :int(sr * duration)]
|
105 |
processed_melody = convert_audio(melody_tensor, sr, target_sr, target_ac)
|
106 |
-
|
107 |
if processed_melody is not None:
|
108 |
output, tokens = model.generate_with_chroma(
|
109 |
descriptions=[text],
|
@@ -114,9 +112,7 @@ def model_worker(model_name: str, task_queue: mp.Queue, result_queue: mp.Queue):
|
|
114 |
)
|
115 |
else:
|
116 |
output, tokens = model.generate([text], progress=True, return_tokens=True)
|
117 |
-
|
118 |
output = output.detach().cpu()
|
119 |
-
|
120 |
if use_diffusion:
|
121 |
if isinstance(model.compression_model, InterleaveStereoCompressionModel):
|
122 |
left, right = model.compression_model.get_left_right_codes(tokens)
|
@@ -129,16 +125,11 @@ def model_worker(model_name: str, task_queue: mp.Queue, result_queue: mp.Queue):
|
|
129 |
result_queue.put((task_id, (output, outputs_diffusion)))
|
130 |
else:
|
131 |
result_queue.put((task_id, (output, None)))
|
132 |
-
|
133 |
except Exception as e:
|
134 |
result_queue.put((task_id, e))
|
135 |
-
|
136 |
except Exception as e:
|
137 |
result_queue.put((-1, e))
|
138 |
|
139 |
-
|
140 |
-
# --- Predictor Class (Modified for conditional process creation) ---
|
141 |
-
|
142 |
class Predictor:
|
143 |
def __init__(self, model_name: str):
|
144 |
self.model_name = model_name
|
@@ -190,7 +181,6 @@ class Predictor:
|
|
190 |
melody_tensor = melody_tensor.unsqueeze(0)
|
191 |
melody_tensor = melody_tensor[..., :int(sr * duration)]
|
192 |
processed_melody = convert_audio(melody_tensor, sr, target_sr, target_ac)
|
193 |
-
|
194 |
if processed_melody is not None:
|
195 |
output, tokens = self.model.generate_with_chroma(
|
196 |
descriptions=[text],
|
@@ -201,9 +191,7 @@ class Predictor:
|
|
201 |
)
|
202 |
else:
|
203 |
output, tokens = self.model.generate([text], progress=True, return_tokens=True)
|
204 |
-
|
205 |
output = output.detach().cpu()
|
206 |
-
|
207 |
if use_diffusion:
|
208 |
if isinstance(self.model.compression_model, InterleaveStereoCompressionModel):
|
209 |
left, right = self.model.compression_model.get_left_right_codes(tokens)
|
@@ -216,11 +204,8 @@ class Predictor:
|
|
216 |
return task_id, (output, outputs_diffusion) #Return the task id.
|
217 |
else:
|
218 |
return task_id, (output, None)
|
219 |
-
|
220 |
-
|
221 |
except Exception as e:
|
222 |
return task_id, e
|
223 |
-
|
224 |
else:
|
225 |
# Use the multiprocessing queue (multi-process mode)
|
226 |
self.current_task_id += 1
|
@@ -239,7 +224,6 @@ class Predictor:
|
|
239 |
result_task_id, result = self.result_queue.get()
|
240 |
if result_task_id == task_id:
|
241 |
break # Found the correct result
|
242 |
-
|
243 |
if isinstance(result, Exception):
|
244 |
raise result
|
245 |
return result
|
@@ -250,14 +234,12 @@ class Predictor:
|
|
250 |
self.task_queue.put(None)
|
251 |
self.process.join()
|
252 |
|
253 |
-
|
254 |
_default_model_name = "facebook/musicgen-melody"
|
255 |
|
256 |
@spaces.GPU(duration=90) # Use the decorator for Spaces
|
257 |
def predict_full(model, model_path, use_mbd, text, melody, duration, topk, topp, temperature, cfg_coef):
|
258 |
# Initialize Predictor *INSIDE* the function
|
259 |
predictor = Predictor(model)
|
260 |
-
|
261 |
task_id, (wav, diffusion_wav) = predictor.predict( # Unpack directly!
|
262 |
text=text,
|
263 |
melody=melody,
|
@@ -268,11 +250,9 @@ def predict_full(model, model_path, use_mbd, text, melody, duration, topk, topp,
|
|
268 |
temperature=temperature,
|
269 |
cfg_coef=cfg_coef,
|
270 |
)
|
271 |
-
|
272 |
# Save and return audio files
|
273 |
wav_paths = []
|
274 |
video_paths = []
|
275 |
-
|
276 |
# Save standard output
|
277 |
with NamedTemporaryFile("wb", suffix=".wav", delete=False) as file:
|
278 |
audio_write(
|
@@ -285,8 +265,6 @@ def predict_full(model, model_path, use_mbd, text, melody, duration, topk, topp,
|
|
285 |
video_paths.append(video_path)
|
286 |
file_cleaner.add(file.name)
|
287 |
file_cleaner.add(video_path)
|
288 |
-
|
289 |
-
|
290 |
# Save MBD output if used
|
291 |
if diffusion_wav is not None:
|
292 |
with NamedTemporaryFile("wb", suffix=".wav", delete=False) as file:
|
@@ -300,31 +278,24 @@ def predict_full(model, model_path, use_mbd, text, melody, duration, topk, topp,
|
|
300 |
video_paths.append(video_path)
|
301 |
file_cleaner.add(file.name)
|
302 |
file_cleaner.add(video_path)
|
303 |
-
|
304 |
# Shutdown predictor to prevent hanging processes!
|
305 |
-
|
306 |
if not predictor.is_daemon: # Important!
|
307 |
predictor.shutdown()
|
308 |
-
|
309 |
if use_mbd:
|
310 |
return video_paths[0], wav_paths[0], video_paths[1], wav_paths[1]
|
311 |
return video_paths[0], wav_paths[0], None, None
|
312 |
|
313 |
-
|
314 |
def toggle_audio_src(choice):
|
315 |
if choice == "mic":
|
316 |
return gr.update(sources="microphone", value=None, label="Microphone")
|
317 |
else:
|
318 |
return gr.update(sources="upload", value=None, label="File")
|
319 |
|
320 |
-
|
321 |
def toggle_diffusion(choice):
|
322 |
if choice == "MultiBand_Diffusion":
|
323 |
return [gr.update(visible=True)] * 2
|
324 |
else:
|
325 |
return [gr.update(visible=False)] * 2
|
326 |
-
# --- Gradio UI ---
|
327 |
-
|
328 |
|
329 |
def ui_full(launch_kwargs):
|
330 |
with gr.Blocks() as interface:
|
@@ -475,7 +446,6 @@ def ui_full(launch_kwargs):
|
|
475 |
|
476 |
interface.queue().launch(**launch_kwargs)
|
477 |
|
478 |
-
# --- Main Entry Point ---
|
479 |
if __name__ == '__main__':
|
480 |
parser = argparse.ArgumentParser()
|
481 |
parser.add_argument(
|
@@ -502,12 +472,9 @@ if __name__ == '__main__':
|
|
502 |
parser.add_argument(
|
503 |
'--share', action='store_true', help='Share the gradio UI'
|
504 |
)
|
505 |
-
|
506 |
args = parser.parse_args()
|
507 |
-
|
508 |
launch_kwargs = {}
|
509 |
launch_kwargs['server_name'] = args.listen
|
510 |
-
|
511 |
if args.username and args.password:
|
512 |
launch_kwargs['auth'] = (args.username, args.password)
|
513 |
if args.server_port:
|
@@ -516,7 +483,6 @@ if __name__ == '__main__':
|
|
516 |
launch_kwargs['inbrowser'] = args.inbrowser
|
517 |
if args.share:
|
518 |
launch_kwargs['share'] = args.share
|
519 |
-
|
520 |
logging.basicConfig(level=logging.INFO, stream=sys.stderr)
|
521 |
# Added predictor shutdown
|
522 |
try:
|
|
|
8 |
import time
|
9 |
import typing as tp
|
10 |
from tempfile import NamedTemporaryFile, gettempdir
|
|
|
11 |
from einops import rearrange
|
12 |
import torch
|
13 |
import gradio as gr
|
|
|
14 |
from audiocraft.data.audio_utils import convert_audio
|
15 |
from audiocraft.data.audio import audio_write
|
16 |
from audiocraft.models.encodec import InterleaveStereoCompressionModel
|
|
|
18 |
import multiprocessing as mp
|
19 |
import warnings
|
20 |
|
21 |
+
os.putenv("HF_HUB_ENABLE_HF_TRANSFER","1")
|
22 |
+
os.environ["SAFETENSORS_FAST_GPU"] = "1"
|
23 |
+
|
24 |
+
torch.backends.cuda.matmul.allow_tf32 = False
|
25 |
+
torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = False
|
26 |
+
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False
|
27 |
+
torch.backends.cudnn.allow_tf32 = False
|
28 |
+
torch.backends.cudnn.deterministic = False
|
29 |
+
torch.backends.cudnn.benchmark = False
|
30 |
+
# torch.backends.cuda.preferred_blas_library="cublas"
|
31 |
+
# torch.backends.cuda.preferred_linalg_library="cusolver"
|
32 |
+
torch.set_float32_matmul_precision("highest")
|
33 |
|
34 |
class FileCleaner:
|
35 |
def __init__(self, file_lifetime: float = 3600):
|
36 |
self.file_lifetime = file_lifetime
|
37 |
self.files = []
|
|
|
38 |
def add(self, path: tp.Union[str, Path]):
|
39 |
self._cleanup()
|
40 |
self.files.append((time.time(), Path(path)))
|
|
|
41 |
def _cleanup(self):
|
42 |
now = time.time()
|
43 |
for time_added, path in list(self.files):
|
|
|
47 |
self.files.pop(0)
|
48 |
else:
|
49 |
break
|
50 |
+
|
51 |
file_cleaner = FileCleaner()
|
52 |
|
|
|
53 |
def convert_wav_to_mp4(wav_path, output_path=None):
|
54 |
"""Converts a WAV file to a waveform MP4 video using ffmpeg."""
|
55 |
if output_path is None:
|
|
|
69 |
"-preset", "fast", # Important, don't do veryslow.
|
70 |
str(output_path),
|
71 |
]
|
|
|
72 |
process = sp.run(command, capture_output=True, text=True, check=True)
|
73 |
return str(output_path)
|
|
|
74 |
except sp.CalledProcessError as e:
|
75 |
print(f"Error in ffmpeg conversion: {e}")
|
76 |
print(f"ffmpeg stdout: {e.stdout}")
|
77 |
print(f"ffmpeg stderr: {e.stderr}")
|
78 |
raise # Re-raise the exception to be caught by Gradio
|
79 |
|
|
|
|
|
|
|
80 |
def model_worker(model_name: str, task_queue: mp.Queue, result_queue: mp.Queue):
|
81 |
"""
|
82 |
Persistent worker process (used when NOT running as a daemon).
|
|
|
85 |
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
86 |
model = MusicGen.get_pretrained(model_name, device=device)
|
87 |
mbd = MultiBandDiffusion.get_mbd_musicgen(device=device)
|
|
|
88 |
while True:
|
89 |
task = task_queue.get()
|
90 |
if task is None:
|
91 |
break
|
|
|
92 |
task_id, text, melody, duration, use_diffusion, gen_params = task
|
|
|
93 |
try:
|
94 |
model.set_generation_params(duration=duration, **gen_params)
|
95 |
target_sr = model.sample_rate
|
|
|
102 |
melody_tensor = melody_tensor.unsqueeze(0)
|
103 |
melody_tensor = melody_tensor[..., :int(sr * duration)]
|
104 |
processed_melody = convert_audio(melody_tensor, sr, target_sr, target_ac)
|
|
|
105 |
if processed_melody is not None:
|
106 |
output, tokens = model.generate_with_chroma(
|
107 |
descriptions=[text],
|
|
|
112 |
)
|
113 |
else:
|
114 |
output, tokens = model.generate([text], progress=True, return_tokens=True)
|
|
|
115 |
output = output.detach().cpu()
|
|
|
116 |
if use_diffusion:
|
117 |
if isinstance(model.compression_model, InterleaveStereoCompressionModel):
|
118 |
left, right = model.compression_model.get_left_right_codes(tokens)
|
|
|
125 |
result_queue.put((task_id, (output, outputs_diffusion)))
|
126 |
else:
|
127 |
result_queue.put((task_id, (output, None)))
|
|
|
128 |
except Exception as e:
|
129 |
result_queue.put((task_id, e))
|
|
|
130 |
except Exception as e:
|
131 |
result_queue.put((-1, e))
|
132 |
|
|
|
|
|
|
|
133 |
class Predictor:
|
134 |
def __init__(self, model_name: str):
|
135 |
self.model_name = model_name
|
|
|
181 |
melody_tensor = melody_tensor.unsqueeze(0)
|
182 |
melody_tensor = melody_tensor[..., :int(sr * duration)]
|
183 |
processed_melody = convert_audio(melody_tensor, sr, target_sr, target_ac)
|
|
|
184 |
if processed_melody is not None:
|
185 |
output, tokens = self.model.generate_with_chroma(
|
186 |
descriptions=[text],
|
|
|
191 |
)
|
192 |
else:
|
193 |
output, tokens = self.model.generate([text], progress=True, return_tokens=True)
|
|
|
194 |
output = output.detach().cpu()
|
|
|
195 |
if use_diffusion:
|
196 |
if isinstance(self.model.compression_model, InterleaveStereoCompressionModel):
|
197 |
left, right = self.model.compression_model.get_left_right_codes(tokens)
|
|
|
204 |
return task_id, (output, outputs_diffusion) #Return the task id.
|
205 |
else:
|
206 |
return task_id, (output, None)
|
|
|
|
|
207 |
except Exception as e:
|
208 |
return task_id, e
|
|
|
209 |
else:
|
210 |
# Use the multiprocessing queue (multi-process mode)
|
211 |
self.current_task_id += 1
|
|
|
224 |
result_task_id, result = self.result_queue.get()
|
225 |
if result_task_id == task_id:
|
226 |
break # Found the correct result
|
|
|
227 |
if isinstance(result, Exception):
|
228 |
raise result
|
229 |
return result
|
|
|
234 |
self.task_queue.put(None)
|
235 |
self.process.join()
|
236 |
|
|
|
237 |
_default_model_name = "facebook/musicgen-melody"
|
238 |
|
239 |
@spaces.GPU(duration=90) # Use the decorator for Spaces
|
240 |
def predict_full(model, model_path, use_mbd, text, melody, duration, topk, topp, temperature, cfg_coef):
|
241 |
# Initialize Predictor *INSIDE* the function
|
242 |
predictor = Predictor(model)
|
|
|
243 |
task_id, (wav, diffusion_wav) = predictor.predict( # Unpack directly!
|
244 |
text=text,
|
245 |
melody=melody,
|
|
|
250 |
temperature=temperature,
|
251 |
cfg_coef=cfg_coef,
|
252 |
)
|
|
|
253 |
# Save and return audio files
|
254 |
wav_paths = []
|
255 |
video_paths = []
|
|
|
256 |
# Save standard output
|
257 |
with NamedTemporaryFile("wb", suffix=".wav", delete=False) as file:
|
258 |
audio_write(
|
|
|
265 |
video_paths.append(video_path)
|
266 |
file_cleaner.add(file.name)
|
267 |
file_cleaner.add(video_path)
|
|
|
|
|
268 |
# Save MBD output if used
|
269 |
if diffusion_wav is not None:
|
270 |
with NamedTemporaryFile("wb", suffix=".wav", delete=False) as file:
|
|
|
278 |
video_paths.append(video_path)
|
279 |
file_cleaner.add(file.name)
|
280 |
file_cleaner.add(video_path)
|
|
|
281 |
# Shutdown predictor to prevent hanging processes!
|
|
|
282 |
if not predictor.is_daemon: # Important!
|
283 |
predictor.shutdown()
|
|
|
284 |
if use_mbd:
|
285 |
return video_paths[0], wav_paths[0], video_paths[1], wav_paths[1]
|
286 |
return video_paths[0], wav_paths[0], None, None
|
287 |
|
|
|
288 |
def toggle_audio_src(choice):
|
289 |
if choice == "mic":
|
290 |
return gr.update(sources="microphone", value=None, label="Microphone")
|
291 |
else:
|
292 |
return gr.update(sources="upload", value=None, label="File")
|
293 |
|
|
|
294 |
def toggle_diffusion(choice):
|
295 |
if choice == "MultiBand_Diffusion":
|
296 |
return [gr.update(visible=True)] * 2
|
297 |
else:
|
298 |
return [gr.update(visible=False)] * 2
|
|
|
|
|
299 |
|
300 |
def ui_full(launch_kwargs):
|
301 |
with gr.Blocks() as interface:
|
|
|
446 |
|
447 |
interface.queue().launch(**launch_kwargs)
|
448 |
|
|
|
449 |
if __name__ == '__main__':
|
450 |
parser = argparse.ArgumentParser()
|
451 |
parser.add_argument(
|
|
|
472 |
parser.add_argument(
|
473 |
'--share', action='store_true', help='Share the gradio UI'
|
474 |
)
|
|
|
475 |
args = parser.parse_args()
|
|
|
476 |
launch_kwargs = {}
|
477 |
launch_kwargs['server_name'] = args.listen
|
|
|
478 |
if args.username and args.password:
|
479 |
launch_kwargs['auth'] = (args.username, args.password)
|
480 |
if args.server_port:
|
|
|
483 |
launch_kwargs['inbrowser'] = args.inbrowser
|
484 |
if args.share:
|
485 |
launch_kwargs['share'] = args.share
|
|
|
486 |
logging.basicConfig(level=logging.INFO, stream=sys.stderr)
|
487 |
# Added predictor shutdown
|
488 |
try:
|