Update demos/musicgen_app.py
Browse files- demos/musicgen_app.py +16 -14
demos/musicgen_app.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
-
import spaces
|
|
|
2 |
import argparse
|
3 |
import logging
|
4 |
import os
|
@@ -19,6 +20,7 @@ from audiocraft.models.encodec import InterleaveStereoCompressionModel
|
|
19 |
from audiocraft.models import MusicGen, MultiBandDiffusion
|
20 |
import multiprocessing as mp
|
21 |
|
|
|
22 |
# --- Utility Functions and Classes ---
|
23 |
|
24 |
class FileCleaner: # Unchanged
|
@@ -50,7 +52,7 @@ def make_waveform(*args, **kwargs): # Unchanged
|
|
50 |
return out
|
51 |
|
52 |
# --- Worker Process ---
|
53 |
-
|
54 |
def model_worker(model_name: str, task_queue: mp.Queue, result_queue: mp.Queue):
|
55 |
"""
|
56 |
Persistent worker process that loads the model and handles prediction tasks.
|
@@ -115,6 +117,7 @@ def model_worker(model_name: str, task_queue: mp.Queue, result_queue: mp.Queue):
|
|
115 |
# --- Gradio Interface Functions ---
|
116 |
|
117 |
class Predictor:
|
|
|
118 |
def __init__(self, model_name: str):
|
119 |
self.task_queue = mp.Queue()
|
120 |
self.result_queue = mp.Queue()
|
@@ -161,22 +164,19 @@ class Predictor:
|
|
161 |
"""
|
162 |
Shuts down the worker process.
|
163 |
"""
|
164 |
-
self.
|
165 |
-
|
|
|
166 |
|
|
|
167 |
|
168 |
-
|
169 |
-
_predictor = None
|
170 |
-
|
171 |
-
def get_predictor(model_name:str = 'facebook/musicgen-melody'):
|
172 |
-
global _predictor
|
173 |
-
if _predictor is None:
|
174 |
-
_predictor = Predictor(model_name)
|
175 |
-
return _predictor
|
176 |
|
|
|
177 |
def predict_full(model, model_path, use_mbd, text, melody, duration, topk, topp, temperature, cfg_coef):
|
|
|
|
|
178 |
|
179 |
-
predictor = get_predictor(model)
|
180 |
task_id = predictor.predict(
|
181 |
text=text,
|
182 |
melody=melody,
|
@@ -214,6 +214,8 @@ def predict_full(model, model_path, use_mbd, text, melody, duration, topk, topp,
|
|
214 |
wav_paths.append(file.name)
|
215 |
video_paths.append(make_waveform(file.name)) # Make and clean up video
|
216 |
file_cleaner.add(file.name)
|
|
|
|
|
217 |
|
218 |
if use_mbd:
|
219 |
return video_paths[0], wav_paths[0], video_paths[1], wav_paths[1]
|
@@ -316,7 +318,7 @@ def ui_full(launch_kwargs):
|
|
316 |
"facebook/musicgen-melody",
|
317 |
"Default"
|
318 |
],
|
319 |
-
|
320 |
"lofi slow bpm electro chill with organic samples",
|
321 |
None,
|
322 |
"facebook/musicgen-medium",
|
|
|
1 |
+
import spaces # <--- IMPORTANT: Add this import
|
2 |
+
|
3 |
import argparse
|
4 |
import logging
|
5 |
import os
|
|
|
20 |
from audiocraft.models import MusicGen, MultiBandDiffusion
|
21 |
import multiprocessing as mp
|
22 |
|
23 |
+
|
24 |
# --- Utility Functions and Classes ---
|
25 |
|
26 |
class FileCleaner: # Unchanged
|
|
|
52 |
return out
|
53 |
|
54 |
# --- Worker Process ---
|
55 |
+
#This stays the same, since the worker is designed for this purpose
|
56 |
def model_worker(model_name: str, task_queue: mp.Queue, result_queue: mp.Queue):
|
57 |
"""
|
58 |
Persistent worker process that loads the model and handles prediction tasks.
|
|
|
117 |
# --- Gradio Interface Functions ---
|
118 |
|
119 |
class Predictor:
|
120 |
+
#This stays the same, this is the intended design
|
121 |
def __init__(self, model_name: str):
|
122 |
self.task_queue = mp.Queue()
|
123 |
self.result_queue = mp.Queue()
|
|
|
164 |
"""
|
165 |
Shuts down the worker process.
|
166 |
"""
|
167 |
+
if self.process.is_alive():
|
168 |
+
self.task_queue.put(None) # Send sentinel value to stop the worker
|
169 |
+
self.process.join() # Wait for the process to terminate
|
170 |
|
171 |
+
# NO GLOBAL PREDICTOR ANYMORE
|
172 |
|
173 |
+
_default_model_name = 'facebook/musicgen-melody'
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
174 |
|
175 |
+
@spaces.GPU(duration=60) # <--- IMPORTANT: Add this decorator
|
176 |
def predict_full(model, model_path, use_mbd, text, melody, duration, topk, topp, temperature, cfg_coef):
|
177 |
+
# Initialize Predictor *INSIDE* the function
|
178 |
+
predictor = Predictor(model)
|
179 |
|
|
|
180 |
task_id = predictor.predict(
|
181 |
text=text,
|
182 |
melody=melody,
|
|
|
214 |
wav_paths.append(file.name)
|
215 |
video_paths.append(make_waveform(file.name)) # Make and clean up video
|
216 |
file_cleaner.add(file.name)
|
217 |
+
# Shutdown predictor to prevent hanging processes!
|
218 |
+
predictor.shutdown()
|
219 |
|
220 |
if use_mbd:
|
221 |
return video_paths[0], wav_paths[0], video_paths[1], wav_paths[1]
|
|
|
318 |
"facebook/musicgen-melody",
|
319 |
"Default"
|
320 |
],
|
321 |
+
[
|
322 |
"lofi slow bpm electro chill with organic samples",
|
323 |
None,
|
324 |
"facebook/musicgen-medium",
|