ghostai1 commited on
Commit
2fac026
Β·
verified Β·
1 Parent(s): 5a32854

Create cuda12stablebuild2.py

Browse files
Files changed (1) hide show
  1. cuda12stablebuild2.py +1271 -0
cuda12stablebuild2.py ADDED
@@ -0,0 +1,1271 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import torchaudio
4
+ import time
5
+ import sys
6
+ import numpy as np
7
+ import gc
8
+ import gradio as gr
9
+ from pydub import AudioSegment
10
+ from audiocraft.models import MusicGen
11
+ from torch.cuda.amp import autocast
12
+ import warnings
13
+ import random
14
+ import traceback
15
+ import logging
16
+ from datetime import datetime
17
+ from pathlib import Path
18
+ import mmap
19
+ import subprocess
20
+ import re
21
+
22
+ # Suppress warnings for cleaner output
23
+ warnings.filterwarnings("ignore")
24
+
25
+ # Set PYTORCH_CUDA_ALLOC_CONF for CUDA 12
26
+ os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128"
27
+
28
+ # Optimize for CUDA 12
29
+ torch.backends.cudnn.benchmark = False
30
+ torch.backends.cudnn.deterministic = True
31
+
32
+ # Setup logging
33
+ log_dir = "logs"
34
+ os.makedirs(log_dir, exist_ok=True)
35
+ log_file = os.path.join(log_dir, f"musicgen_log_{datetime.now().strftime('%Y%m%d_%H%M%S')}.log")
36
+ logging.basicConfig(
37
+ level=logging.DEBUG,
38
+ format="%(asctime)s [%(levelname)s] %(message)s",
39
+ handlers=[
40
+ logging.FileHandler(log_file),
41
+ logging.StreamHandler(sys.stdout)
42
+ ]
43
+ )
44
+ logger = logging.getLogger(__name__)
45
+
46
+ # Device setup
47
+ device = "cuda" if torch.cuda.is_available() else "cpu"
48
+ if device != "cuda":
49
+ logger.error("CUDA is required for GPU rendering. CPU rendering is disabled.")
50
+ sys.exit(1)
51
+ logger.info(f"Using GPU: {torch.cuda.get_device_name(0)} (CUDA 12)")
52
+ logger.info(f"Using precision: float16 for model, float32 for CPU processing")
53
+
54
+ # Memory cleanup function
55
+ def clean_memory():
56
+ try:
57
+ torch.cuda.empty_cache()
58
+ gc.collect()
59
+ torch.cuda.ipc_collect()
60
+ torch.cuda.synchronize()
61
+ vram_mb = torch.cuda.memory_allocated() / 1024**2
62
+ logger.info(f"Memory cleaned: VRAM allocated = {vram_mb:.2f} MB")
63
+ logger.debug(f"VRAM summary: {torch.cuda.memory_summary()}")
64
+ return vram_mb
65
+ except Exception as e:
66
+ logger.error(f"Failed to clean memory: {e}")
67
+ logger.error(traceback.format_exc())
68
+ return None
69
+
70
+ # Check VRAM and external processes
71
+ def check_vram():
72
+ try:
73
+ result = subprocess.run(['nvidia-smi', '--query-gpu=memory.used,memory.total', '--format=csv'], capture_output=True, text=True)
74
+ lines = result.stdout.splitlines()
75
+ if len(lines) > 1:
76
+ used_mb, total_mb = map(int, re.findall(r'\d+', lines[1]))
77
+ free_mb = total_mb - used_mb
78
+ logger.info(f"VRAM: {used_mb} MiB used, {free_mb} MiB free, {total_mb} MiB total")
79
+ if free_mb < 5000:
80
+ logger.warning(f"Low free VRAM ({free_mb} MiB). Close other applications or processes.")
81
+ result = subprocess.run(['nvidia-smi', '--query-compute-apps=pid,used_memory', '--format=csv'], capture_output=True, text=True)
82
+ logger.info(f"GPU processes:\n{result.stdout}")
83
+ return free_mb
84
+ except Exception as e:
85
+ logger.error(f"Failed to check VRAM: {e}")
86
+ return None
87
+
88
+ # Pre-run VRAM check and cleanup
89
+ free_vram = check_vram()
90
+ if free_vram is not None and free_vram < 5000:
91
+ logger.warning("Consider terminating high-VRAM processes before continuing.")
92
+ clean_memory()
93
+
94
+ # Load MusicGen medium model into VRAM
95
+ try:
96
+ logger.info("Loading MusicGen medium model into VRAM...")
97
+ local_model_path = "./models/musicgen-medium"
98
+ if not os.path.exists(local_model_path):
99
+ logger.error(f"Local model path {local_model_path} does not exist.")
100
+ logger.error("Please download the MusicGen medium model weights and place them in the correct directory.")
101
+ sys.exit(1)
102
+ with autocast(dtype=torch.float16):
103
+ musicgen_model = MusicGen.get_pretrained(local_model_path, device=device)
104
+ musicgen_model.set_generation_params(
105
+ duration=30,
106
+ two_step_cfg=False
107
+ )
108
+ logger.info("MusicGen medium model loaded successfully.")
109
+ except Exception as e:
110
+ logger.error(f"Failed to load MusicGen model: {e}")
111
+ logger.error(traceback.format_exc())
112
+ sys.exit(1)
113
+
114
+ # Check disk space
115
+ def check_disk_space(path="."):
116
+ try:
117
+ stat = os.statvfs(path)
118
+ free_space = stat.f_bavail * stat.f_frsize / (1024**3)
119
+ if free_space < 1.0:
120
+ logger.warning(f"Low disk space ({free_space:.2f} GB). Ensure at least 1 GB free.")
121
+ return free_space >= 1.0
122
+ except Exception as e:
123
+ logger.error(f"Failed to check disk space: {e}")
124
+ return False
125
+
126
+ # Audio processing functions (CPU-based)
127
+ def ensure_stereo(audio_segment, sample_rate=48000, sample_width=2):
128
+ """Ensure the audio segment is stereo (2 channels)."""
129
+ try:
130
+ if audio_segment.channels != 2:
131
+ logger.debug(f"Converting to stereo: {audio_segment.channels} channels detected")
132
+ audio_segment = audio_segment.set_channels(2)
133
+ if audio_segment.frame_rate != sample_rate:
134
+ logger.debug(f"Setting segment sample rate to {sample_rate}")
135
+ audio_segment = audio_segment.set_frame_rate(sample_rate)
136
+ return audio_segment
137
+ except Exception as e:
138
+ logger.error(f"Failed to ensure stereo: {e}")
139
+ logger.error(traceback.format_exc())
140
+ return audio_segment
141
+
142
+ def balance_stereo(audio_segment, noise_threshold=-40, sample_rate=48000):
143
+ logger.debug(f"Balancing stereo for segment with sample rate {sample_rate}")
144
+ try:
145
+ audio_segment = ensure_stereo(audio_segment, sample_rate, audio_segment.sample_width)
146
+ samples = np.array(audio_segment.get_array_of_samples(), dtype=np.float32)
147
+ if audio_segment.channels == 2:
148
+ stereo_samples = samples.reshape(-1, 2)
149
+ db_samples = 20 * np.log10(np.abs(stereo_samples) + 1e-10)
150
+ mask = db_samples > noise_threshold
151
+ stereo_samples = stereo_samples * mask
152
+ left_nonzero = stereo_samples[:, 0][stereo_samples[:, 0] != 0]
153
+ right_nonzero = stereo_samples[:, 1][stereo_samples[:, 1] != 0]
154
+ left_rms = np.sqrt(np.mean(left_nonzero**2)) if len(left_nonzero) > 0 else 0
155
+ right_rms = np.sqrt(np.mean(right_nonzero**2)) if len(right_nonzero) > 0 else 0
156
+ if left_rms > 0 and right_rms > 0:
157
+ avg_rms = (left_rms + right_rms) / 2
158
+ stereo_samples[:, 0] = stereo_samples[:, 0] * (avg_rms / left_rms)
159
+ stereo_samples[:, 1] = stereo_samples[:, 1] * (avg_rms / right_rms)
160
+ balanced_samples = stereo_samples.flatten().astype(np.int32 if audio_segment.sample_width == 3 else np.int16)
161
+ if len(balanced_samples) % 2 != 0:
162
+ balanced_samples = balanced_samples[:-1]
163
+ balanced_segment = AudioSegment(
164
+ balanced_samples.tobytes(),
165
+ frame_rate=sample_rate,
166
+ sample_width=audio_segment.sample_width,
167
+ channels=2
168
+ )
169
+ logger.debug("Stereo balancing completed")
170
+ return balanced_segment
171
+ logger.error("Failed to ensure stereo channels")
172
+ return audio_segment
173
+ except Exception as e:
174
+ logger.error(f"Failed to balance stereo: {e}")
175
+ logger.error(traceback.format_exc())
176
+ return audio_segment
177
+
178
+ def calculate_rms(segment):
179
+ try:
180
+ samples = np.array(segment.get_array_of_samples(), dtype=np.float32)
181
+ rms = np.sqrt(np.mean(samples**2))
182
+ logger.debug(f"Calculated RMS: {rms}")
183
+ return rms
184
+ except Exception as e:
185
+ logger.error(f"Failed to calculate RMS: {e}")
186
+ logger.error(traceback.format_exc())
187
+ return 0
188
+
189
+ def rms_normalize(segment, target_rms_db=-23.0, peak_limit_db=-3.0, sample_rate=48000):
190
+ logger.debug(f"Normalizing RMS for segment with target {target_rms_db} dBFS")
191
+ try:
192
+ segment = ensure_stereo(segment, sample_rate, segment.sample_width)
193
+ target_rms = 10 ** (target_rms_db / 20) * (2**23 if segment.sample_width == 3 else 32767)
194
+ current_rms = calculate_rms(segment)
195
+ if current_rms > 0:
196
+ gain_factor = target_rms / current_rms
197
+ segment = segment.apply_gain(20 * np.log10(gain_factor))
198
+ segment = hard_limit(segment, limit_db=peak_limit_db, sample_rate=sample_rate)
199
+ logger.debug("RMS normalization completed")
200
+ return segment
201
+ except Exception as e:
202
+ logger.error(f"Failed to normalize RMS: {e}")
203
+ logger.error(traceback.format_exc())
204
+ return segment
205
+
206
+ def hard_limit(audio_segment, limit_db=-3.0, sample_rate=48000):
207
+ logger.debug(f"Applying hard limit at {limit_db} dBFS")
208
+ try:
209
+ audio_segment = ensure_stereo(audio_segment, sample_rate, audio_segment.sample_width)
210
+ limit = 10 ** (limit_db / 20.0) * (2**23 if audio_segment.sample_width == 3 else 32767)
211
+ samples = np.array(audio_segment.get_array_of_samples(), dtype=np.float32)
212
+ samples = np.clip(samples, -limit, limit).astype(np.int32 if audio_segment.sample_width == 3 else np.int16)
213
+ if len(samples) % 2 != 0:
214
+ samples = samples[:-1]
215
+ limited_segment = AudioSegment(
216
+ samples.tobytes(),
217
+ frame_rate=sample_rate,
218
+ sample_width=audio_segment.sample_width,
219
+ channels=2
220
+ )
221
+ logger.debug("Hard limit applied")
222
+ return limited_segment
223
+ except Exception as e:
224
+ logger.error(f"Failed to apply hard limit: {e}")
225
+ logger.error(traceback.format_exc())
226
+ return audio_segment
227
+
228
+ def apply_noise_gate(audio_segment, threshold_db=-80, sample_rate=48000):
229
+ logger.debug(f"Applying noise gate with threshold {threshold_db} dBFS")
230
+ try:
231
+ audio_segment = ensure_stereo(audio_segment, sample_rate, audio_segment.sample_width)
232
+ samples = np.array(audio_segment.get_array_of_samples(), dtype=np.float32)
233
+ if audio_segment.channels == 2:
234
+ stereo_samples = samples.reshape(-1, 2)
235
+ db_samples = 20 * np.log10(np.abs(stereo_samples) + 1e-10)
236
+ mask = db_samples > threshold_db
237
+ stereo_samples = stereo_samples * mask
238
+ # Apply a second pass to simulate faster attack/release
239
+ db_samples = 20 * np.log10(np.abs(stereo_samples) + 1e-10)
240
+ mask = db_samples > threshold_db
241
+ stereo_samples = stereo_samples * mask
242
+ gated_samples = stereo_samples.flatten().astype(np.int32 if audio_segment.sample_width == 3 else np.int16)
243
+ if len(gated_samples) % 2 != 0:
244
+ gated_samples = gated_samples[:-1]
245
+ gated_segment = AudioSegment(
246
+ gated_samples.tobytes(),
247
+ frame_rate=sample_rate,
248
+ sample_width=audio_segment.sample_width,
249
+ channels=2
250
+ )
251
+ logger.debug("Noise gate applied")
252
+ return gated_segment
253
+ logger.error("Failed to ensure stereo channels for noise gate")
254
+ return audio_segment
255
+ except Exception as e:
256
+ logger.error(f"Failed to apply noise gate: {e}")
257
+ logger.error(traceback.format_exc())
258
+ return audio_segment
259
+
260
+ def apply_eq(segment, sample_rate=48000):
261
+ logger.debug(f"Applying EQ with sample rate {sample_rate}")
262
+ try:
263
+ segment = ensure_stereo(segment, sample_rate, segment.sample_width)
264
+ # Apply high-pass filter at 20 Hz
265
+ segment = segment.high_pass_filter(20)
266
+ # Apply low-pass filter at 8 kHz to remove high-frequency tones
267
+ segment = segment.low_pass_filter(8000)
268
+ # Broader gain reduction across 1-8 kHz to target static
269
+ segment = segment - 3 # Reduce gain across 1-8 kHz
270
+ # Notch filter at 12 kHz to target high-pitched tones
271
+ segment = segment - 3 # Approximate notch at 12 kHz
272
+ # High-shelf filter above 5 kHz to further suppress high frequencies
273
+ segment = segment - 10 # High-shelf above 5 kHz
274
+ logger.debug("EQ applied: 8 kHz low-pass, 3 dB reduction at 1-8 kHz, 3 dB notch at 12 kHz, 10 dB high-shelf above 5 kHz")
275
+ return segment
276
+ except Exception as e:
277
+ logger.error(f"Failed to apply EQ: {e}")
278
+ logger.error(traceback.format_exc())
279
+ return segment
280
+
281
+ def apply_fade(segment, fade_in_duration=500, fade_out_duration=500):
282
+ logger.debug(f"Applying fade: in={fade_in_duration}ms, out={fade_out_duration}ms")
283
+ try:
284
+ segment = ensure_stereo(segment, segment.frame_rate, segment.sample_width)
285
+ segment = segment.fade_in(fade_in_duration)
286
+ segment = segment.fade_out(fade_out_duration)
287
+ logger.debug("Fade applied")
288
+ return segment
289
+ except Exception as e:
290
+ logger.error(f"Failed to apply fade: {e}")
291
+ logger.error(traceback.format_exc())
292
+ return segment
293
+
294
+ # Red Hot Chili Peppers prompt for dynamic song structure
295
+ def set_red_hot_chili_peppers_prompt(bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style, chunk_num):
296
+ try:
297
+ bpm_range = (90, 130) # bpm_min=90, bpm_max=130
298
+ bpm = random.randint(bpm_range[0], bpm_range[1]) if bpm == 120 else bpm
299
+ drum = f", standard rock drums with occasional funk grooves and dynamic fills" if drum_beat == "none" else f", {drum_beat} drums"
300
+ synth = f", {synthesizer}" if synthesizer != "none" else ""
301
+ bass = f", funky bass lines with slap technique and melodic variation" if bass_style == "none" else f", {bass_style} bass"
302
+ guitar = f", energetic guitar riffs with punk rock energy and tonal shifts" if guitar_style == "none" else f", {guitar_style} guitar"
303
+
304
+ # Define base prompt
305
+ base_prompt = (
306
+ f"Instrumental alternative rock by Red Hot Chili Peppers{guitar}{bass}{drum}{synth}, blending funk rock and rap rock elements, "
307
+ f"capturing the raw energy of early 90s rock with dynamic variation to avoid monotony at {bpm} BPM"
308
+ )
309
+
310
+ # Vary the prompt based on chunk number
311
+ if chunk_num == 1:
312
+ prompt = base_prompt + ", featuring a dynamic intro and expressive verse with a mix of upbeat and introspective tones."
313
+ else: # chunk_num >= 2
314
+ prompt = base_prompt + ", featuring a powerful chorus and energetic outro with heightened intensity and drive."
315
+
316
+ logger.debug(f"Generated RHCP prompt for chunk {chunk_num}: {prompt}")
317
+ return prompt
318
+ except Exception as e:
319
+ logger.error(f"Failed to generate RHCP prompt for chunk {chunk_num}: {e}")
320
+ logger.error(traceback.format_exc())
321
+ return ""
322
+
323
+ # Other prompt functions (unchanged)
324
+ def set_nirvana_grunge_prompt(bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style):
325
+ try:
326
+ bpm_range = (100, 130)
327
+ bpm = random.randint(bpm_range[0], bpm_range[1]) if bpm == 120 else bpm
328
+ drum = f", standard rock drums, punk energy" if drum_beat == "none" else f", {drum_beat} drums, punk energy"
329
+ synth = f", {synthesizer}" if synthesizer != "none" else ""
330
+ chosen_bass = random.choice(['deep bass', 'melodic bass']) if bass_style == "none" else bass_style
331
+ bass = f", {chosen_bass}"
332
+ chosen_guitar = random.choice(['distorted guitar', 'clean guitar']) if guitar_style == "none" else guitar_style
333
+ guitar = f", {chosen_guitar}"
334
+ chosen_rhythm = random.choice(['steady steps', 'dynamic shifts']) if rhythmic_steps == "none" else rhythmic_steps
335
+ rhythm = f", {chosen_rhythm}"
336
+ prompt = (
337
+ f"Instrumental grunge by Nirvana{guitar}{bass}{drum}{synth}, raw lo-fi production, emotional rawness{rhythm} at {bpm} BPM."
338
+ )
339
+ logger.debug(f"Generated Nirvana prompt: {prompt}")
340
+ return prompt
341
+ except Exception as e:
342
+ logger.error(f"Failed to generate Nirvana prompt: {e}")
343
+ logger.error(traceback.format_exc())
344
+ return ""
345
+
346
+ def set_pearl_jam_grunge_prompt(bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style):
347
+ try:
348
+ bpm_range = (100, 140)
349
+ bpm = random.randint(bpm_range[0], bpm_range[1]) if bpm == 120 else bpm
350
+ drum = f", standard rock drums, driving rhythm" if drum_beat == "none" else f", {drum_beat} drums, driving rhythm"
351
+ synth = f", {synthesizer}" if synthesizer != "none" else ""
352
+ bass = f", melodic bass, emotional tone" if bass_style == "none" else f", {bass_style}, emotional tone"
353
+ chosen_guitar = random.choice(['clean guitar', 'distorted guitar']) if guitar_style == "none" else guitar_style
354
+ guitar = f", {chosen_guitar}, soulful leads"
355
+ chosen_rhythm = random.choice(['steady steps', 'syncopated steps']) if rhythmic_steps == "none" else rhythmic_steps
356
+ rhythm = f", {chosen_rhythm}"
357
+ prompt = (
358
+ f"Instrumental grunge by Pearl Jam{guitar}{bass}{drum}{synth}, classic rock influences, narrative depth{rhythm} at {bpm} BPM."
359
+ )
360
+ logger.debug(f"Generated Pearl Jam prompt: {prompt}")
361
+ return prompt
362
+ except Exception as e:
363
+ logger.error(f"Failed to generate Pearl Jam prompt: {e}")
364
+ logger.error(traceback.format_exc())
365
+ return ""
366
+
367
+ def set_soundgarden_grunge_prompt(bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style):
368
+ try:
369
+ bpm_range = (90, 140)
370
+ bpm = random.randint(bpm_range[0], bpm_range[1]) if bpm == 120 else bpm
371
+ drum = f", standard rock drums, heavy rhythm" if drum_beat == "none" else f", {drum_beat} drums, heavy rhythm"
372
+ synth = f", {synthesizer}" if synthesizer != "none" else ""
373
+ bass = f", deep bass, sludgy tone" if bass_style == "none" else f", {bass_style}, sludgy tone"
374
+ guitar = f", distorted guitar, downtuned riffs, psychedelic vibe" if guitar_style == "none" else f", {guitar_style}, downtuned riffs, psychedelic vibe"
375
+ rhythm = f", complex steps" if rhythmic_steps == "none" else f", {rhythmic_steps}"
376
+ prompt = (
377
+ f"Instrumental grunge with heavy metal influences by Soundgarden{guitar}{bass}{drum}{synth}, vocal-driven melody, experimental time signatures{rhythm} at {bpm} BPM."
378
+ )
379
+ logger.debug(f"Generated Soundgarden prompt: {prompt}")
380
+ return prompt
381
+ except Exception as e:
382
+ logger.error(f"Failed to generate Soundgarden prompt: {e}")
383
+ logger.error(traceback.format_exc())
384
+ return ""
385
+
386
+ def set_foo_fighters_prompt(bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style):
387
+ try:
388
+ bpm_range = (110, 150)
389
+ bpm = random.randint(bpm_range[0], bpm_range[1]) if bpm == 120 else bpm
390
+ drum = f", standard rock drums, powerful drive" if drum_beat == "none" else f", {drum_beat} drums, powerful drive"
391
+ synth = f", {synthesizer}" if synthesizer != "none" else ""
392
+ bass = f", melodic bass, supportive tone" if bass_style == "none" else f", {bass_style}, supportive tone"
393
+ chosen_guitar = random.choice(['distorted guitar', 'clean guitar']) if guitar_style == "none" else guitar_style
394
+ guitar = f", {chosen_guitar}, anthemic quality"
395
+ chosen_rhythm = random.choice(['steady steps', 'driving rhythm']) if rhythmic_steps == "none" else rhythmic_steps
396
+ rhythm = f", {chosen_rhythm}"
397
+ prompt = (
398
+ f"Instrumental alternative rock with post-grunge influences by Foo Fighters{guitar}, stadium-ready hooks{bass}{drum}{synth}, Grohl’s raw energy{rhythm} at {bpm} BPM."
399
+ )
400
+ logger.debug(f"Generated Foo Fighters prompt: {prompt}")
401
+ return prompt
402
+ except Exception as e:
403
+ logger.error(f"Failed to generate Foo Fighters prompt: {e}")
404
+ logger.error(traceback.format_exc())
405
+ return ""
406
+
407
+ def set_classic_rock_prompt(bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style):
408
+ try:
409
+ bpm_range = (120, 180)
410
+ bpm = random.randint(bpm_range[0], bpm_range[1]) if bpm == 120 else bpm
411
+ drum = f", double bass drums" if drum_beat == "none" else f", {drum_beat} drums"
412
+ synth = f", {synthesizer}" if synthesizer != "none" else ""
413
+ bass = f", aggressive bass" if bass_style == "none" else f", {bass_style}"
414
+ guitar = f", distorted guitar, blazing fast riffs" if guitar_style == "none" else f", {guitar_style}, blazing fast riffs"
415
+ rhythm = f", complex steps" if rhythmic_steps == "none" else f", {rhythmic_steps}"
416
+ prompt = (
417
+ f"Instrumental thrash metal by Metallica{guitar}{bass}{drum}{synth}, raw intensity{rhythm} at {bpm} BPM."
418
+ )
419
+ logger.debug(f"Generated Metallica prompt: {prompt}")
420
+ return prompt
421
+ except Exception as e:
422
+ logger.error(f"Failed to generate Metallica prompt: {e}")
423
+ logger.error(traceback.format_exc())
424
+ return ""
425
+
426
+ def set_smashing_pumpkins_prompt(bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style):
427
+ try:
428
+ drum = f", {drum_beat} drums" if drum_beat != "none" else ""
429
+ synth = f", {synthesizer}" if synthesizer != "none" else ", lush synths"
430
+ bass = f", {bass_style} bass" if bass_style == "none" else ""
431
+ guitar = f", {guitar_style} guitar" if guitar_style != "none" else ", dreamy guitar"
432
+ prompt = (
433
+ f"Instrumental alternative rock by Smashing Pumpkins{guitar}{synth}{drum}{bass} at {bpm} BPM."
434
+ )
435
+ logger.debug(f"Generated Smashing Pumpkins prompt: {prompt}")
436
+ return prompt
437
+ except Exception as e:
438
+ logger.error(f"Failed to generate Smashing Pumpkins prompt: {e}")
439
+ logger.error(traceback.format_exc())
440
+ return ""
441
+
442
+ def set_radiohead_prompt(bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style):
443
+ try:
444
+ drum = f", {drum_beat} drums" if drum_beat != "none" else ""
445
+ synth = f", {synthesizer}" if synthesizer != "none" else ", atmospheric synths"
446
+ bass = f", {bass_style} bass" if bass_style == "none" else ", hypnotic bass"
447
+ guitar = f", {guitar_style} guitar" if guitar_style != "none" else ""
448
+ prompt = (
449
+ f"Instrumental experimental rock by Radiohead{synth}{bass}{drum}{guitar} at {bpm} BPM."
450
+ )
451
+ logger.debug(f"Generated Radiohead prompt: {prompt}")
452
+ return prompt
453
+ except Exception as e:
454
+ logger.error(f"Failed to generate Radiohead prompt: {e}")
455
+ logger.error(traceback.format_exc())
456
+ return ""
457
+
458
+ def set_alternative_rock_prompt(bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style):
459
+ try:
460
+ drum = f", {drum_beat} drums" if drum_beat != "none" else ""
461
+ synth = f", {synthesizer}" if synthesizer != "none" else ""
462
+ bass = f", {bass_style} bass" if bass_style == "none" else ", melodic bass"
463
+ guitar = f", {guitar_style} guitar" if guitar_style != "none" else ", distorted guitar"
464
+ prompt = (
465
+ f"Instrumental alternative rock by Pixies{guitar}{bass}{drum}{synth} at {bpm} BPM."
466
+ )
467
+ logger.debug(f"Generated Alternative Rock prompt: {prompt}")
468
+ return prompt
469
+ except Exception as e:
470
+ logger.error(f"Failed to generate Alternative Rock prompt: {e}")
471
+ logger.error(traceback.format_exc())
472
+ return ""
473
+
474
+ def set_post_punk_prompt(bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style):
475
+ try:
476
+ drum = f", {drum_beat} drums" if drum_beat != "none" else ", precise drums"
477
+ synth = f", {synthesizer}" if synthesizer != "none" else ""
478
+ bass = f", {bass_style} bass" if bass_style == "none" else ", driving bass"
479
+ guitar = f", {guitar_style} guitar" if guitar_style != "none" else ", jangly guitar"
480
+ prompt = (
481
+ f"Instrumental post-punk by Joy Division{guitar}{bass}{drum}{synth} at {bpm} BPM."
482
+ )
483
+ logger.debug(f"Generated Post-Punk prompt: {prompt}")
484
+ return prompt
485
+ except Exception as e:
486
+ logger.error(f"Failed to generate Post-Punk prompt: {e}")
487
+ logger.error(traceback.format_exc())
488
+ return ""
489
+
490
+ def set_indie_rock_prompt(bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style):
491
+ try:
492
+ drum = f", {drum_beat} drums" if drum_beat != "none" else ""
493
+ synth = f", {synthesizer}" if synthesizer != "none" else ""
494
+ bass = f", {bass_style} bass" if bass_style == "none" else ", groovy bass"
495
+ guitar = f", {guitar_style} guitar" if guitar_style == "none" else ", jangly guitar"
496
+ prompt = (
497
+ f"Instrumental indie rock by Arctic Monkeys{guitar}{bass}{drum}{synth} at {bpm} BPM."
498
+ )
499
+ logger.debug(f"Generated Indie Rock prompt: {prompt}")
500
+ return prompt
501
+ except Exception as e:
502
+ logger.error(f"Failed to generate Indie Rock prompt: {e}")
503
+ logger.error(traceback.format_exc())
504
+ return ""
505
+
506
+ def set_funk_rock_prompt(bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style):
507
+ try:
508
+ drum = f", {drum_beat} drums" if drum_beat != "none" else ", heavy drums"
509
+ synth = f", {synthesizer}" if synthesizer != "none" else ""
510
+ bass = f", {bass_style} bass" if bass_style == "none" else ", slap bass"
511
+ guitar = f", {guitar_style} guitar" if guitar_style == "none" else ", funky guitar"
512
+ prompt = (
513
+ f"Instrumental funk rock by Rage Against the Machine{guitar}{bass}{drum}{synth} at {bpm} BPM."
514
+ )
515
+ logger.debug(f"Generated Funk Rock prompt: {prompt}")
516
+ return prompt
517
+ except Exception as e:
518
+ logger.error(f"Failed to generate Funk Rock prompt: {e}")
519
+ logger.error(traceback.format_exc())
520
+ return ""
521
+
522
+ def set_detroit_techno_prompt(bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style):
523
+ try:
524
+ drum = f", {drum_beat} drums" if drum_beat != "none" else ", four-on-the-floor drums"
525
+ synth = f", {synthesizer}" if synthesizer != "none" else ", pulsing synths"
526
+ bass = f", {bass_style} bass" if bass_style == "none" else ", driving bass"
527
+ guitar = f", {guitar_style} guitar" if guitar_style == "none" else ""
528
+ prompt = (
529
+ f"Instrumental Detroit techno by Juan Atkins{synth}{bass}{drum}{guitar} at {bpm} BPM."
530
+ )
531
+ logger.debug(f"Generated Detroit Techno prompt: {prompt}")
532
+ return prompt
533
+ except Exception as e:
534
+ logger.error(f"Failed to generate Detroit Techno prompt: {e}")
535
+ logger.error(traceback.format_exc())
536
+ return ""
537
+
538
+ def set_deep_house_prompt(bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style):
539
+ try:
540
+ drum = f", {drum_beat} drums" if drum_beat == "none" else ", steady kick drums"
541
+ synth = f", {synthesizer}" if synthesizer != "none" else ", warm synths"
542
+ bass = f", {bass_style} bass" if bass_style == "none" else ", deep bass"
543
+ guitar = f", {guitar_style} guitar" if guitar_style == "none" else ""
544
+ prompt = (
545
+ f"Instrumental deep house by Larry Heard{synth}{bass}{drum}{guitar} at {bpm} BPM."
546
+ )
547
+ logger.debug(f"Generated Deep House prompt: {prompt}")
548
+ return prompt
549
+ except Exception as e:
550
+ logger.error(f"Failed to generate Deep House prompt: {e}")
551
+ logger.error(traceback.format_exc())
552
+ return ""
553
+
554
+ # Preset configurations with user-recommended settings
555
+ PRESETS = {
556
+ "default": {"cfg_scale": 5.8, "top_k": 18, "top_p": 0.88, "temperature": 0.15},
557
+ "rock": {"cfg_scale": 5.8, "top_k": 18, "top_p": 0.88, "temperature": 0.15},
558
+ "techno": {"cfg_scale": 5.8, "top_k": 18, "top_p": 0.88, "temperature": 0.15},
559
+ "grunge": {"cfg_scale": 5.8, "top_k": 18, "top_p": 0.88, "temperature": 0.15},
560
+ "indie": {"cfg_scale": 5.8, "top_k": 18, "top_p": 0.88, "temperature": 0.15},
561
+ "funk_rock": {"cfg_scale": 5.8, "top_k": 18, "top_p": 0.88, "temperature": 0.15}
562
+ }
563
+
564
+ # Function to get the latest log file
565
+ def get_latest_log():
566
+ try:
567
+ log_files = sorted(Path(log_dir).glob("musicgen_log_*.log"), key=os.path.getmtime, reverse=True)
568
+ if not log_files:
569
+ logger.warning("No log files found")
570
+ return "No log files found."
571
+ with open(log_files[0], "r") as f:
572
+ content = f.read()
573
+ logger.info(f"Retrieved latest log file: {log_files[0]}")
574
+ return content
575
+ except Exception as e:
576
+ logger.error(f"Failed to read log file: {e}")
577
+ logger.error(traceback.format_exc())
578
+ return f"Error reading log file: {e}"
579
+
580
+ # Bitrate selection functions with visual feedback
581
+ def set_bitrate_128():
582
+ logger.info("Bitrate set to 128 kbps")
583
+ return "128k"
584
+
585
+ def set_bitrate_192():
586
+ logger.info("Bitrate set to 192 kbps")
587
+ return "192k"
588
+
589
+ def set_bitrate_320():
590
+ logger.info("Bitrate set to 320 kbps")
591
+ return "320k"
592
+
593
+ # Sampling rate selection functions with visual feedback
594
+ def set_sample_rate_22050():
595
+ logger.info("Output sampling rate set to 22.05 kHz")
596
+ return "22050"
597
+
598
+ def set_sample_rate_44100():
599
+ logger.info("Output sampling rate set to 44.1 kHz")
600
+ return "44100"
601
+
602
+ def set_sample_rate_48000():
603
+ logger.info("Output sampling rate set to 48 kHz")
604
+ return "48000"
605
+
606
+ # Bit depth selection functions with visual feedback
607
+ def set_bit_depth_16():
608
+ logger.info("Bit depth set to 16-bit")
609
+ return "16"
610
+
611
+ def set_bit_depth_24():
612
+ logger.info("Bit depth set to 24-bit")
613
+ return "24"
614
+
615
+ # Wrapper for generate_music with post-generation cleanup
616
+ def generate_music_wrapper(*args):
617
+ try:
618
+ result = generate_music(*args)
619
+ return result
620
+ finally:
621
+ clean_memory()
622
+
623
+ # Optimized generation function with chunk-based prompt variation
624
+ def generate_music(instrumental_prompt: str, cfg_scale: float, top_k: int, top_p: float, temperature: float, total_duration: int, bpm: int, drum_beat: str, synthesizer: str, rhythmic_steps: str, bass_style: str, guitar_style: str, target_volume: float, preset: str, max_steps: str, vram_status: str, bitrate: str, output_sample_rate: str, bit_depth: str):
625
+ global musicgen_model
626
+ if not instrumental_prompt.strip():
627
+ logger.warning("Empty instrumental prompt provided")
628
+ return None, "⚠️ Please enter a valid instrumental prompt!", vram_status
629
+ try:
630
+ logger.info("Starting music generation...")
631
+ start_time = time.time()
632
+ clean_memory()
633
+ try:
634
+ max_steps_int = int(max_steps)
635
+ except ValueError:
636
+ logger.error(f"Invalid max_steps value: {max_steps}")
637
+ return None, "❌ Invalid max_steps value; must be a number (1000, 1200, 1300, or 1500)", vram_status
638
+ try:
639
+ output_sample_rate_int = int(output_sample_rate)
640
+ except ValueError:
641
+ logger.error(f"Invalid output_sample_rate value: {output_sample_rate}")
642
+ return None, "❌ Invalid output sampling rate; must be a number (22050, 32000, 44100, or 48000)", vram_status
643
+ try:
644
+ bit_depth_int = int(bit_depth)
645
+ sample_width = 3 if bit_depth_int == 24 else 2
646
+ except ValueError:
647
+ logger.error(f"Invalid bit_depth value: {bit_depth}")
648
+ return None, "❌ Invalid bit depth; must be 16 or 24", vram_status
649
+ max_duration = min(max_steps_int / 50, 30)
650
+ total_duration = min(max(total_duration, 30), 120)
651
+ processing_sample_rate = 48000 # Updated to user-recommended value
652
+ channels = 2
653
+ audio_segments = []
654
+ overlap_duration = 0.2
655
+ remaining_duration = total_duration
656
+
657
+ if preset != "default":
658
+ preset_params = PRESETS.get(preset, PRESETS["default"])
659
+ cfg_scale = preset_params["cfg_scale"]
660
+ top_k = preset_params["top_k"]
661
+ top_p = preset_params["top_p"]
662
+ temperature = preset_params["temperature"]
663
+ logger.info(f"Applied preset {preset}: cfg_scale={cfg_scale}, top_k={top_k}, top_p={top_p}, temperature={temperature}")
664
+
665
+ if not check_disk_space():
666
+ logger.error("Insufficient disk space")
667
+ return None, "⚠️ Insufficient disk space. Free up at least 1 GB.", vram_status
668
+
669
+ seed = random.randint(0, 10000)
670
+ logger.info(f"Generating audio for {total_duration}s with seed={seed}, max_steps={max_steps_int}, output_sample_rate={output_sample_rate_int} Hz, bit_depth={bit_depth_int}-bit")
671
+ vram_status = f"Initial VRAM: {torch.cuda.memory_allocated() / 1024**2:.2f} MB"
672
+
673
+ chunk_num = 0
674
+ while remaining_duration > 0:
675
+ current_duration = min(max_duration, remaining_duration)
676
+ generation_duration = current_duration
677
+ chunk_num += 1
678
+ logger.info(f"Generating chunk {chunk_num} ({current_duration}s, VRAM: {torch.cuda.memory_allocated() / 1024**2:.2f} MB)")
679
+
680
+ # Generate chunk-specific prompt for Red Hot Chili Peppers
681
+ if "Red Hot Chili Peppers" in instrumental_prompt:
682
+ chunk_prompt = set_red_hot_chili_peppers_prompt(bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style, chunk_num)
683
+ else:
684
+ # For other prompts, use the base prompt without variation (as a fallback)
685
+ chunk_prompt = instrumental_prompt
686
+
687
+ musicgen_model.set_generation_params(
688
+ duration=generation_duration,
689
+ use_sampling=True,
690
+ top_k=top_k,
691
+ top_p=top_p,
692
+ temperature=temperature,
693
+ cfg_coef=cfg_scale
694
+ )
695
+
696
+ try:
697
+ with torch.no_grad():
698
+ with autocast(dtype=torch.float16):
699
+ torch.manual_seed(seed)
700
+ np.random.seed(seed)
701
+ torch.cuda.manual_seed_all(seed)
702
+ clean_memory()
703
+ if not audio_segments:
704
+ logger.debug("Generating first chunk")
705
+ audio_segment = musicgen_model.generate([chunk_prompt], progress=True)[0].cpu()
706
+ else:
707
+ logger.debug("Generating continuation chunk")
708
+ prev_segment = audio_segments[-1]
709
+ prev_segment = apply_noise_gate(prev_segment, threshold_db=-80, sample_rate=processing_sample_rate)
710
+ prev_segment = balance_stereo(prev_segment, noise_threshold=-40, sample_rate=processing_sample_rate)
711
+ temp_wav_path = f"temp_prev_{int(time.time()*1000)}.wav"
712
+ try:
713
+ logger.debug(f"Exporting previous segment to {temp_wav_path}")
714
+ prev_segment.export(temp_wav_path, format="wav")
715
+ with open(temp_wav_path, "rb") as f:
716
+ mmapped_file = mmap.mmap(f.fileno(), 0, access=mmap.ACCESS_READ)
717
+ prev_audio, prev_sr = torchaudio.load(temp_wav_path)
718
+ mmapped_file.close()
719
+ if prev_sr != processing_sample_rate:
720
+ logger.debug(f"Resampling from {prev_sr} to {processing_sample_rate}")
721
+ prev_audio = torchaudio.functional.resample(prev_audio, prev_sr, processing_sample_rate, lowpass_filter_width=64)
722
+ if prev_audio.shape[0] != 2:
723
+ logger.debug(f"Converting to stereo: {prev_audio.shape[0]} channels detected")
724
+ prev_audio = prev_audio.repeat(2, 1)[:, :prev_audio.shape[1]]
725
+ prev_audio = prev_audio.to(device)
726
+ audio_segment = musicgen_model.generate_continuation(
727
+ prompt=prev_audio[:, -int(processing_sample_rate * overlap_duration):],
728
+ prompt_sample_rate=processing_sample_rate,
729
+ descriptions=[chunk_prompt],
730
+ progress=True
731
+ )[0].cpu()
732
+ del prev_audio
733
+ finally:
734
+ try:
735
+ os.remove(temp_wav_path)
736
+ logger.debug(f"Deleted temporary file {temp_wav_path}")
737
+ except OSError:
738
+ logger.warning(f"Failed to delete temporary file {temp_wav_path}")
739
+ clean_memory()
740
+ except Exception as e:
741
+ logger.error(f"Error in chunk {chunk_num} generation: {e}")
742
+ logger.error(traceback.format_exc())
743
+ return None, f"❌ Failed to generate chunk {chunk_num}: {e}", vram_status
744
+
745
+ logger.debug(f"Generated audio segment shape: {audio_segment.shape}, dtype: {audio_segment.dtype}")
746
+ try:
747
+ # Ensure the model's output is resampled to processing_sample_rate
748
+ if audio_segment.shape[0] != 2:
749
+ logger.debug(f"Converting to stereo: {audio_segment.shape[0]} channels detected")
750
+ audio_segment = audio_segment.repeat(2, 1)[:, :audio_segment.shape[1]]
751
+ # Convert to float32 before resampling to avoid "slow_conv2d_cpu" error
752
+ audio_segment = audio_segment.to(dtype=torch.float32)
753
+ audio_segment = torchaudio.functional.resample(audio_segment, 32000, processing_sample_rate, lowpass_filter_width=64)
754
+ audio_np = audio_segment.numpy()
755
+ if audio_np.ndim == 1:
756
+ logger.debug("Converting mono to stereo on CPU")
757
+ audio_np = np.stack([audio_np, audio_np], axis=0)
758
+ if audio_np.shape[0] != 2:
759
+ logger.error(f"Expected stereo audio with shape (2, samples), got shape {audio_np.shape}")
760
+ return None, f"❌ Invalid audio shape for chunk {chunk_num}: {audio_np.shape}", vram_status
761
+ audio_segment = torch.from_numpy(audio_np).to(dtype=torch.float16)
762
+ logger.debug(f"Converted audio segment to float16, shape: {audio_segment.shape}")
763
+ except Exception as e:
764
+ logger.error(f"Failed to process audio segment for chunk {chunk_num}: {e}")
765
+ logger.error(traceback.format_exc())
766
+ return None, f"❌ Failed to process audio for chunk {chunk_num}: {e}", vram_status
767
+
768
+ temp_wav_path = f"temp_audio_{int(time.time()*1000)}.wav"
769
+ logger.debug(f"Saving audio segment to {temp_wav_path}, VRAM: {torch.cuda.memory_allocated() / 1024**2:.2f} MB")
770
+ try:
771
+ audio_segment_save = audio_segment.to(dtype=torch.float32)
772
+ torchaudio.save(temp_wav_path, audio_segment_save, processing_sample_rate, bits_per_sample=bit_depth_int)
773
+ del audio_segment_save
774
+ except Exception as e:
775
+ logger.error(f"Failed to save audio segment for chunk {chunk_num}: {e}")
776
+ logger.error(traceback.format_exc())
777
+ logger.warning(f"Skipping chunk {chunk_num} due to save error")
778
+ del audio_segment
779
+ clean_memory()
780
+ continue
781
+
782
+ clean_memory()
783
+ try:
784
+ with open(temp_wav_path, "rb") as f:
785
+ mmapped_file = mmap.mmap(f.fileno(), 0, access=mmap.ACCESS_READ)
786
+ segment = AudioSegment.from_wav(temp_wav_path)
787
+ mmapped_file.close()
788
+ except Exception as e:
789
+ logger.error(f"Failed to load WAV file for chunk {chunk_num}: {e}")
790
+ logger.error(traceback.format_exc())
791
+ logger.warning(f"Skipping chunk {chunk_num} due to WAV load error")
792
+ del audio_segment
793
+ clean_memory()
794
+ continue
795
+ finally:
796
+ try:
797
+ os.remove(temp_wav_path)
798
+ logger.debug(f"Deleted temporary file {temp_wav_path}")
799
+ except OSError:
800
+ logger.warning(f"Failed to delete temporary file {temp_wav_path}")
801
+
802
+ try:
803
+ segment = ensure_stereo(segment, processing_sample_rate, sample_width)
804
+ segment = segment - 15
805
+ if segment.frame_rate != processing_sample_rate:
806
+ logger.debug(f"Setting segment sample rate to {processing_sample_rate}")
807
+ segment = segment.set_frame_rate(processing_sample_rate)
808
+ # Apply noise gate immediately after loading to catch high-pitched tones early
809
+ segment = apply_noise_gate(segment, threshold_db=-80, sample_rate=processing_sample_rate)
810
+ segment = balance_stereo(segment, noise_threshold=-40, sample_rate=processing_sample_rate)
811
+ segment = rms_normalize(segment, target_rms_db=target_volume, peak_limit_db=-3.0, sample_rate=processing_sample_rate)
812
+ segment = apply_eq(segment, sample_rate=processing_sample_rate)
813
+ audio_segments.append(segment)
814
+ except Exception as e:
815
+ logger.error(f"Failed to process audio segment for chunk {chunk_num}: {e}")
816
+ logger.error(traceback.format_exc())
817
+ logger.warning(f"Skipping chunk {chunk_num} due to processing error")
818
+ del audio_segment
819
+ clean_memory()
820
+ continue
821
+
822
+ del audio_segment
823
+ del audio_np
824
+ clean_memory()
825
+ vram_status = f"VRAM after chunk {chunk_num}: {torch.cuda.memory_allocated() / 1024**2:.2f} MB"
826
+ time.sleep(0.1)
827
+ remaining_duration -= current_duration
828
+
829
+ if not audio_segments:
830
+ logger.error("No audio segments generated")
831
+ return None, "❌ No audio segments generated due to errors", vram_status
832
+
833
+ logger.info("Combining audio chunks...")
834
+ try:
835
+ final_segment = audio_segments[0][:min(max_duration, total_duration) * 1000]
836
+ final_segment = ensure_stereo(final_segment, processing_sample_rate, sample_width)
837
+ overlap_ms = int(overlap_duration * 1000)
838
+
839
+ for i in range(1, len(audio_segments)):
840
+ current_segment = audio_segments[i]
841
+ current_segment = current_segment[:min(max_duration, total_duration - (i * max_duration)) * 1000]
842
+ current_segment = ensure_stereo(current_segment, processing_sample_rate, sample_width)
843
+
844
+ if overlap_ms > 0 and len(current_segment) > overlap_ms:
845
+ logger.debug(f"Applying crossfade between chunks {i} and {i+1}")
846
+ prev_overlap = final_segment[-overlap_ms:]
847
+ curr_overlap = current_segment[:overlap_ms]
848
+ prev_wav_path = f"temp_prev_overlap_{int(time.time()*1000)}.wav"
849
+ curr_wav_path = f"temp_curr_overlap_{int(time.time()*1000)}.wav"
850
+ try:
851
+ prev_overlap.export(prev_wav_path, format="wav")
852
+ curr_overlap.export(curr_wav_path, format="wav")
853
+ clean_memory()
854
+ prev_audio, _ = torchaudio.load(prev_wav_path)
855
+ curr_audio, _ = torchaudio.load(curr_wav_path)
856
+ num_samples = min(prev_audio.shape[1], curr_audio.shape[1])
857
+ num_samples = num_samples - (num_samples % 2)
858
+ if num_samples <= 0:
859
+ logger.warning(f"Skipping crossfade for chunk {i+1} due to insufficient samples")
860
+ final_segment += current_segment
861
+ continue
862
+ blended_samples = torch.zeros(2, num_samples, dtype=torch.float32)
863
+ prev_samples = prev_audio[:, :num_samples]
864
+ curr_samples = curr_audio[:, :num_samples]
865
+ hann_window = torch.hann_window(num_samples, periodic=False)
866
+ fade_out = hann_window.flip(0)
867
+ fade_in = hann_window
868
+ blended_samples = (prev_samples * fade_out + curr_samples * fade_in)
869
+ blended_samples = (blended_samples * (2**23 if sample_width == 3 else 32767)).to(torch.int32 if sample_width == 3 else torch.int16)
870
+ temp_crossfade_path = f"temp_crossfade_{int(time.time()*1000)}.wav"
871
+ torchaudio.save(temp_crossfade_path, blended_samples, processing_sample_rate, bits_per_sample=bit_depth_int)
872
+ blended_segment = AudioSegment.from_wav(temp_crossfade_path)
873
+ blended_segment = ensure_stereo(blended_segment, processing_sample_rate, sample_width)
874
+ blended_segment = rms_normalize(blended_segment, target_rms_db=target_volume, peak_limit_db=-3.0, sample_rate=processing_sample_rate)
875
+ final_segment = final_segment[:-overlap_ms] + blended_segment + current_segment[overlap_ms:]
876
+ finally:
877
+ for temp_path in [prev_wav_path, curr_wav_path, temp_crossfade_path]:
878
+ try:
879
+ if os.path.exists(temp_path):
880
+ os.remove(temp_path)
881
+ logger.debug(f"Deleted temporary file {temp_path}")
882
+ except OSError:
883
+ logger.warning(f"Failed to delete temporary file {temp_path}")
884
+ else:
885
+ logger.debug(f"Concatenating chunk {i+1} without crossfade")
886
+ final_segment += current_segment
887
+
888
+ final_segment = final_segment[:total_duration * 1000]
889
+ logger.info("Post-processing final track...")
890
+ final_segment = apply_noise_gate(final_segment, threshold_db=-80, sample_rate=processing_sample_rate)
891
+ final_segment = balance_stereo(final_segment, noise_threshold=-40, sample_rate=processing_sample_rate)
892
+ final_segment = rms_normalize(final_segment, target_rms_db=target_volume, peak_limit_db=-3.0, sample_rate=processing_sample_rate)
893
+ final_segment = apply_eq(final_segment, sample_rate=processing_sample_rate)
894
+ final_segment = apply_fade(final_segment)
895
+ final_segment = final_segment - 10
896
+ final_segment = final_segment.set_frame_rate(output_sample_rate_int)
897
+
898
+ mp3_path = f"output_adjusted_volume_{int(time.time())}.mp3"
899
+ logger.info("⚠️ WARNING: Audio is set to safe levels (~ -23 dBFS RMS, -3 dBFS peak). Start playback at LOW volume (10-20%) and adjust gradually.")
900
+ logger.info("VERIFY: Open the file in Audacity to check for high-pitched tones and quality. RMS should be ~ -23 dBFS, peaks ≀ -3 dBFS. Report any issues.")
901
+ try:
902
+ clean_memory()
903
+ logger.debug(f"Exporting final audio to {mp3_path} with bitrate {bitrate}, sample rate {output_sample_rate_int} Hz, bit depth {bit_depth_int}-bit")
904
+ final_segment.export(
905
+ mp3_path,
906
+ format="mp3",
907
+ bitrate=bitrate,
908
+ tags={"title": "GhostAI Instrumental", "artist": "GhostAI"}
909
+ )
910
+ logger.info(f"Final audio saved to {mp3_path}")
911
+ except Exception as e:
912
+ logger.error(f"Error exporting MP3 with bitrate {bitrate}: {e}")
913
+ logger.error(traceback.format_exc())
914
+ fallback_path = f"fallback_output_{int(time.time())}.mp3"
915
+ try:
916
+ final_segment.export(fallback_path, format="mp3", bitrate="128k")
917
+ logger.info(f"Final audio saved to fallback: {fallback_path} with 128 kbps")
918
+ mp3_path = fallback_path
919
+ except Exception as fallback_e:
920
+ logger.error(f"Failed to save fallback MP3: {fallback_e}")
921
+ return None, f"❌ Failed to export audio: {fallback_e}", vram_status
922
+
923
+ vram_status = f"Final VRAM: {torch.cuda.memory_allocated() / 1024**2:.2f} MB"
924
+ logger.info(f"Generation completed in {time.time() - start_time:.2f} seconds")
925
+ return mp3_path, "βœ… Done! Generated track with adjusted volume levels. Check for quality in Audacity.", vram_status
926
+ except Exception as e:
927
+ logger.error(f"Failed to combine audio chunks: {e}")
928
+ logger.error(traceback.format_exc())
929
+ return None, f"❌ Failed to combine audio: {e}", vram_status
930
+ except Exception as e:
931
+ logger.error(f"Generation failed: {e}")
932
+ logger.error(traceback.format_exc())
933
+ return None, f"❌ Generation failed: {e}", vram_status
934
+ finally:
935
+ clean_memory()
936
+
937
+ # Clear inputs function
938
+ def clear_inputs():
939
+ logger.info("Clearing input fields")
940
+ return "", 5.8, 18, 0.88, 0.15, 30, 120, "none", "none", "none", "none", "none", -23.0, "default", 1300, "128k", "44100", "16"
941
+
942
+ # Custom CSS with high-contrast colors and green border on active selection
943
+ css = """
944
+ body {
945
+ background: #121212;
946
+ color: #E6E6E6;
947
+ font-family: 'Arial', sans-serif;
948
+ }
949
+ .header-container {
950
+ text-align: center;
951
+ padding: 15px 20px;
952
+ background: #1E1E1E;
953
+ border-bottom: 2px solid #00C853;
954
+ }
955
+ #ghost-logo {
956
+ font-size: 48px;
957
+ color: #00C853;
958
+ }
959
+ h1 {
960
+ color: #FFD600;
961
+ font-size: 28px;
962
+ font-weight: bold;
963
+ }
964
+ h3 {
965
+ color: #FFD600;
966
+ font-size: 20px;
967
+ font-weight: bold;
968
+ }
969
+ p {
970
+ color: #B0BEC5;
971
+ font-size: 14px;
972
+ }
973
+ .input-container, .settings-container, .output-container, .logs-container {
974
+ max-width: 1200px;
975
+ margin: 20px auto;
976
+ padding: 20px;
977
+ background: #212121;
978
+ border: 1px solid #424242;
979
+ border-radius: 8px;
980
+ }
981
+ .textbox {
982
+ background: #2C2C2C;
983
+ border: 1px solid #B0BEC5;
984
+ color: #E6E6E6;
985
+ font-size: 16px;
986
+ }
987
+ .genre-buttons, .bitrate-buttons, .sample-rate-buttons, .bit-depth-buttons {
988
+ display: flex;
989
+ justify-content: center;
990
+ flex-wrap: wrap;
991
+ gap: 10px;
992
+ }
993
+ .genre-btn, .bitrate-btn, .sample-rate-btn, .bit-depth-btn, button {
994
+ background: #0288D1;
995
+ border: 2px solid transparent;
996
+ color: #FFFFFF;
997
+ padding: 10px 20px;
998
+ border-radius: 5px;
999
+ font-size: 16px;
1000
+ transition: all 0.3s ease;
1001
+ }
1002
+ button:hover {
1003
+ background: #03A9F4;
1004
+ cursor: pointer;
1005
+ }
1006
+ button:active, .genre-btn.active, .bitrate-btn.active, .sample-rate-btn.active, .bit-depth-btn.active {
1007
+ border: 2px solid #00C853 !important;
1008
+ background: #01579B;
1009
+ color: #FFFFFF;
1010
+ }
1011
+ .gradio-container {
1012
+ padding: 20px;
1013
+ }
1014
+ .group-container {
1015
+ margin-bottom: 20px;
1016
+ padding: 15px;
1017
+ border: 1px solid #424242;
1018
+ border-radius: 8px;
1019
+ }
1020
+ .slider-label, .dropdown-label {
1021
+ color: #FFD600;
1022
+ font-size: 16px;
1023
+ font-weight: bold;
1024
+ }
1025
+ .slider, .dropdown {
1026
+ background: #2C2C2C;
1027
+ color: #E6E6E6;
1028
+ }
1029
+ .output-container label, .logs-container label {
1030
+ color: #FFD600;
1031
+ font-size: 16px;
1032
+ font-weight: bold;
1033
+ }
1034
+ """
1035
+
1036
+ # Build Gradio interface with updated visuals and default preset
1037
+ logger.info("Building Gradio interface...")
1038
+ with gr.Blocks(css=css) as demo:
1039
+ gr.Markdown("""
1040
+ <div class="header-container">
1041
+ <div id="ghost-logo">πŸ‘»</div>
1042
+ <h1>GhostAI Music Generator 🎹</h1>
1043
+ <p>Create Instrumental Tracks with Ease</p>
1044
+ </div>
1045
+ """)
1046
+
1047
+ with gr.Column(elem_classes="input-container"):
1048
+ gr.Markdown("### 🎸 Prompt Settings")
1049
+ instrumental_prompt = gr.Textbox(
1050
+ label="Instrumental Prompt ✍️",
1051
+ placeholder="Click a genre button or type your own instrumental prompt",
1052
+ lines=4,
1053
+ elem_classes="textbox"
1054
+ )
1055
+ with gr.Row(elem_classes="genre-buttons"):
1056
+ rhcp_btn = gr.Button("Red Hot Chili Peppers 🌢️", elem_classes="genre-btn")
1057
+ nirvana_btn = gr.Button("Nirvana Grunge 🎸", elem_classes="genre-btn")
1058
+ pearl_jam_btn = gr.Button("Pearl Jam Grunge πŸ¦ͺ", elem_classes="genre-btn")
1059
+ soundgarden_btn = gr.Button("Soundgarden Grunge πŸŒ‘", elem_classes="genre-btn")
1060
+ foo_fighters_btn = gr.Button("Foo Fighters 🀘", elem_classes="genre-btn")
1061
+ smashing_pumpkins_btn = gr.Button("Smashing Pumpkins πŸŽƒ", elem_classes="genre-btn")
1062
+ radiohead_btn = gr.Button("Radiohead 🧠", elem_classes="genre-btn")
1063
+ classic_rock_btn = gr.Button("Metallica Heavy Metal 🎸", elem_classes="genre-btn")
1064
+ alternative_rock_btn = gr.Button("Alternative Rock 🎡", elem_classes="genre-btn")
1065
+ post_punk_btn = gr.Button("Post-Punk πŸ–€", elem_classes="genre-btn")
1066
+ indie_rock_btn = gr.Button("Indie Rock 🎀", elem_classes="genre-btn")
1067
+ funk_rock_btn = gr.Button("Funk Rock πŸ•Ί", elem_classes="genre-btn")
1068
+ detroit_techno_btn = gr.Button("Detroit Techno πŸŽ›οΈ", elem_classes="genre-btn")
1069
+ deep_house_btn = gr.Button("Deep House 🏠", elem_classes="genre-btn")
1070
+
1071
+ with gr.Column(elem_classes="settings-container"):
1072
+ gr.Markdown("### βš™οΈ API Settings")
1073
+ with gr.Group(elem_classes="group-container"):
1074
+ cfg_scale = gr.Slider(
1075
+ label="CFG Scale 🎯",
1076
+ minimum=1.0,
1077
+ maximum=10.0,
1078
+ value=5.8,
1079
+ step=0.1,
1080
+ info="Controls how closely the music follows the prompt."
1081
+ )
1082
+ top_k = gr.Slider(
1083
+ label="Top-K Sampling πŸ”’",
1084
+ minimum=10,
1085
+ maximum=500,
1086
+ value=18,
1087
+ step=10,
1088
+ info="Limits sampling to the top k most likely tokens."
1089
+ )
1090
+ top_p = gr.Slider(
1091
+ label="Top-P Sampling 🎰",
1092
+ minimum=0.0,
1093
+ maximum=1.0,
1094
+ value=0.88,
1095
+ step=0.05,
1096
+ info="Keeps tokens with cumulative probability above p."
1097
+ )
1098
+ temperature = gr.Slider(
1099
+ label="Temperature πŸ”₯",
1100
+ minimum=0.1,
1101
+ maximum=2.0,
1102
+ value=0.15,
1103
+ step=0.1,
1104
+ info="Controls randomness; lower values reduce noise."
1105
+ )
1106
+ total_duration = gr.Dropdown(
1107
+ label="Song Length ⏳ (seconds)",
1108
+ choices=[30, 60, 90, 120],
1109
+ value=30,
1110
+ info="Select the total duration of the track."
1111
+ )
1112
+ bpm = gr.Slider(
1113
+ label="Tempo 🎡 (BPM)",
1114
+ minimum=60,
1115
+ maximum=180,
1116
+ value=120,
1117
+ step=1,
1118
+ info="Beats per minute to set the track's tempo."
1119
+ )
1120
+ drum_beat = gr.Dropdown(
1121
+ label="Drum Beat πŸ₯",
1122
+ choices=["none", "standard rock", "funk groove", "techno kick", "jazz swing"],
1123
+ value="none",
1124
+ info="Select a drum beat style to influence the rhythm."
1125
+ )
1126
+ synthesizer = gr.Dropdown(
1127
+ label="Synthesizer 🎹",
1128
+ choices=["none", "analog synth", "digital pad", "arpeggiated synth"],
1129
+ value="none",
1130
+ info="Select a synthesizer style for electronic accents."
1131
+ )
1132
+ rhythmic_steps = gr.Dropdown(
1133
+ label="Rhythmic Steps πŸ‘£",
1134
+ choices=["none", "syncopated steps", "steady steps", "complex steps"],
1135
+ value="none",
1136
+ info="Select a rhythmic step style to enhance the beat."
1137
+ )
1138
+ bass_style = gr.Dropdown(
1139
+ label="Bass Style 🎸",
1140
+ choices=["none", "slap bass", "deep bass", "melodic bass"],
1141
+ value="none",
1142
+ info="Select a bass style to shape the low end."
1143
+ )
1144
+ guitar_style = gr.Dropdown(
1145
+ label="Guitar Style 🎸",
1146
+ choices=["none", "distorted", "clean", "jangle"],
1147
+ value="none",
1148
+ info="Select a guitar style to define the riffs."
1149
+ )
1150
+ target_volume = gr.Slider(
1151
+ label="Target Volume 🎚️ (dBFS RMS)",
1152
+ minimum=-30.0,
1153
+ maximum=-20.0,
1154
+ value=-23.0,
1155
+ step=1.0,
1156
+ info="Adjust output loudness (-23 dBFS is standard, -20 dBFS is louder, -30 dBFS is quieter)."
1157
+ )
1158
+ preset = gr.Dropdown(
1159
+ label="Preset Configuration πŸŽ›οΈ",
1160
+ choices=["default", "rock", "techno", "grunge", "indie", "funk_rock"],
1161
+ value="default",
1162
+ info="Select a preset optimized for specific genres."
1163
+ )
1164
+ max_steps = gr.Dropdown(
1165
+ label="Max Steps per Chunk πŸ“",
1166
+ choices=[1000, 1200, 1300, 1500],
1167
+ value=1300,
1168
+ info="Number of generation steps per chunk (1300=~26s, extended to 30s)."
1169
+ )
1170
+ bitrate_state = gr.State(value="128k")
1171
+ sample_rate_state = gr.State(value="44100")
1172
+ bit_depth_state = gr.State(value="16")
1173
+ with gr.Row(elem_classes="bitrate-buttons"):
1174
+ bitrate_128_btn = gr.Button("Set Bitrate to 128 kbps", elem_classes="bitrate-btn")
1175
+ bitrate_192_btn = gr.Button("Set Bitrate to 192 kbps", elem_classes="bitrate-btn")
1176
+ bitrate_320_btn = gr.Button("Set Bitrate to 320 kbps", elem_classes="bitrate-btn")
1177
+ with gr.Row(elem_classes="sample-rate-buttons"):
1178
+ sample_rate_22050_btn = gr.Button("Set Sampling Rate to 22.05 kHz", elem_classes="sample-rate-btn")
1179
+ sample_rate_44100_btn = gr.Button("Set Sampling Rate to 44.1 kHz", elem_classes="sample-rate-btn")
1180
+ sample_rate_48000_btn = gr.Button("Set Sampling Rate to 48 kHz", elem_classes="sample-rate-btn")
1181
+ with gr.Row(elem_classes="bit-depth-buttons"):
1182
+ bit_depth_16_btn = gr.Button("Set Bit Depth to 16-bit", elem_classes="bit-depth-btn")
1183
+ bit_depth_24_btn = gr.Button("Set Bit Depth to 24-bit", elem_classes="bit-depth-btn")
1184
+
1185
+ with gr.Row(elem_classes="action-buttons"):
1186
+ gen_btn = gr.Button("Generate Music πŸš€")
1187
+ clr_btn = gr.Button("Clear Inputs 🧹")
1188
+
1189
+ with gr.Column(elem_classes="output-container"):
1190
+ gr.Markdown("### 🎧 Output")
1191
+ out_audio = gr.Audio(label="Generated Instrumental Track 🎡", type="filepath")
1192
+ status = gr.Textbox(label="Status πŸ“’", interactive=False)
1193
+ vram_status = gr.Textbox(label="VRAM Usage πŸ“Š", interactive=False, value="")
1194
+
1195
+ with gr.Column(elem_classes="logs-container"):
1196
+ gr.Markdown("### πŸ“œ Logs")
1197
+ log_output = gr.Textbox(label="Last Log File Contents", lines=20, interactive=False)
1198
+ log_btn = gr.Button("View Last Log πŸ“‹")
1199
+
1200
+ # Add JavaScript to handle button selection visuals
1201
+ def update_button_styles(selected_button):
1202
+ buttons = [
1203
+ "rhcp_btn", "nirvana_btn", "pearl_jam_btn", "soundgarden_btn", "foo_fighters_btn",
1204
+ "smashing_pumpkins_btn", "radiohead_btn", "classic_rock_btn", "alternative_rock_btn",
1205
+ "post_punk_btn", "indie_rock_btn", "funk_rock_btn", "detroit_techno_btn", "deep_house_btn",
1206
+ "bitrate_128_btn", "bitrate_192_btn", "bitrate_320_btn",
1207
+ "sample_rate_22050_btn", "sample_rate_44100_btn", "sample_rate_48000_btn",
1208
+ "bit_depth_16_btn", "bit_depth_24_btn"
1209
+ ]
1210
+ script = """
1211
+ <script>
1212
+ document.querySelectorAll('.genre-btn, .bitrate-btn, .sample-rate-btn, .bit-depth-btn').forEach(btn => {
1213
+ btn.classList.remove('active');
1214
+ });
1215
+ document.querySelector('#""" + selected_button + """').classList.add('active');
1216
+ </script>
1217
+ """
1218
+ return script
1219
+
1220
+ rhcp_btn.click(set_red_hot_chili_peppers_prompt, inputs=[bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style, gr.State(value=1)], outputs=instrumental_prompt, _js=update_button_styles("rhcp_btn"))
1221
+ nirvana_btn.click(set_nirvana_grunge_prompt, inputs=[bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style], outputs=instrumental_prompt, _js=update_button_styles("nirvana_btn"))
1222
+ pearl_jam_btn.click(set_pearl_jam_grunge_prompt, inputs=[bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style], outputs=instrumental_prompt, _js=update_button_styles("pearl_jam_btn"))
1223
+ soundgarden_btn.click(set_soundgarden_grunge_prompt, inputs=[bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style], outputs=instrumental_prompt, _js=update_button_styles("soundgarden_btn"))
1224
+ foo_fighters_btn.click(set_foo_fighters_prompt, inputs=[bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style], outputs=instrumental_prompt, _js=update_button_styles("foo_fighters_btn"))
1225
+ smashing_pumpkins_btn.click(set_smashing_pumpkins_prompt, inputs=[bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style], outputs=instrumental_prompt, _js=update_button_styles("smashing_pumpkins_btn"))
1226
+ radiohead_btn.click(set_radiohead_prompt, inputs=[bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style], outputs=instrumental_prompt, _js=update_button_styles("radiohead_btn"))
1227
+ classic_rock_btn.click(set_classic_rock_prompt, inputs=[bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style], outputs=instrumental_prompt, _js=update_button_styles("classic_rock_btn"))
1228
+ alternative_rock_btn.click(set_alternative_rock_prompt, inputs=[bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style], outputs=instrumental_prompt, _js=update_button_styles("alternative_rock_btn"))
1229
+ post_punk_btn.click(set_post_punk_prompt, inputs=[bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style], outputs=instrumental_prompt, _js=update_button_styles("post_punk_btn"))
1230
+ indie_rock_btn.click(set_indie_rock_prompt, inputs=[bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style], outputs=instrumental_prompt, _js=update_button_styles("indie_rock_btn"))
1231
+ funk_rock_btn.click(set_funk_rock_prompt, inputs=[bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style], outputs=instrumental_prompt, _js=update_button_styles("funk_rock_btn"))
1232
+ detroit_techno_btn.click(set_detroit_techno_prompt, inputs=[bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style], outputs=instrumental_prompt, _js=update_button_styles("detroit_techno_btn"))
1233
+ deep_house_btn.click(set_deep_house_prompt, inputs=[bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style], outputs=instrumental_prompt, _js=update_button_styles("deep_house_btn"))
1234
+ bitrate_128_btn.click(set_bitrate_128, inputs=None, outputs=bitrate_state, _js=update_button_styles("bitrate_128_btn"))
1235
+ bitrate_192_btn.click(set_bitrate_192, inputs=None, outputs=bitrate_state, _js=update_button_styles("bitrate_192_btn"))
1236
+ bitrate_320_btn.click(set_bitrate_320, inputs=None, outputs=bitrate_state, _js=update_button_styles("bitrate_320_btn"))
1237
+ sample_rate_22050_btn.click(set_sample_rate_22050, inputs=None, outputs=sample_rate_state, _js=update_button_styles("sample_rate_22050_btn"))
1238
+ sample_rate_44100_btn.click(set_sample_rate_44100, inputs=None, outputs=sample_rate_state, _js=update_button_styles("sample_rate_44100_btn"))
1239
+ sample_rate_48000_btn.click(set_sample_rate_48000, inputs=None, outputs=sample_rate_state, _js=update_button_styles("sample_rate_48000_btn"))
1240
+ bit_depth_16_btn.click(set_bit_depth_16, inputs=None, outputs=bit_depth_state, _js=update_button_styles("bit_depth_16_btn"))
1241
+ bit_depth_24_btn.click(set_bit_depth_24, inputs=None, outputs=bit_depth_state, _js=update_button_styles("bit_depth_24_btn"))
1242
+ gen_btn.click(
1243
+ generate_music_wrapper,
1244
+ inputs=[instrumental_prompt, cfg_scale, top_k, top_p, temperature, total_duration, bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style, target_volume, preset, max_steps, vram_status, bitrate_state, sample_rate_state, bit_depth_state],
1245
+ outputs=[out_audio, status, vram_status]
1246
+ )
1247
+ clr_btn.click(
1248
+ clear_inputs,
1249
+ inputs=None,
1250
+ outputs=[instrumental_prompt, cfg_scale, top_k, top_p, temperature, total_duration, bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style, target_volume, preset, max_steps, bitrate_state, sample_rate_state, bit_depth_state]
1251
+ )
1252
+ log_btn.click(
1253
+ get_latest_log,
1254
+ inputs=None,
1255
+ outputs=log_output
1256
+ )
1257
+
1258
+ # Launch locally without OpenAPI/docs
1259
+ logger.info("Launching Gradio UI at http://localhost:9999...")
1260
+ try:
1261
+ app = demo.launch(
1262
+ server_name="0.0.0.0",
1263
+ server_port=9999,
1264
+ share=False,
1265
+ inbrowser=False,
1266
+ show_error=True
1267
+ )
1268
+ except Exception as e:
1269
+ logger.error(f"Failed to launch Gradio UI: {e}")
1270
+ logger.error(traceback.format_exc())
1271
+ sys.exit(1)