ford442 commited on
Commit
590092f
·
verified ·
1 Parent(s): 38c73b8

Update demos/musicgen_app.py

Browse files
Files changed (1) hide show
  1. demos/musicgen_app.py +16 -14
demos/musicgen_app.py CHANGED
@@ -1,4 +1,5 @@
1
- import spaces
 
2
  import argparse
3
  import logging
4
  import os
@@ -19,6 +20,7 @@ from audiocraft.models.encodec import InterleaveStereoCompressionModel
19
  from audiocraft.models import MusicGen, MultiBandDiffusion
20
  import multiprocessing as mp
21
 
 
22
  # --- Utility Functions and Classes ---
23
 
24
  class FileCleaner: # Unchanged
@@ -50,7 +52,7 @@ def make_waveform(*args, **kwargs): # Unchanged
50
  return out
51
 
52
  # --- Worker Process ---
53
- @spaces.GPU(required=True)
54
  def model_worker(model_name: str, task_queue: mp.Queue, result_queue: mp.Queue):
55
  """
56
  Persistent worker process that loads the model and handles prediction tasks.
@@ -115,6 +117,7 @@ def model_worker(model_name: str, task_queue: mp.Queue, result_queue: mp.Queue):
115
  # --- Gradio Interface Functions ---
116
 
117
  class Predictor:
 
118
  def __init__(self, model_name: str):
119
  self.task_queue = mp.Queue()
120
  self.result_queue = mp.Queue()
@@ -161,22 +164,19 @@ class Predictor:
161
  """
162
  Shuts down the worker process.
163
  """
164
- self.task_queue.put(None) # Send sentinel value to stop the worker
165
- self.process.join() # Wait for the process to terminate
 
166
 
 
167
 
168
- # Global predictor instance
169
- _predictor = None
170
-
171
- def get_predictor(model_name:str = 'facebook/musicgen-melody'):
172
- global _predictor
173
- if _predictor is None:
174
- _predictor = Predictor(model_name)
175
- return _predictor
176
 
 
177
  def predict_full(model, model_path, use_mbd, text, melody, duration, topk, topp, temperature, cfg_coef):
 
 
178
 
179
- predictor = get_predictor(model)
180
  task_id = predictor.predict(
181
  text=text,
182
  melody=melody,
@@ -214,6 +214,8 @@ def predict_full(model, model_path, use_mbd, text, melody, duration, topk, topp,
214
  wav_paths.append(file.name)
215
  video_paths.append(make_waveform(file.name)) # Make and clean up video
216
  file_cleaner.add(file.name)
 
 
217
 
218
  if use_mbd:
219
  return video_paths[0], wav_paths[0], video_paths[1], wav_paths[1]
@@ -316,7 +318,7 @@ def ui_full(launch_kwargs):
316
  "facebook/musicgen-melody",
317
  "Default"
318
  ],
319
- [
320
  "lofi slow bpm electro chill with organic samples",
321
  None,
322
  "facebook/musicgen-medium",
 
1
+ import spaces # <--- IMPORTANT: Add this import
2
+
3
  import argparse
4
  import logging
5
  import os
 
20
  from audiocraft.models import MusicGen, MultiBandDiffusion
21
  import multiprocessing as mp
22
 
23
+
24
  # --- Utility Functions and Classes ---
25
 
26
  class FileCleaner: # Unchanged
 
52
  return out
53
 
54
  # --- Worker Process ---
55
+ #This stays the same, since the worker is designed for this purpose
56
  def model_worker(model_name: str, task_queue: mp.Queue, result_queue: mp.Queue):
57
  """
58
  Persistent worker process that loads the model and handles prediction tasks.
 
117
  # --- Gradio Interface Functions ---
118
 
119
  class Predictor:
120
+ #This stays the same, this is the intended design
121
  def __init__(self, model_name: str):
122
  self.task_queue = mp.Queue()
123
  self.result_queue = mp.Queue()
 
164
  """
165
  Shuts down the worker process.
166
  """
167
+ if self.process.is_alive():
168
+ self.task_queue.put(None) # Send sentinel value to stop the worker
169
+ self.process.join() # Wait for the process to terminate
170
 
171
+ # NO GLOBAL PREDICTOR ANYMORE
172
 
173
+ _default_model_name = 'facebook/musicgen-melody'
 
 
 
 
 
 
 
174
 
175
+ @spaces.GPU(duration=60) # <--- IMPORTANT: Add this decorator
176
  def predict_full(model, model_path, use_mbd, text, melody, duration, topk, topp, temperature, cfg_coef):
177
+ # Initialize Predictor *INSIDE* the function
178
+ predictor = Predictor(model)
179
 
 
180
  task_id = predictor.predict(
181
  text=text,
182
  melody=melody,
 
214
  wav_paths.append(file.name)
215
  video_paths.append(make_waveform(file.name)) # Make and clean up video
216
  file_cleaner.add(file.name)
217
+ # Shutdown predictor to prevent hanging processes!
218
+ predictor.shutdown()
219
 
220
  if use_mbd:
221
  return video_paths[0], wav_paths[0], video_paths[1], wav_paths[1]
 
318
  "facebook/musicgen-melody",
319
  "Default"
320
  ],
321
+ [
322
  "lofi slow bpm electro chill with organic samples",
323
  None,
324
  "facebook/musicgen-medium",