ghostai1 commited on
Commit
a50352d
·
verified ·
1 Parent(s): b2feece

Update app.py

Browse files

quality of life updates

Files changed (1) hide show
  1. app.py +100 -37
app.py CHANGED
@@ -11,6 +11,7 @@ from pydub import AudioSegment
11
  from audiocraft.models import MusicGen
12
  from torch.cuda.amp import autocast
13
  import warnings
 
14
 
15
  # Suppress warnings for cleaner output
16
  warnings.filterwarnings("ignore")
@@ -71,7 +72,13 @@ def set_soundgarden_grunge_prompt():
71
  return "Grunge with heavy, sludgy guitar riffs, complex drum patterns, and a Soundgarden-inspired dark, psychedelic edge with powerful vocals."
72
 
73
  def set_foo_fighters_prompt():
74
- return "Alternative rock with punchy guitar riffs, tight drums, melodic hooks, and a Foo Fighters-inspired anthemic energy with gritty verses."
 
 
 
 
 
 
75
 
76
  def set_smashing_pumpkins_prompt():
77
  return "Alternative rock with dreamy guitar textures, heavy distortion, dynamic drums, and a Smashing Pumpkins-inspired blend of melancholy and aggression."
@@ -126,34 +133,89 @@ def apply_fade(segment, fade_in_duration=2000, fade_out_duration=2000):
126
  return segment
127
 
128
  # 6) GENERATION & I/O FUNCTIONS
129
- def generate_music(instrumental_prompt: str, cfg_scale: float, top_k: int, top_p: float, temperature: float, total_duration: int, crossfade_duration: int, num_variations: int = 1):
130
  global musicgen_model
131
  if not instrumental_prompt.strip():
132
  return None, "⚠️ Please enter a valid instrumental prompt!"
133
  try:
134
  start_time = time.time()
135
  total_duration = min(max(total_duration, 10), 90)
