ghostai1 commited on
Commit
d121304
·
verified ·
1 Parent(s): 006933d

Create app.py

Browse files

API SDK from FB medium music added

Files changed (1) hide show
  1. app.py +428 -0
app.py ADDED
@@ -0,0 +1,428 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import torchaudio
4
+ import psutil
5
+ import time
6
+ import sys
7
+ import numpy as np
8
+ import gc
9
+ import gradio as gr
10
+ from pydub import AudioSegment
11
+ from audiocraft.models import MusicGen
12
+ from torch.cuda.amp import autocast
13
+ import warnings
14
+
15
+ # Suppress warnings for cleaner output
16
+ warnings.filterwarnings("ignore")
17
+
18
+ # Set PYTORCH_CUDA_ALLOC_CONF to manage memory fragmentation
19
+ os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128"
20
+
21
+ # Check critical dependencies
22
+ if np.__version__ != "1.23.5":
23
+ print(f"WARNING: NumPy version {np.__version__} is being used. Tested with numpy==1.23.5.")
24
+ if not torch.__version__.startswith(("2.1.0", "2.3.1")):
25
+ print(f"WARNING: PyTorch version {torch.__version__} may not be compatible. Expected torch==2.1.0 or 2.3.1.")
26
+
27
+ # 1) DEVICE SETUP
28
+ device = "cuda" if torch.cuda.is_available() else "cpu"
29
+ if device != "cuda":
30
+ print("ERROR: CUDA is required for GPU rendering. CPU rendering is disabled.")
31
+ sys.exit(1)
32
+ print(f"CUDA is available. Using GPU: {torch.cuda.get_device_name(0)}")
33
+
34
+ # 2) LOAD MUSICGEN INTO VRAM
35
+ try:
36
+ print("Loading MusicGen medium model into VRAM...")
37
+ local_model_path = "./models/musicgen-medium"
38
+ if not os.path.exists(local_model_path):
39
+ print(f"ERROR: Local model path {local_model_path} does not exist.")
40
+ print("Please download the MusicGen medium model weights and place them in the correct directory.")
41
+ sys.exit(1)
42
+ musicgen_model = MusicGen.get_pretrained(local_model_path, device=device)
43
+ musicgen_model.set_generation_params(
44
+ duration=15, # Default chunk duration
45
+ two_step_cfg=False # Disable two-step CFG for stability
46
+ )
47
+ except Exception as e:
48
+ print(f"ERROR: Failed to load MusicGen model: {e}")
49
+ print("Ensure model weights are correctly placed and dependencies are installed.")
50
+ sys.exit(1)
51
+
52
+ # 3) RESOURCE MONITORING FUNCTION
53
+ def print_resource_usage(stage: str):
54
+ print(f"--- {stage} ---")
55
+ print(f"GPU Memory Allocated: {torch.cuda.memory_allocated() / (1024**3):.2f} GB")
56
+ print(f"GPU Memory Reserved: {torch.cuda.memory_reserved() / (1024**3):.2f} GB")
57
+ print(f"CPU Memory Used: {psutil.virtual_memory().percent}%")
58
+ print("---------------")
59
+
60
+ # 4) GENRE PROMPT FUNCTIONS
61
+ def set_classic_rock_prompt():
62
+ return "Classic rock with bluesy electric guitars, steady drums, groovy bass, Hammond organ fills, and a Led Zeppelin-inspired raw energy."
63
+
64
+ def set_alternative_rock_prompt():
65
+ return "Alternative rock with distorted guitar riffs, punchy drums, melodic basslines, atmospheric synths, and a Nirvana-inspired grunge vibe."
66
+
67
+ def set_detroit_techno_prompt():
68
+ return "Detroit techno with deep pulsing synths, driving basslines, crisp hi-hats, and a rhythmic groove inspired by Juan Atkins."
69
+
70
+ def set_deep_house_prompt():
71
+ return "Deep house with warm analog synth chords, soulful vocal chops, deep basslines, and a laid-back groove inspired by Larry Heard."
72
+
73
+ def set_smooth_jazz_prompt():
74
+ return "Smooth jazz with warm saxophone leads, expressive Rhodes piano chords, soft bossa nova drums, and a George Benson-inspired feel."
75
+
76
+ def set_bebop_jazz_prompt():
77
+ return "Bebop jazz with fast-paced saxophone solos, intricate piano runs, walking basslines, and a Charlie Parker-inspired style."
78
+
79
+ def set_baroque_classical_prompt():
80
+ return "Baroque classical with harpsichord, delicate violin, cello, and a Vivaldi-inspired melodic structure."
81
+
82
+ def set_romantic_classical_prompt():
83
+ return "Romantic classical with lush strings, expressive piano, dramatic brass, and a Chopin-inspired melodic flow."
84
+
85
+ def set_boom_bap_hiphop_prompt():
86
+ return "Boom bap hip-hop with gritty sampled drums, deep basslines, jazzy piano loops, and a J Dilla-inspired groove."
87
+
88
+ def set_trap_hiphop_prompt():
89
+ return "Trap hip-hop with hard-hitting 808 bass, snappy snares, rapid hi-hats, and eerie synth melodies."
90
+
91
+ def set_pop_rock_prompt():
92
+ return "Pop rock with catchy electric guitar riffs, uplifting synths, steady drums, and a Coldplay-inspired anthemic feel."
93
+
94
+ def set_fusion_jazz_prompt():
95
+ return "Fusion jazz with electric piano, funky basslines, intricate drum patterns, and a Herbie Hancock-inspired groove."
96
+
97
+ def set_edm_prompt():
98
+ return "EDM with high-energy synth leads, pounding basslines, four-on-the-floor kicks, and a festival-ready drop."
99
+
100
+ def set_indie_folk_prompt():
101
+ return "Indie folk with acoustic guitars, heartfelt vocals, gentle percussion, and a Bon Iver-inspired atmosphere."
102
+
103
+ # 5) AUDIO PROCESSING FUNCTIONS
104
+ def apply_chorus(segment):
105
+ delayed = segment - 6
106
+ delayed = delayed.set_frame_rate(segment.frame_rate)
107
+ return segment.overlay(delayed, position=20)
108
+
109
+ def apply_eq(segment):
110
+ segment = segment.low_pass_filter(8000)
111
+ segment = segment.high_pass_filter(80)
112
+ return segment
113
+
114
+ def apply_limiter(segment, max_db=-3.hibit0):
115
+ if segment.dBFS > max_db:
116
+ segment = segment - (segment.dBFS - max_db)
117
+ return segment
118
+
119
+ def apply_final_gain(segment, target_db=-12.0):
120
+ gain_adjustment = target_db - segment.dBFS
121
+ return segment + gain_adjustment
122
+
123
+ def apply_fade(segment, fade_in_duration=2000, fade_out_duration=2000):
124
+ segment = segment.fade_in(fade_in_duration)
125
+ segment = segment.fade_out(fade_out_duration)
126
+ return segment
127
+
128
+ # 6) GENERATION & I/O FUNCTIONS
129
+ def generate_music(instrumental_prompt: str, cfg_scale: float, top_k: int, top_p: float, temperature: float, total_duration: int, crossfade_duration: int, num_variations: int = 1):
130
+ global musicgen_model
131
+ if not instrumental_prompt.strip():
132
+ return None, "⚠️ Please enter a valid instrumental prompt!"
133
+ try:
134
+ start_time = time.time()
135
+ total_duration = min(max(total_duration, 10), 90)
136
+ chunk_duration = 15
137
+ num_chunks = max(1, total_duration // chunk_duration)
138
+ chunk_duration = total_duration / num_chunks
139
+ overlap_duration = min(1.0, crossfade_duration / 1000.0)
140
+ generation_duration = chunk_duration + overlap_duration
141
+
142
+ output_files = []
143
+ sample_rate = musicgen_model.sample_rate
144
+
145
+ for var in range(num_variations):
146
+ print(f"Generating variation {var+1}/{num_variations}...")
147
+ audio_chunks = []
148
+ seed = 42 + var # Use different seeds for variations
149
+ torch.manual_seed(seed)
150
+ np.random.seed(seed)
151
+
152
+ for i in range(num_chunks):
153
+ chunk_prompt = instrumental_prompt
154
+ print(f"Generating chunk {i+1}/{num_chunks} for variation {var+1} on GPU (prompt: {chunk_prompt})...")
155
+ musicgen_model.set_generation_params(
156
+ duration=generation_duration,
157
+ use_sampling=True,
158
+ top_k=top_k,
159
+ top_p=top_p,
160
+ temperature=temperature,
161
+ cfg_coef=cfg_scale
162
+ )
163
+
164
+ print_resource_usage(f"Before Chunk {i+1} Generation (Variation {var+1})")
165
+
166
+ with torch.no_grad():
167
+ with autocast():
168
+ audio_chunk = musicgen_model.generate([chunk_prompt], progress=True)[0]
169
+
170
+ audio_chunk = audio_chunk.cpu().to(dtype=torch.float32)
171
+ if audio_chunk.dim() == 1:
172
+ audio_chunk = torch.stack([audio_chunk, audio_chunk], dim=0)
173
+ elif audio_chunk.dim() == 2 and audio_chunk.shape[0] == 1:
174
+ audio_chunk = torch.cat([audio_chunk, audio_chunk], dim=0)
175
+ elif audio_chunk.dim() == 2 and audio_chunk.shape[0] != 2:
176
+ audio_chunk = audio_chunk[:1, :]
177
+ audio_chunk = torch.cat([audio_chunk, audio_chunk], dim=0)
178
+ elif audio_chunk.dim() > 2:
179
+ audio_chunk = audio_chunk.view(2, -1)
180
+
181
+ if audio_chunk.shape[0] != 2:
182
+ raise ValueError(f"Expected stereo audio with shape (2, samples), got shape {audio_chunk.shape}")
183
+
184
+ temp_wav_path = f"temp_chunk_{var}_{i}.wav"
185
+ chunk_path = f"chunk_{var}_{i}.mp3"
186
+ torchaudio.save(temp_wav_path, audio_chunk, sample_rate, bits_per_sample=24)
187
+ segment = AudioSegment.from_wav(temp_wav_path)
188
+ segment.export(chunk_path, format="mp3", bitrate="320k")
189
+ os.remove(temp_wav_path)
190
+ audio_chunks.append(chunk_path)
191
+
192
+ torch.cuda.empty_cache()
193
+ gc.collect()
194
+ time.sleep(0.5)
195
+ print_resource_usage(f"After Chunk {i+1} Generation (Variation {var+1})")
196
+
197
+ print(f"Combining audio chunks for variation {var+1}...")
198
+ final_segment = AudioSegment.from_mp3(audio_chunks[0])
199
+ for i in range(1, len(audio_chunks)):
200
+ next_segment = AudioSegment.from_mp3(audio_chunks[i])
201
+ next_segment = next_segment + 1
202
+ final_segment = final_segment.append(next_segment, crossfade=crossfade_duration)
203
+
204
+ final_segment = final_segment[:total_duration * 1000]
205
+
206
+ print(f"Post-processing final track for variation {var+1}...")
207
+ final_segment = apply_eq(final_segment)
208
+ final_segment = apply_chorus(final_segment)
209
+ final_segment = apply_limiter(final_segment, max_db=-3.0)
210
+ final_segment = final_segment.normalize(headroom=-6.0)
211
+ final_segment = apply_final_gain(final_segment, target_db=-12.0)
212
+
213
+ mp3_path = f"output_cleaned_variation_{var+1}.mp3"
214
+ final_segment.export(
215
+ mp3_path,
216
+ format="mp3",
217
+ bitrate="320k",
218
+ tags={"title": f"GhostAI Instrumental Variation {var+1}", "artist": "GhostAI"}
219
+ )
220
+ print(f"Saved final audio to {mp3_path}")
221
+ output_files.append(mp3_path)
222
+
223
+ for chunk_path in audio_chunks:
224
+ os.remove(chunk_path)
225
+
226
+ print_resource_usage("After Final Generation")
227
+ print(f"Total Generation Time: {time.time() - start_time:.2f} seconds")
228
+
229
+ # Return the first variation for Gradio display; others are saved to disk
230
+ return output_files[0], f"✅ Done! Generated {num_variations} variations."
231
+ except Exception as e:
232
+ return None, f"❌ Generation failed: {e}"
233
+ finally:
234
+ torch.cuda.empty_cache()
235
+ gc.collect()
236
+
237
+ def clear_inputs():
238
+ return "", 3.0, 250, 0.9, 1.0, 30, 500, 1
239
+
240
+ # 7) CUSTOM CSS
241
+ css = """
242
+ body {
243
+ background: linear-gradient(135deg, #0A0A0A 0%, #1C2526 100%);
244
+ color: #E0E0E0;
245
+ font-family: 'Orbitron', sans-serif;
246
+ }
247
+ .header-container {
248
+ text-align: center;
249
+ padding: 15px 20px;
250
+ background: rgba(0, 0, 0, 0.9);
251
+ border-bottom: 1px solid #00FF9F;
252
+ }
253
+ #ghost-logo {
254
+ font-size: 60px;
255
+ animation: glitch-ghost 1.5s infinite;
256
+ }
257
+ h1 {
258
+ color: #A100FF;
259
+ font-size: 28px;
260
+ animation: glitch-text 2s infinite;
261
+ }
262
+ .input-container, .settings-container, .output-container {
263
+ max-width: 1000px;
264
+ margin: 20px auto;
265
+ padding: 20px;
266
+ background: rgba(28, 37, 38, 0.8);
267
+ border-radius: 10px;
268
+ }
269
+ .textbox {
270
+ background: #1A1A1A;
271
+ border: 1px solid #A100FF;
272
+ color: #E0E0E0;
273
+ }
274
+ .genre-buttons {
275
+ display: flex;
276
+ justify-content: center;
277
+ gap: 15px;
278
+ }
279
+ .genre-btn, button {
280
+ background: linear-gradient(45deg, #A100FF, #00FF9F);
281
+ border: none;
282
+ color: #0A0A0A;
283
+ padding: 10px 20px;
284
+ border-radius: 5px;
285
+ }
286
+ @keyframes glitch-ghost {
287
+ 0% { transform: translate(0, 0); opacity: 1; }
288
+ 20% { transform: translate(-5px, 2px); opacity: 0.8; }
289
+ 100% { transform: translate(0, 0); opacity: 1; }
290
+ }
291
+ @keyframes glitch-text {
292
+ 0% { transform: translate(0, 0); }
293
+ 20% { transform: translate(-2px, 1px); }
294
+ 100% { transform: translate(0, 0); }
295
+ }
296
+ @font-face {
297
+ font-family: 'Orbitron';
298
+ src: url('https://fonts.gstatic.com/s/orbitron/v29/yMJRMIlzdpvBhQQL_Qq7dy0.woff2') format('woff2');
299
+ }
300
+ """
301
+
302
+ # 8) BUILD WITH BLOCKS
303
+ with gr.Blocks(css=css) as demo:
304
+ gr.Markdown("""
305
+ <div class="header-container">
306
+ <div id="ghost-logo">👻</div>
307
+ <h1>GhostAI Music Generator</h1>
308
+ <p>Summon the Sound of the Unknown</p>
309
+ </div>
310
+ """)
311
+
312
+ with gr.Column(elem_classes="input-container"):
313
+ instrumental_prompt = gr.Textbox(
314
+ label="Instrumental Prompt",
315
+ placeholder="Click a genre button or type your own prompt",
316
+ lines=4,
317
+ elem_classes="textbox"
318
+ )
319
+ with gr.Row(elem_classes="genre-buttons"):
320
+ classic_rock_btn = gr.Button("Classic Rock", elem_classes="genre-btn")
321
+ alternative_rock_btn = gr.Button("Alternative Rock", elem_classes="genre-btn")
322
+ detroit_techno_btn = gr.Button("Detroit Techno", elem_classes="genre-btn")
323
+ deep_house_btn = gr.Button("Deep House", elem_classes="genre-btn")
324
+ smooth_jazz_btn = gr.Button("Smooth Jazz", elem_classes="genre-btn")
325
+ bebop_jazz_btn = gr.Button("Bebop Jazz", elem_classes="genre-btn")
326
+ baroque_classical_btn = gr.Button("Baroque Classical", elem_classes="genre-btn")
327
+ romantic_classical_btn = gr.Button("Romantic Classical", elem_classes="genre-btn")
328
+ boom_bap_hiphop_btn = gr.Button("Boom Bap Hip-Hop", elem_classes="genre-btn")
329
+ trap_hiphop_btn = gr.Button("Trap Hip-Hop", elem_classes="genre-btn")
330
+ pop_rock_btn = gr.Button("Pop Rock", elem_classes="genre-btn")
331
+ fusion_jazz_btn = gr.Button("Fusion Jazz", elem_classes="genre-btn")
332
+ edm_btn = gr.Button("EDM", elem_classes="genre-btn")
333
+ indie_folk_btn = gr.Button("Indie Folk", elem_classes="genre-btn")
334
+
335
+ with gr.Column(elem_classes="settings-container"):
336
+ cfg_scale = gr.Slider(
337
+ label="Guidance Scale (CFG)",
338
+ minimum=1.0,
339
+ maximum=10.0,
340
+ value=3.0,
341
+ step=0.1,
342
+ info="Higher values make the instrumental more closely follow the prompt."
343
+ )
344
+ top_k = gr.Slider(
345
+ label="Top-K Sampling",
346
+ minimum=10,
347
+ maximum=500,
348
+ value=250,
349
+ step=10,
350
+ info="Limits sampling to the top k most likely tokens."
351
+ )
352
+ top_p = gr.Slider(
353
+ label="Top-P Sampling",
354
+ minimum=0.0,
355
+ maximum=1.0,
356
+ value=0.9,
357
+ step=0.05,
358
+ info="Keeps tokens with cumulative probability above p."
359
+ )
360
+ temperature = gr.Slider(
361
+ label="Temperature",
362
+ minimum=0.1,
363
+ maximum=2.0,
364
+ value=1.0,
365
+ step=0.1,
366
+ info="Controls randomness. Higher values make output more diverse."
367
+ )
368
+ total_duration = gr.Slider(
369
+ label="Total Duration (seconds)",
370
+ minimum=10,
371
+ maximum=90,
372
+ value=30,
373
+ step=1,
374
+ info="Total duration of the track (10 to 90 seconds)."
375
+ )
376
+ crossfade_duration = gr.Slider(
377
+ label="Crossfade Duration (ms)",
378
+ minimum=100,
379
+ maximum=2000,
380
+ value=500,
381
+ step=100,
382
+ info="Crossfade duration between chunks."
383
+ )
384
+ num_variations = gr.Slider(
385
+ label="Number of Variations",
386
+ minimum=1,
387
+ maximum=4,
388
+ value=1,
389
+ step=1,
390
+ info="Number of different versions to generate with varying random seeds."
391
+ )
392
+ with gr.Row(elem_classes="action-buttons"):
393
+ gen_btn = gr.Button("Generate Music")
394
+ clr_btn = gr.Button("Clear Inputs")
395
+
396
+ with gr.Column(elem_classes="output-container"):
397
+ out_audio = gr.Audio(label="Generated Stereo Instrumental Track", type="filepath")
398
+ status = gr.Textbox(label="Status", interactive=False)
399
+
400
+ classic_rock_btn.click(set_classic_rock_prompt, inputs=None, outputs=[instrumental_prompt])
401
+ alternative_rock_btn.click(set_alternative_rock_prompt, inputs=None, outputs=[instrumental_prompt])
402
+ detroit_techno_btn.click(set_detroit_techno_prompt, inputs=None, outputs=[instrumental_prompt])
403
+ deep_house_btn.click(set_deep_house_prompt, inputs=None, outputs=[instrumental_prompt])
404
+ smooth_jazz_btn.click(set_smooth_jazz_prompt, inputs=None, outputs=[instrumental_prompt])
405
+ _scale=3.0,
406
+ top_k=250,
407
+ top_p=0.9,
408
+ temperature=1.0,
409
+ total_duration=30,
410
+ crossfade_duration=500,
411
+ num_variations=1
412
+ )
413
+
414
+ # 9) TURN OFF OPENAPI/DOCS
415
+ app = demo.launch(
416
+ server_name="0.0.0.0",
417
+ server_port=9999,
418
+ share=False,
419
+ inbrowser=False,
420
+ show_error=True
421
+ )
422
+ try:
423
+ fastapi_app = demo._server.app
424
+ fastapi_app.docs_url = None
425
+ fastapi_app.redoc_url = None
426
+ fastapi_app.openapi_url = None
427
+ except Exception:
428
+ pass