ghostai1 commited on
Commit
79ac61a
Β·
verified Β·
1 Parent(s): 5852ce4

Create barks.py

Browse files

Vocal small pair to medium off loads vram etc

Files changed (1) hide show
  1. barks.py +639 -0
barks.py ADDED
@@ -0,0 +1,639 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import torchaudio
4
+ import psutil
5
+ import time
6
+ import sys
7
+ import numpy as np
8
+ import gc
9
+ import gradio as gr
10
+ from pydub import AudioSegment
11
+ from audiocraft.models import MusicGen
12
+ from torch.cuda.amp import autocast
13
+ import warnings
14
+ import random
15
+ from transformers import AutoProcessor, BarkModel
16
+ from accelerate import Accelerator
17
+
18
+ # Suppress warnings for cleaner output
19
+ warnings.filterwarnings("ignore")
20
+
21
+ # Set PYTORCH_CUDA_ALLOC_CONF to manage memory fragmentation
22
+ os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128"
23
+
24
+ # Check critical dependencies
25
+ if np.__version__ != "1.23.5":
26
+ print(f"WARNING: NumPy version {np.__version__} is being used. Tested with numpy==1.23.5.")
27
+ if not torch.__version__.startswith(("2.1.0", "2.3.1")):
28
+ print(f"WARNING: PyTorch version {torch.__version__} may not be compatible. Expected torch==2.1.0 or 2.3.1.")
29
+
30
+ # 1) DEVICE SETUP
31
+ device = "cuda" if torch.cuda.is_available() else "cpu"
32
+ if device != "cuda":
33
+ print("ERROR: CUDA is required for GPU rendering. CPU rendering is disabled.")
34
+ sys.exit(1)
35
+ print(f"CUDA is available. Using GPU: {torch.cuda.get_device_name(0)}")
36
+
37
+ # Initialize accelerator for offloading
38
+ accelerator = Accelerator(mixed_precision="fp16")
39
+
40
+ # Pre-run memory cleanup
41
+ torch.cuda.empty_cache()
42
+ gc.collect()
43
+ torch.cuda.ipc_collect()
44
+ torch.cuda.synchronize()
45
+
46
+ # 2) LOAD MODELS
47
+ try:
48
+ print("Loading MusicGen medium model into VRAM...")
49
+ local_model_path = "./models/musicgen-medium"
50
+ if not os.path.exists(local_model_path):
51
+ print(f"ERROR: Local model path {local_model_path} does not exist.")
52
+ print("Please download the MusicGen medium model weights and place them in the correct directory.")
53
+ sys.exit(1)
54
+ musicgen_model = MusicGen.get_pretrained(local_model_path, device=device)
55
+ musicgen_model.set_generation_params(
56
+ duration=10, # Default chunk duration
57
+ two_step_cfg=False # Disable two-step CFG for stability
58
+ )
59
+ except Exception as e:
60
+ print(f"ERROR: Failed to load MusicGen model: {e}")
61
+ print("Ensure model weights are correctly placed and dependencies are installed.")
62
+ sys.exit(1)
63
+
64
+ try:
65
+ print("Loading Bark small model into system RAM...")
66
+ bark_processor = AutoProcessor.from_pretrained("suno/bark-small")
67
+ bark_model = BarkModel.from_pretrained("suno/bark-small")
68
+ bark_model = bark_model.to("cpu") # Offload to CPU initially
69
+ except Exception as e:
70
+ print(f"ERROR: Failed to load Bark model: {e}")
71
+ print("Ensure Bark model weights are available and dependencies are installed.")
72
+ sys.exit(1)
73
+
74
+ # 3) RESOURCE MONITORING FUNCTION
75
+ def print_resource_usage(stage: str):
76
+ print(f"--- {stage} ---")
77
+ print(f"GPU Memory Allocated: {torch.cuda.memory_allocated() / (1024**3):.2f} GB")
78
+ print(f"GPU Memory Reserved: {torch.cuda.memory_reserved() / (1024**3):.2f} GB")
79
+ print(f"CPU Memory Used: {psutil.virtual_memory().percent}%")
80
+ print("---------------")
81
+
82
+ # Check available GPU memory
83
+ def check_vram_availability(required_gb=4.5): # Adjusted for MusicGen + Bark
84
+ total_vram = torch.cuda.get_device_properties(0).total_memory / (1024**3)
85
+ allocated_vram = torch.cuda.memory_allocated() / (1024**3)
86
+ available_vram = total_vram - allocated_vram
87
+ if available_vram < required_gb:
88
+ print(f"WARNING: Low VRAM available ({available_vram:.2f} GB). Reduce total_duration or chunk_duration.")
89
+ return available_vram >= required_gb
90
+
91
+ # 4) GENRE PROMPT FUNCTIONS
92
+ def set_red_hot_chili_peppers_prompt(bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style):
93
+ rhythm = f" with {rhythmic_steps}" if rhythmic_steps != "none" else ("strong rhythmic steps" if bpm > 120 else "groovy rhythmic flow")
94
+ drum = f", {drum_beat} drums" if drum_beat != "none" else ""
95
+ synth = f", {synthesizer} accents" if synthesizer != "none" else ""
96
+ bass = f", {bass_style}" if bass_style != "none" else ", groovy basslines"
97
+ guitar = f", {guitar_style} guitar riffs" if guitar_style != "none" else ", syncopated guitar riffs"
98
+ return f"Instrumental funk rock{bass}{guitar}{drum}{synth}, Red Hot Chili Peppers-inspired vibe with dynamic energy and funky breakdowns, { rhythm} at {bpm} BPM."
99
+
100
+ def set_nirvana_grunge_prompt(bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style):
101
+ rhythm = f" with {rhythmic_steps}" if rhythmic_steps != "none" else ("intense rhythmic steps" if bpm > 120 else "grungy rhythmic pulse")
102
+ drum = f", {drum_beat} drums" if drum_beat != "none" else ""
103
+ synth = f", {synthesizer} accents" if synthesizer != "none" else ""
104
+ bass = f", {bass_style}" if bass_style != "none" else ", melodic basslines"
105
+ guitar = f", {guitar_style} guitar riffs" if guitar_style != "none" else ", raw distorted guitar riffs"
106
+ return f"Instrumental grunge{bass}{guitar}{drum}{synth}, Nirvana-inspired angst-filled sound with quiet-loud dynamics, {rhythm} at {bpm} BPM."
107
+
108
+ def set_pearl_jam_grunge_prompt(bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style):
109
+ rhythm = f" with {rhythmic_steps}" if rhythmic_steps != "none" else ("soulful rhythmic steps" if bpm > 120 else "driving rhythmic flow")
110
+ drum = f", {drum_beat} drums" if drum_beat != "none" else ""
111
+ synth = f", {synthesizer} accents" if synthesizer != "none" else ""
112
+ bass = f", {bass_style}" if bass_style != "none" else ", deep bass"
113
+ guitar = f", {guitar_style} guitar leads" if guitar_style != "none" else ", soulful guitar leads"
114
+ return f"Instrumental grunge{bass}{guitar}{drum}{synth}, Pearl Jam-inspired emotional intensity with soaring choruses, {rhythm} at {bpm} BPM."
115
+
116
+ def set_soundgarden_grunge_prompt(bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style):
117
+ rhythm = f" with {rhythmic_steps}" if rhythmic_steps != "none" else ("heavy rhythmic steps" if bpm > 120 else "sludgy rhythmic groove")
118
+ drum = f", {drum_beat} drums" if drum_beat != "none" else ""
119
+ synth = f", {synthesizer} accents" if synthesizer != "none" else ""
120
+ bass = f", {bass_style}" if bass_style != "none" else ""
121
+ guitar = f", {guitar_style} guitar riffs" if guitar_style != "none" else ", heavy sludgy guitar riffs"
122
+ return f"Instrumental grunge{bass}{guitar}{drum}{synth}, Soundgarden-inspired dark, psychedelic edge, {rhythm} at {bpm} BPM."
123
+
124
+ def set_foo_fighters_prompt(bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style):
125
+ styles = ["anthemic", "gritty", "melodic", "fast-paced", "driving"]
126
+ tempos = ["upbeat", "mid-tempo", "high-energy"]
127
+ moods = ["energetic", "introspective", "rebellious", "uplifting"]
128
+ style = random.choice(styles)
129
+ tempo = random.choice(tempos)
130
+ mood = random.choice(moods)
131
+ rhythm = f" with {rhythmic_steps}" if rhythmic_steps != "none" else ("powerful rhythmic steps" if bpm > 120 else "catchy rhythmic groove")
132
+ drum = f", {drum_beat} drums" if drum_beat != "none" else ""
133
+ synth = f", {synthesizer} accents" if synthesizer != "none" else ""
134
+ bass = f", {bass_style}" if bass_style != "none" else ""
135
+ guitar = f", {guitar_style} guitar riffs" if guitar_style != "none" else f", {style} guitar riffs"
136
+ return f"Instrumental alternative rock{bass}{guitar}{drum}{synth}, Foo Fighters-inspired {mood} vibe with powerful choruses, {rhythm} at {bpm} BPM."
137
+
138
+ def set_smashing_pumpkins_prompt(bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style):
139
+ rhythm = f" with {rhythmic_steps}" if rhythmic_steps != "none" else ("dynamic rhythmic steps" if bpm > 120 else "dreamy rhythmic flow")
140
+ drum = f", {drum_beat} drums" if drum_beat != "none" else ""
141
+ synth = f", {synthesizer} accents" if synthesizer != "none" else ""
142
+ bass = f", {bass_style}" if bass_style != "none" else ""
143
+ guitar = f", {guitar_style} guitar textures" if guitar_style != "none" else ", dreamy guitar textures"
144
+ return f"Instrumental alternative rock{bass}{guitar}{drum}{synth}, Smashing Pumpkins-inspired blend of melancholy and aggression, {rhythm} at {bpm} BPM."
145
+
146
+ def set_radiohead_prompt(bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style):
147
+ rhythm = f" with {rhythmic_steps}" if rhythmic_steps != "none" else ("complex rhythmic steps" if bpm > 120 else "intricate rhythmic pulse")
148
+ drum = f", {drum_beat} drums" if drum_beat != "none" else ""
149
+ synth = f", {synthesizer} accents" if synthesizer != "none" else ", atmospheric synths"
150
+ bass = f", {bass_style}" if bass_style != "none" else ""
151
+ guitar = f", {guitar_style} guitar layers" if guitar_style != "none" else ", intricate guitar layers"
152
+ return f"Instrumental experimental rock{bass}{guitar}{drum}{synth}, Radiohead-inspired blend of introspective and innovative soundscapes, {rhythm} at {bpm} BPM."
153
+
154
+ def set_classic_rock_prompt(bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style):
155
+ rhythm = f" with {rhythmic_steps}" if rhythmic_steps != "none" else ("bluesy rhythmic steps" if bpm > 120 else "steady rhythmic groove")
156
+ drum = f", {drum_beat} drums" if drum_beat != "none" else ""
157
+ synth = f", {synthesizer} accents" if synthesizer != "none" else ""
158
+ bass = f", {bass_style}" if bass_style != "none" else ", groovy bass"
159
+ guitar = f", {guitar_style} electric guitars" if guitar_style != "none" else ", bluesy electric guitars"
160
+ return f"Instrumental classic rock{bass}{guitar}{drum}{synth}, Led Zeppelin-inspired raw energy with dynamic solos, {rhythm} at {bpm} BPM."
161
+
162
+ def set_alternative_rock_prompt(bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style):
163
+ rhythm = f" with {rhythmic_steps}" if rhythmic_steps != "none" else ("quirky rhythmic steps" if bpm > 120 else "energetic rhythmic flow")
164
+ drum = f", {drum_beat} drums" if drum_beat != "none" else ""
165
+ synth = f", {synthesizer} accents" if synthesizer != "none" else ""
166
+ bass = f", {bass_style}" if bass_style != "none" else ", melodic basslines"
167
+ guitar = f", {guitar_style} guitar riffs" if guitar_style != "none" else ", distorted guitar riffs"
168
+ return f"Instrumental alternative rock{bass}{guitar}{drum}{synth}, Pixies-inspired quirky, energetic vibe, {rhythm} at {bpm} BPM."
169
+
170
+ def set_post_punk_prompt(bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style):
171
+ rhythm = f" with {rhythmic_steps}" if rhythmic_steps != "none" else ("sharp rhythmic steps" if bpm > 120 else "moody rhythmic pulse")
172
+ drum = f", {drum_beat} drums" if drum_beat != "none" else ""
173
+ synth = f", {synthesizer} accents" if synthesizer != "none" else ""
174
+ bass = f", {bass_style}" if bass_style != "none" else ", driving basslines"
175
+ guitar = f", {guitar_style} guitars" if guitar_style != "none" else ", jangly guitars"
176
+ return f"Instrumental post-punk{bass}{guitar}{drum}{synth}, Joy Division-inspired moody, atmospheric sound, {rhythm} at {bpm} BPM."
177
+
178
+ def set_indie_rock_prompt(bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style):
179
+ rhythm = f" with {rhythmic_steps}" if rhythmic_steps != "none" else ("catchy rhythmic steps" if bpm > 120 else "jangly rhythmic flow")
180
+ drum = f", {drum_beat} drums" if drum_beat != "none" else ""
181
+ synth = f", {synthesizer} accents" if synthesizer != "none" else ""
182
+ bass = f", {bass_style}" if bass_style != "none" else ""
183
+ guitar = f", {guitar_style} guitars" if guitar_style != "none" else ", jangly guitars"
184
+ return f"Instrumental indie rock{bass}{guitar}{drum}{synth}, Arctic Monkeys-inspired blend of catchy riffs, {rhythm} at {bpm} BPM."
185
+
186
+ def set_funk_rock_prompt(bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style):
187
+ rhythm = f" with {rhythmic_steps}" if rhythmic_steps != "none" else ("aggressive rhythmic steps" if bpm > 120 else "funky rhythmic groove")
188
+ drum = f", {drum_beat} drums" if drum_beat != "none" else ""
189
+ synth = f", {synthesizer} accents" if synthesizer != "none" else ""
190
+ bass = f", {bass_style}" if bass_style != "none" else ", slap bass"
191
+ guitar = f", {guitar_style} guitar chords" if guitar_style != "none" else ", funky guitar chords"
192
+ return f"Instrumental funk rock{bass}{guitar}{drum}{synth}, Rage Against the Machine-inspired mix of groove and aggression, {rhythm} at {bpm} BPM."
193
+
194
+ def set_detroit_techno_prompt(bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style):
195
+ rhythm = f" with {rhythmic_steps}" if rhythmic_steps != "none" else ("pulsing rhythmic steps" if bpm > 120 else "deep rhythmic groove")
196
+ drum = f", {drum_beat} drums" if drum_beat != "none" else ", crisp hi-hats"
197
+ synth = f", {synthesizer} accents" if synthesizer != "none" else ", deep pulsing synths"
198
+ bass = f", {bass_style}" if bass_style != "none" else ", driving basslines"
199
+ guitar = f", {guitar_style} guitars" if guitar_style != "none" else ""
200
+ return f"Instrumental Detroit techno{bass}{guitar}{drum}{synth}, Juan Atkins-inspired rhythmic groove, {rhythm} at {bpm} BPM."
201
+
202
+ def set_deep_house_prompt(bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style):
203
+ rhythm = f" with {rhythmic_steps}" if rhythmic_steps != "none" else ("soulful rhythmic steps" if bpm > 120 else "laid-back rhythmic flow")
204
+ drum = f", {drum_beat} drums" if drum_beat != "none" else ""
205
+ synth = f", {synthesizer} accents" if synthesizer != "none" else ", warm analog synth chords"
206
+ bass = f", {bass_style}" if bass_style != "none" else ", deep basslines"
207
+ guitar = f", {guitar_style} guitars" if guitar_style != "none" else ""
208
+ return f"Instrumental deep house{bass}{guitar}{drum}{synth}, Larry Heard-inspired laid-back groove, {rhythm} at {bpm} BPM."
209
+
210
+ # 5) AUDIO PROCESSING FUNCTIONS
211
+ def apply_eq(segment):
212
+ segment = segment.low_pass_filter(8000)
213
+ segment = segment.high_pass_filter(80)
214
+ return segment
215
+
216
+ def apply_fade(segment, fade_in_duration=1000, fade_out_duration=1000):
217
+ segment = segment.fade_in(fade_in_duration)
218
+ segment = segment.fade_out(fade_out_duration)
219
+ return segment
220
+
221
+ def generate_vocals(vocal_prompt: str, total_duration: int):
222
+ global bark_model, bark_processor
223
+ if not vocal_prompt.strip():
224
+ return None, "⚠️ Please enter a valid vocal prompt!"
225
+
226
+ try:
227
+ print("Generating vocals with Bark...")
228
+ # Move Bark model to GPU
229
+ bark_model = bark_model.to(accelerator.device)
230
+
231
+ # Process vocal prompt
232
+ inputs = bark_processor(vocal_prompt, return_tensors="pt").to(accelerator.device)
233
+
234
+ # Generate vocals with mixed precision
235
+ with torch.no_grad(), autocast():
236
+ vocal_array = bark_model.generate(**inputs, do_sample=True)
237
+
238
+ # Convert to numpy and create AudioSegment
239
+ vocal_array = vocal_array.cpu().numpy().squeeze()
240
+ sample_rate = bark_model.config.sampling_rate
241
+ temp_vocal_path = "temp_vocal.wav"
242
+ torchaudio.save(temp_vocal_path, torch.tensor(vocal_array).unsqueeze(0), sample_rate)
243
+ vocal_segment = AudioSegment.from_wav(temp_vocal_path)
244
+ os.remove(temp_vocal_path)
245
+
246
+ # Trim or pad to match total_duration
247
+ vocal_segment = vocal_segment[:total_duration * 1000]
248
+ if len(vocal_segment) < total_duration * 1000:
249
+ vocal_segment = vocal_segment + AudioSegment.silent(duration=(total_duration * 1000 - len(vocal_segment)))
250
+
251
+ # Move Bark model back to CPU
252
+ bark_model = bark_model.to("cpu")
253
+ torch.cuda.empty_cache()
254
+
255
+ return vocal_segment, "βœ… Vocals generated successfully."
256
+ except Exception as e:
257
+ return None, f"❌ Vocal generation failed: {e}"
258
+
259
+ # 6) GENERATION & I/O FUNCTIONS
260
+ def generate_music(instrumental_prompt: str, vocal_prompt: str, cfg_scale: float, top_k: int, top_p: float, temperature: float, total_duration: int, chunk_duration: int, crossfade_duration: int, bpm: int, drum_beat: str, synthesizer: str, rhythmic_steps: str, bass_style: str, guitar_style: str):
261
+ global musicgen_model
262
+ if not instrumental_prompt.strip():
263
+ return None, "⚠️ Please enter a valid instrumental prompt!"
264
+
265
+ try:
266
+ start_time = time.time()
267
+ total_duration = total_duration # Validated by radio button (30, 60, 90, 120)
268
+ chunk_duration = min(max(chunk_duration, 5), 15)
269
+ num_chunks = max(1, total_duration // chunk_duration)
270
+ chunk_duration = total_duration / num_chunks
271
+ overlap_duration = min(1.0, crossfade_duration / 1000.0)
272
+ generation_duration = chunk_duration + overlap_duration
273
+ sample_rate = musicgen_model.sample_rate
274
+ audio_segments = []
275
+
276
+ if not check_vram_availability(required_gb=4.5):
277
+ return None, "⚠️ Insufficient VRAM for generation. Reduce total_duration or chunk_duration."
278
+
279
+ print("Generating instrumental audio...")
280
+ seed = 42
281
+ torch.manual_seed(seed)
282
+ np.random.seed(seed)
283
+
284
+ for i in range(num_chunks):
285
+ chunk_prompt = instrumental_prompt
286
+ print(f"Generating chunk {i+1}/{num_chunks} on GPU (prompt: {chunk_prompt})...")
287
+ musicgen_model.set_generation_params(
288
+ duration=generation_duration,
289
+ use_sampling=True,
290
+ top_k=top_k,
291
+ top_p=top_p,
292
+ temperature=temperature,
293
+ cfg_coef=cfg_scale
294
+ )
295
+
296
+ print_resource_usage(f"Before Chunk {i+1} Generation")
297
+
298
+ with torch.no_grad():
299
+ with autocast():
300
+ audio_chunk = musicgen_model.generate([chunk_prompt], progress=True)[0]
301
+
302
+ audio_chunk = audio_chunk.cpu().to(dtype=torch.float32)
303
+ if audio_chunk.dim() == 1:
304
+ audio_chunk = torch.stack([audio_chunk, audio_chunk], dim=0)
305
+ elif audio_chunk.dim() == 2 and audio_chunk.shape[0] == 1:
306
+ audio_chunk = torch.cat([audio_chunk, audio_chunk], dim=0)
307
+ elif audio_chunk.dim() == 2 and audio_chunk.shape[0] != 2:
308
+ audio_chunk = audio_chunk[:1, :]
309
+ audio_chunk = torch.cat([audio_chunk, audio_chunk], dim=0)
310
+ elif audio_chunk.dim() > 2:
311
+ audio_chunk = audio_chunk.view(2, -1)
312
+
313
+ if audio_chunk.shape[0] != 2:
314
+ raise ValueError(f"Expected stereo audio with shape (2, samples), got shape {audio_chunk.shape}")
315
+
316
+ temp_wav_path = f"temp_chunk_{i}.wav"
317
+ torchaudio.save(temp_wav_path, audio_chunk, sample_rate, bits_per_sample=24)
318
+ segment = AudioSegment.from_wav(temp_wav_path)
319
+ os.remove(temp_wav_path)
320
+ audio_segments.append(segment)
321
+
322
+ torch.cuda.empty_cache()
323
+ gc.collect()
324
+ torch.cuda.ipc_collect()
325
+ torch.cuda.synchronize()
326
+ time.sleep(0.5)
327
+ print_resource_usage(f"After Chunk {i+1} Generation")
328
+
329
+ print("Combining instrumental chunks...")
330
+ final_segment = audio_segments[0]
331
+ for i in range(1, len(audio_segments)):
332
+ next_segment = audio_segments[i]
333
+ next_segment = next_segment + 1
334
+ final_segment = final_segment.append(next_segment, crossfade=crossfade_duration)
335
+
336
+ final_segment = final_segment[:total_duration * 1000]
337
+
338
+ # Generate vocals if provided
339
+ if vocal_prompt.strip():
340
+ vocal_segment, vocal_status = generate_vocals(vocal_prompt, total_duration)
341
+ if vocal_segment is None:
342
+ return None, vocal_status
343
+ print("Mixing vocals with instrumental...")
344
+ final_segment = final_segment.overlay(vocal_segment, gain_during_overlay=-6) # Adjust vocal volume
345
+
346
+ print("Post-processing final track...")
347
+ final_segment = apply_eq(final_segment)
348
+ final_segment = final_segment.normalize(headroom=-9.0)
349
+ final_segment = apply_fade(final_segment)
350
+
351
+ mp3_path = "output_cleaned.mp3"
352
+ final_segment.export(
353
+ mp3_path,
354
+ format="mp3",
355
+ bitrate="128k",
356
+ tags={"title": "GhostAI Song", "artist": "GhostAI"}
357
+ )
358
+ print(f"Saved final audio to {mp3_path}")
359
+
360
+ print_resource_usage("After Final Generation")
361
+ print(f"Total Generation Time: {time.time() - start_time:.2f} seconds")
362
+
363
+ return mp3_path, "βœ… Done! Generated song with vocals." if vocal_prompt.strip() else "βœ… Done! Generated instrumental audio."
364
+ except Exception as e:
365
+ return None, f"❌ Generation failed: {e}"
366
+ finally:
367
+ torch.cuda.empty_cache()
368
+ gc.collect()
369
+ torch.cuda.ipc_collect()
370
+ torch.cuda.synchronize()
371
+
372
+ # Function to clear inputs
373
+ def clear_inputs():
374
+ return "", "", 3.0, 250, 0.9, 1.0, 30, 10, 1000, 120, "none", "none", "none", "none", "none"
375
+
376
+ # 7) CUSTOM CSS
377
+ css = """
378
+ body {
379
+ background: linear-gradient(135deg, #0A0A0A 0%, #1C2526 100%);
380
+ color: #E0E0E0;
381
+ font-family: 'Orbitron', sans-serif;
382
+ }
383
+ .header-container {
384
+ text-align: center;
385
+ padding: 10px 20px;
386
+ background: rgba(0, 0, 0, 0.9);
387
+ border-bottom: 1px solid #00FF9F;
388
+ }
389
+ #ghost-logo {
390
+ font-size: 40px;
391
+ animation: glitch-ghost 1.5s infinite;
392
+ }
393
+ h1 {
394
+ color: #A100FF;
395
+ font-size: 24px;
396
+ animation: glitch-text 2s infinite;
397
+ }
398
+ p {
399
+ color: #E0E0E0;
400
+ font-size: 12px;
401
+ }
402
+ .input-container, .settings-container, .output-container {
403
+ max-width: 1200px;
404
+ margin: 20px auto;
405
+ padding: 20px;
406
+ background: rgba(28, 37, 38, 0.8);
407
+ border-radius: 10px;
408
+ }
409
+ .textbox {
410
+ background: #1A1A1A;
411
+ border: 1px solid #A100FF;
412
+ color: #E0E0E0;
413
+ }
414
+ .genre-buttons {
415
+ display: flex;
416
+ justify-content: center;
417
+ flex-wrap: wrap;
418
+ gap: 15px;
419
+ }
420
+ .genre-btn, button {
421
+ background: linear-gradient(45deg, #A100FF, #00FF9F);
422
+ border: none;
423
+ color: #0A0A0A;
424
+ padding: 10px 20px;
425
+ border-radius: 5px;
426
+ }
427
+ .gradio-container {
428
+ padding: 20px;
429
+ }
430
+ .group-container {
431
+ margin-bottom: 20px;
432
+ padding: 15px;
433
+ border: 1px solid #00FF9F;
434
+ border-radius: 8px;
435
+ }
436
+ @keyframes glitch-ghost {
437
+ 0% { transform: translate(0, 0); opacity: 1; }
438
+ 20% { transform: translate(-5px, 2px); opacity: 0.8; }
439
+ 100% { transform: translate(0, 0); opacity: 1; }
440
+ }
441
+ @keyframes glitch-text {
442
+ 0% { transform: translate(0, 0); }
443
+ 20% { transform: translate(-2px, 1px); }
444
+ 100% { transform: translate(0, 0); }
445
+ }
446
+ @font-face {
447
+ font-family: 'Orbitron';
448
+ src: url('https://fonts.gstatic.com/s/orbitron/v29/yMJRMIlzdpvBhQQL_Qq7dy0.woff2') format('woff2');
449
+ }
450
+ """
451
+
452
+ # 8) BUILD WITH BLOCKS
453
+ with gr.Blocks(css=css) as demo:
454
+ gr.Markdown("""
455
+ <div class="header-container">
456
+ <div id="ghost-logo">πŸ‘»</div>
457
+ <h1>GhostAI Music Generator 🎹</h1>
458
+ <p>Summon the Sound of the Unknown</p>
459
+ </div>
460
+ """)
461
+
462
+ with gr.Column(elem_classes="input-container"):
463
+ gr.Markdown("### 🎸 Prompt Settings")
464
+ instrumental_prompt = gr.Textbox(
465
+ label="Instrumental Prompt ✍️",
466
+ placeholder="Click a genre button or type your own instrumental prompt",
467
+ lines=4,
468
+ elem_classes="textbox"
469
+ )
470
+ vocal_prompt = gr.Textbox(
471
+ label="Vocal Prompt 🎀",
472
+ placeholder="Enter song lyrics or vocal description (e.g., 'Upbeat pop, male voice, singing about freedom')",
473
+ lines=4,
474
+ elem_classes="textbox"
475
+ )
476
+ with gr.Row(elem_classes="genre-buttons"):
477
+ rhcp_btn = gr.Button("Red Hot Chili Peppers 🌢️", elem_classes="genre-btn")
478
+ nirvana_btn = gr.Button("Nirvana Grunge 🎸", elem_classes="genre-btn")
479
+ pearl_jam_btn = gr.Button("Pearl Jam Grunge πŸ¦ͺ", elem_classes="genre-btn")
480
+ soundgarden_btn = gr.Button("Soundgarden Grunge πŸŒ‘", elem_classes="genre-btn")
481
+ foo_fighters_btn = gr.Button("Foo Fighters 🀘", elem_classes="genre-btn")
482
+ smashing_pumpkins_btn = gr.Button("Smashing Pumpkins πŸŽƒ", elem_classes="genre-btn")
483
+ radiohead_btn = gr.Button("Radiohead 🧠", elem_classes="genre-btn")
484
+ classic_rock_btn = gr.Button("Classic Rock 🎸", elem_classes="genre-btn")
485
+ alternative_rock_btn = gr.Button("Alternative Rock 🎡", elem_classes="genre-btn")
486
+ post_punk_btn = gr.Button("Post-Punk πŸ–€", elem_classes="genre-btn")
487
+ indie_rock_btn = gr.Button("Indie Rock 🎀", elem_classes="genre-btn")
488
+ funk_road_btn = gr.Button("Funk Rock πŸ•Ί", elem_classes="genre-btn")
489
+ detroit_techno_btn = gr.Button("Detroit Techno πŸŽ›οΈ", elem_classes="genre-btn")
490
+ deep_house_btn = gr.Button("Deep House 🏠", elem_classes="genre-btn")
491
+
492
+ with gr.Column(elem_classes="settings-container"):
493
+ gr.Markdown("### βš™οΈ API Settings")
494
+ with gr.Group(elem_classes="group-container"):
495
+ cfg_scale = gr.Slider(
496
+ label="CFG Scale 🎯",
497
+ minimum=1.0,
498
+ maximum=10.0,
499
+ value=3.0,
500
+ step=0.1,
501
+ info="Controls how closely the music follows the prompt."
502
+ )
503
+ top_k = gr.Slider(
504
+ label="Top-K Sampling πŸ”’",
505
+ minimum=10,
506
+ maximum=500,
507
+ value=250,
508
+ step=10,
509
+ info="Limits sampling to the top k most likely tokens."
510
+ )
511
+ top_p = gr.Slider(
512
+ label="Top-P Sampling 🎰",
513
+ minimum=0.0,
514
+ maximum=1.0,
515
+ value=0.9,
516
+ step=0.05,
517
+ info="Keeps tokens with cumulative probability above p."
518
+ )
519
+ temperature = gr.Slider(
520
+ label="Temperature πŸ”₯",
521
+ minimum=0.1,
522
+ maximum=2.0,
523
+ value=1.0,
524
+ step=0.1,
525
+ info="Controls randomness; higher values increase diversity."
526
+ )
527
+ total_duration = gr.Radio(
528
+ label="Song Length ⏳ (seconds)",
529
+ choices=[30, 60, 90, 120],
530
+ value=30,
531
+ info="Select the total duration of the track."
532
+ )
533
+ chunk_duration = gr.Slider(
534
+ label="Chunk Duration ⏱️ (seconds)",
535
+ minimum=5,
536
+ maximum=15,
537
+ value=10,
538
+ step=1,
539
+ info="Duration of each chunk to render (5 to 15 seconds)."
540
+ )
541
+ crossfade_duration = gr.Slider(
542
+ label="Crossfade Duration 🎢 (ms)",
543
+ minimum=100,
544
+ maximum=2000,
545
+ value=1000,
546
+ step=100,
547
+ info="Crossfade duration between chunks."
548
+ )
549
+
550
+ gr.Markdown("### 🎡 Musical Controls")
551
+ with gr.Group(elem_classes="group-container"):
552
+ bpm = gr.Slider(
553
+ label="Tempo 🎡 (BPM)",
554
+ minimum=60,
555
+ maximum=180,
556
+ value=120,
557
+ step=1,
558
+ info="Beats per minute to set the track's tempo."
559
+ )
560
+ drum_beat = gr.Dropdown(
561
+ label="Drum Beat πŸ₯",
562
+ choices=["none", "standard rock", "funk groove", "techno kick", "jazz swing"],
563
+ value="none",
564
+ info="Select a drum beat style to influence the rhythm."
565
+ )
566
+ synthesizer = gr.Dropdown(
567
+ label="Synthesizer 🎹",
568
+ choices=["none", "analog synth", "digital pad", "arpeggiated synth"],
569
+ value="none",
570
+ info="Select a synthesizer style for electronic accents."
571
+ )
572
+ rhythmic_steps = gr.Dropdown(
573
+ label="Rhythmic Steps πŸ‘£",
574
+ choices=["none", "syncopated steps", "steady steps", "complex steps"],
575
+ value="none",
576
+ info="Select a rhythmic step style to enhance the beat."
577
+ )
578
+ bass_style = gr.Dropdown(
579
+ label="Bass Style 🎸",
580
+ choices=["none", "slap bass", "deep bass", "melodic bass"],
581
+ value="none",
582
+ info="Select a bass style to shape the low end."
583
+ )
584
+ guitar_style = gr.Dropdown(
585
+ label="Guitar Style 🎸",
586
+ choices=["none", "distorted", "clean", "jangle"],
587
+ value="none",
588
+ info="Select a guitar style to define the riffs."
589
+ )
590
+
591
+ with gr.Row(elem_classes="action-buttons"):
592
+ gen_btn = gr.Button("Generate Music πŸš€")
593
+ clr_btn = gr.Button("Clear Inputs 🧹")
594
+
595
+ with gr.Column(elem_classes="output-container"):
596
+ gr.Markdown("### 🎧 Output")
597
+ out_audio = gr.Audio(label="Generated Song 🎡", type="filepath")
598
+ status = gr.Textbox(label="Status πŸ“’", interactive=False)
599
+
600
+ rhcp_btn.click(set_red_hot_chili_peppers_prompt, inputs=[bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style], outputs=instrumental_prompt)
601
+ nirvana_btn.click(set_nirvana_grunge_prompt, inputs=[bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style], outputs=instrumental_prompt)
602
+ pearl_jam_btn.click(set_pearl_jam_grunge_prompt, inputs=[bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style], outputs=instrumental_prompt)
603
+ soundgarden_btn.click(set_soundgarden_grunge_prompt, inputs=[bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style], outputs=instrumental_prompt)
604
+ foo_fighters_btn.click(set_foo_fighters_prompt, inputs=[bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style], outputs=instrumental_prompt)
605
+ smashing_pumpkins_btn.click(set_smashing_pumpkins_prompt, inputs=[bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style], outputs=instrumental_prompt)
606
+ radiohead_btn.click(set_radiohead_prompt, inputs=[bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style], outputs=instrumental_prompt)
607
+ classic_rock_btn.click(set_classic_rock_prompt, inputs=[bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style], outputs=instrumental_prompt)
608
+ alternative_rock_btn.click(set_alternative_rock_prompt, inputs=[bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style], outputs=instrumental_prompt)
609
+ post_punk_btn.click(set_post_punk_prompt, inputs=[bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style], outputs=instrumental_prompt)
610
+ indie_rock_btn.click(set_indie_rock_prompt, inputs=[bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style], outputs=instrumental_prompt)
611
+ funk_road_btn.click(set_funk_rock_prompt, inputs=[bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style], outputs=instrumental_prompt)
612
+ detroit_techno_btn.click(set_detroit_techno_prompt, inputs=[bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style], outputs=instrumental_prompt)
613
+ deep_house_btn.click(set_deep_house_prompt, inputs=[bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style], outputs=instrumental_prompt)
614
+ gen_btn.click(
615
+ generate_music,
616
+ inputs=[instrumental_prompt, vocal_prompt, cfg_scale, top_k, top_p, temperature, total_duration, chunk_duration, crossfade_duration, bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style],
617
+ outputs=[out_audio, status]
618
+ )
619
+ clr_btn.click(
620
+ clear_inputs,
621
+ inputs=None,
622
+ outputs=[instrumental_prompt, vocal_prompt, cfg_scale, top_k, top_p, temperature, total_duration, chunk_duration, crossfade_duration, bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style]
623
+ )
624
+
625
+ # 9) TURN OFF OPENAPI/DOCS
626
+ app = demo.launch(
627
+ server_name="0.0.0.0",
628
+ server_port=9999,
629
+ share=True,
630
+ inbrowser=False,
631
+ show_error=True
632
+ )
633
+ try:
634
+ fastapi_app = demo._server.app
635
+ fastapi_app.docs_url = None
636
+ fastapi_app.redoc_url = None
637
+ fastapi_app.openapi_url = None
638
+ except Exception:
639
+ pass