136
- chunk_duration = 15
137
- num_chunks = max(1, total_duration // chunk_duration)
138
- chunk_duration = total_duration / num_chunks
139
- overlap_duration = min(1.0, crossfade_duration / 1000.0)
140
- generation_duration = chunk_duration + overlap_duration
141
-
142
- output_files = []
143
  sample_rate = musicgen_model.sample_rate
 
144
 
145
  for var in range(num_variations):
146
  print(f"Generating variation {var+1}/{num_variations}...")
147
- audio_chunks = []
148
  seed = 42 + var # Use different seeds for variations
149
  torch.manual_seed(seed)
150
  np.random.seed(seed)
151
 
152
- for i in range(num_chunks):
153
- chunk_prompt = instrumental_prompt
154
- print(f"Generating chunk {i+1}/{num_chunks} for variation {var+1} on GPU (prompt: {chunk_prompt})...")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
155
  musicgen_model.set_generation_params(
156
- duration=generation_duration,
157
  use_sampling=True,
158
  top_k=top_k,
159
  top_p=top_p,
@@ -161,11 +223,11 @@ def generate_music(instrumental_prompt: str, cfg_scale: float, top_k: int, top_p
161
  cfg_coef=cfg_scale
162
  )
163
 
164
- print_resource_usage(f"Before Chunk {i+1} Generation (Variation {var+1})")
165
 
166
  with torch.no_grad():
167
  with autocast():
168
- audio_chunk = musicgen_model.generate([chunk_prompt], progress=True)[0]
169
 
170
  audio_chunk = audio_chunk.cpu().to(dtype=torch.float32)
171
  if audio_chunk.dim() == 1:
@@ -181,27 +243,15 @@ def generate_music(instrumental_prompt: str, cfg_scale: float, top_k: int, top_p
181
  if audio_chunk.shape[0] != 2:
182
  raise ValueError(f"Expected stereo audio with shape (2, samples), got shape {audio_chunk.shape}")
183
 
184
- temp_wav_path = f"temp_chunk_{var}_{i}.wav"
185
- chunk_path = f"chunk_{var}_{i}.mp3"
186
  torchaudio.save(temp_wav_path, audio_chunk, sample_rate, bits_per_sample=24)
187
- segment = AudioSegment.from_wav(temp_wav_path)
188
- segment.export(chunk_path, format="mp3", bitrate="320k")
189
  os.remove(temp_wav_path)
190
- audio_chunks.append(chunk_path)
191
 
192
  torch.cuda.empty_cache()
193
  gc.collect()
194
  time.sleep(0.5)
195
- print_resource_usage(f"After Chunk {i+1} Generation (Variation {var+1})")
196
-
197
- print(f"Combining audio chunks for variation {var+1}...")
198
- final_segment = AudioSegment.from_mp3(audio_chunks[0])
199
- for i in range(1, len(audio_chunks)):
200
- next_segment = AudioSegment.from_mp3(audio_chunks[i])
201
- next_segment = next_segment + 1
202
- final_segment = final_segment.append(next_segment, crossfade=crossfade_duration)
203
-
204
- final_segment = final_segment[:total_duration * 1000]
205
 
206
  print(f"Post-processing final track for variation {var+1}...")
207
  final_segment = apply_eq(final_segment)
@@ -220,8 +270,9 @@ def generate_music(instrumental_prompt: str, cfg_scale: float, top_k: int, top_p
220
  print(f"Saved final audio to {mp3_path}")
221
  output_files.append(mp3_path)
222
 
223
- for chunk_path in audio_chunks:
224
- os.remove(chunk_path)
 
225
 
226
  print_resource_usage("After Final Generation")
227
  print(f"Total Generation Time: {time.time() - start_time:.2f} seconds")
@@ -234,8 +285,12 @@ def generate_music(instrumental_prompt: str, cfg_scale: float, top_k: int, top_p
234
  torch.cuda.empty_cache()
235
  gc.collect()
236
 
 
 
 
 
237
  def clear_inputs():
238
- return "", 3.0, 250, 0.9, 1.0, 30, 500, 1
239
 
240
  # 7) CUSTOM CSS
241
  css = """
@@ -384,7 +439,7 @@ with gr.Blocks(css=css) as demo:
384
  maximum=2000,
385
  value=500,
386
  step=100,
387
- info="Crossfade duration between chunks."
388
  )
389
  num_variations = gr.Slider(
390
  label="Number of Variations",
@@ -394,6 +449,11 @@ with gr.Blocks(css=css) as demo:
394
  step=1,
395
  info="Number of different versions to generate with varying random seeds."
396
  )
 
 
 
 
 
397
  with gr.Row(elem_classes="action-buttons"):
398
  gen_btn = gr.Button("Generate Music")
399
  clr_btn = gr.Button("Clear Inputs")
@@ -402,6 +462,9 @@ with gr.Blocks(css=css) as demo:
402
  out_audio = gr.Audio(label="Generated Stereo Instrumental Track", type="filepath")
403
  status = gr.Textbox(label="Status", interactive=False)
404
 
 
 
 
405
  rhcp_btn.click(set_red_hot_chili_peppers_prompt, inputs=None, outputs=[instrumental_prompt])
406
  nirvana_btn.click(set_nirvana_grunge_prompt, inputs=None, outputs=[instrumental_prompt])
407
  pearl_jam_btn.click(set_pearl_jam_grunge_prompt, inputs=None, outputs=[instrumental_prompt])
@@ -418,13 +481,13 @@ with gr.Blocks(css=css) as demo:
418
  deep_house_btn.click(set_deep_house_prompt, inputs=None, outputs=[instrumental_prompt])
419
  gen_btn.click(
420
  generate_music,
421
- inputs=[instrumental_prompt, cfg_scale, top_k, top_p, temperature, total_duration, crossfade_duration, num_variations],
422
  outputs=[out_audio, status]
423
  )
424
  clr_btn.click(
425
  clear_inputs,
426
  inputs=None,
427
- outputs=[instrumental_prompt, cfg_scale, top_k, top_p, temperature, total_duration, crossfade_duration, num_variations]
428
  )
429
 
430
  # 9) TURN OFF OPENAPI/DOCS
 
11
  from audiocraft.models import MusicGen
12
  from torch.cuda.amp import autocast
13
  import warnings
14
+ import random
15
 
16
  # Suppress warnings for cleaner output
17
  warnings.filterwarnings("ignore")
 
72
  return "Grunge with heavy, sludgy guitar riffs, complex drum patterns, and a Soundgarden-inspired dark, psychedelic edge with powerful vocals."
73
 
74
  def set_foo_fighters_prompt():
75
+ styles = ["anthemic", "gritty", "melodic", "fast-paced", "driving"]
76
+ tempos = ["upbeat", "mid-tempo", "high-energy"]
77
+ moods = ["energetic", "introspective", "rebellious", "uplifting"]
78
+ style = random.choice(styles)
79
+ tempo = random.choice(tempos)
80
+ mood = random.choice(moods)
81
+ return f"Alternative rock with {style} guitar riffs, {tempo} drums, melodic hooks, and a Foo Fighters-inspired {mood} vibe with powerful choruses."
82
 
83
  def set_smashing_pumpkins_prompt():
84
  return "Alternative rock with dreamy guitar textures, heavy distortion, dynamic drums, and a Smashing Pumpkins-inspired blend of melancholy and aggression."
 
133
  return segment
134
 
135
  # 6) GENERATION & I/O FUNCTIONS
136
+ def generate_music(instrumental_prompt: str, cfg_scale: float, top_k: int, top_p: float, temperature: float, total_duration: int, crossfade_duration: int, num_variations: int, use_chunks: bool):
137
  global musicgen_model
138
  if not instrumental_prompt.strip():
139
  return None, "⚠️ Please enter a valid instrumental prompt!"
140
  try:
141
  start_time = time.time()
142
  total_duration = min(max(total_duration, 10), 90)
 
 
 
 
 
 
 
143
  sample_rate = musicgen_model.sample_rate
144
+ output_files = []
145
 
146
  for var in range(num_variations):
147
  print(f"Generating variation {var+1}/{num_variations}...")
 
148
  seed = 42 + var # Use different seeds for variations
149
  torch.manual_seed(seed)
150
  np.random.seed(seed)
151
 
152
+ if use_chunks:
153
+ # Chunked generation
154
+ chunk_duration = 15
155
+ num_chunks = max(1, total_duration // chunk_duration)
156
+ chunk_duration = total_duration / num_chunks
157
+ overlap_duration = min(1.0, crossfade_duration / 1000.0)
158
+ generation_duration = chunk_duration + overlap_duration
159
+ audio_chunks = []
160
+
161
+ for i in range(num_chunks):
162
+ chunk_prompt = instrumental_prompt
163
+ print(f"Generating chunk {i+1}/{num_chunks} for variation {var+1} on GPU (prompt: {chunk_prompt})...")
164
+ musicgen_model.set_generation_params(
165
+ duration=generation_duration,
166
+ use_sampling=True,
167
+ top_k=top_k,
168
+ top_p=top_p,
169
+ temperature=temperature,
170
+ cfg_coef=cfg_scale
171
+ )
172
+
173
+ print_resource_usage(f"Before Chunk {i+1} Generation (Variation {var+1})")
174
+
175
+ with torch.no_grad():
176
+ with autocast():
177
+ audio_chunk = musicgen_model.generate([chunk_prompt], progress=True)[0]
178
+
179
+ audio_chunk = audio_chunk.cpu().to(dtype=torch.float32)
180
+ if audio_chunk.dim() == 1:
181
+ audio_chunk = torch.stack([audio_chunk, audio_chunk], dim=0)
182
+ elif audio_chunk.dim() == 2 and audio_chunk.shape[0] == 1:
183
+ audio_chunk = torch.cat([audio_chunk, audio_chunk], dim=0)
184
+ elif audio_chunk.dim() == 2 and audio_chunk.shape[0] != 2:
185
+ audio_chunk = audio_chunk[:1, :]
186
+ audio_chunk = torch.cat([audio_chunk, audio_chunk], dim=0)
187
+ elif audio_chunk.dim() > 2:
188
+ audio_chunk = audio_chunk.view(2, -1)
189
+
190
+ if audio_chunk.shape[0] != 2:
191
+ raise ValueError(f"Expected stereo audio with shape (2, samples), got shape {audio_chunk.shape}")
192
+
193
+ temp_wav_path = f"temp_chunk_{var}_{i}.wav"
194
+ chunk_path = f"chunk_{var}_{i}.mp3"
195
+ torchaudio.save(temp_wav_path, audio_chunk, sample_rate, bits_per_sample=24)
196
+ segment = AudioSegment.from_wav(temp_wav_path)
197
+ segment.export(chunk_path, format="mp3", bitrate="320k")
198
+ os.remove(temp_wav_path)
199
+ audio_chunks.append(chunk_path)
200
+
201
+ torch.cuda.empty_cache()
202
+ gc.collect()
203
+ time.sleep(0.5)
204
+ print_resource_usage(f"After Chunk {i+1} Generation (Variation {var+1})")
205
+
206
+ print(f"Combining audio chunks for variation {var+1}...")
207
+ final_segment = AudioSegment.from_mp3(audio_chunks[0])
208
+ for i in range(1, len(audio_chunks)):
209
+ next_segment = AudioSegment.from_mp3(audio_chunks[i])
210
+ next_segment = next_segment + 1
211
+ final_segment = final_segment.append(next_segment, crossfade=crossfade_duration)
212
+
213
+ final_segment = final_segment[:total_duration * 1000]
214
+ else:
215
+ # Single-shot generation
216
+ print(f"Generating full track for variation {var+1} on GPU (prompt: {instrumental_prompt})...")
217
  musicgen_model.set_generation_params(
218
+ duration=total_duration,
219
  use_sampling=True,
220
  top_k=top_k,
221
  top_p=top_p,
 
223
  cfg_coef=cfg_scale
224
  )
225
 
226
+ print_resource_usage(f"Before Full Track Generation (Variation {var+1})")
227
 
228
  with torch.no_grad():
229
  with autocast():
230
+ audio_chunk = musicgen_model.generate([instrumental_prompt], progress=True)[0]
231
 
232
  audio_chunk = audio_chunk.cpu().to(dtype=torch.float32)
233
  if audio_chunk.dim() == 1:
 
243
  if audio_chunk.shape[0] != 2:
244
  raise ValueError(f"Expected stereo audio with shape (2, samples), got shape {audio_chunk.shape}")
245
 
246
+ temp_wav_path = f"temp_full_{var}.wav"
 
247
  torchaudio.save(temp_wav_path, audio_chunk, sample_rate, bits_per_sample=24)
248
+ final_segment = AudioSegment.from_wav(temp_wav_path)
 
249
  os.remove(temp_wav_path)
 
250
 
251
  torch.cuda.empty_cache()
252
  gc.collect()
253
  time.sleep(0.5)
254
+ print_resource_usage(f"After Full Track Generation (Variation {var+1})")
 
 
 
 
 
 
 
 
 
255
 
256
  print(f"Post-processing final track for variation {var+1}...")
257
  final_segment = apply_eq(final_segment)
 
270
  print(f"Saved final audio to {mp3_path}")
271
  output_files.append(mp3_path)
272
 
273
+ if use_chunks:
274
+ for chunk_path in audio_chunks:
275
+ os.remove(chunk_path)
276
 
277
  print_resource_usage("After Final Generation")
278
  print(f"Total Generation Time: {time.time() - start_time:.2f} seconds")
 
285
  torch.cuda.empty_cache()
286
  gc.collect()
287
 
288
+ # Function to toggle crossfade_duration interactivity
289
+ def toggle_crossfade_interactivity(use_chunks):
290
+ return gr.update(interactive=use_chunks)
291
+
292
  def clear_inputs():
293
+ return "", 3.0, 250, 0.9, 1.0, 30, 500, 1, True
294
 
295
  # 7) CUSTOM CSS
296
  css = """
 
439
  maximum=2000,
440
  value=500,
441
  step=100,
442
+ info="Crossfade duration between chunks (only used if chunking is enabled)."
443
  )
444
  num_variations = gr.Slider(
445
  label="Number of Variations",
 
449
  step=1,
450
  info="Number of different versions to generate with varying random seeds."
451
  )
452
+ use_chunks = gr.Checkbox(
453
+ label="Generate in Chunks",
454
+ value=True,
455
+ info="Enable to generate in 15-second chunks (safer for GPU memory). Disable for single-shot generation (higher VRAM usage)."
456
+ )
457
  with gr.Row(elem_classes="action-buttons"):
458
  gen_btn = gr.Button("Generate Music")
459
  clr_btn = gr.Button("Clear Inputs")
 
462
  out_audio = gr.Audio(label="Generated Stereo Instrumental Track", type="filepath")
463
  status = gr.Textbox(label="Status", interactive=False)
464
 
465
+ # Toggle crossfade_duration interactivity
466
+ use_chunks.change(fn=toggle_crossfade_interactivity, inputs=use_chunks, outputs=crossfade_duration)
467
+
468
  rhcp_btn.click(set_red_hot_chili_peppers_prompt, inputs=None, outputs=[instrumental_prompt])
469
  nirvana_btn.click(set_nirvana_grunge_prompt, inputs=None, outputs=[instrumental_prompt])
470
  pearl_jam_btn.click(set_pearl_jam_grunge_prompt, inputs=None, outputs=[instrumental_prompt])
 
481
  deep_house_btn.click(set_deep_house_prompt, inputs=None, outputs=[instrumental_prompt])
482
  gen_btn.click(
483
  generate_music,
484
+ inputs=[instrumental_prompt, cfg_scale, top_k, top_p, temperature, total_duration, crossfade_duration, num_variations, use_chunks],
485
  outputs=[out_audio, status]
486
  )
487
  clr_btn.click(
488
  clear_inputs,
489
  inputs=None,
490
+ outputs=[instrumental_prompt, cfg_scale, top_k, top_p, temperature, total_duration, crossfade_duration, num_variations, use_chunks]
491
  )
492
 
493
  # 9) TURN OFF OPENAPI/DOCS