Spaces:
Running
Running
import gradio as gr | |
import os | |
import torch | |
from transformers import ( | |
AutoTokenizer, | |
AutoModelForCausalLM, | |
pipeline, | |
AutoProcessor, | |
MusicgenForConditionalGeneration, | |
) | |
from scipy.io.wavfile import write | |
from pydub import AudioSegment | |
from dotenv import load_dotenv | |
import tempfile | |
import spaces | |
from TTS.api import TTS | |
# ------------------------------- | |
# Configuration | |
# ------------------------------- | |
load_dotenv() | |
HF_TOKEN = os.getenv("HF_TOKEN") | |
MODEL_CONFIG = { | |
"llama_models": { | |
"Meta-Llama-3-8B": "meta-llama/Meta-Llama-3-8B-Instruct", | |
"Mistral-7B": "mistralai/Mistral-7B-Instruct-v0.2", | |
"Phi-3-mini": "microsoft/Phi-3-mini-4k-instruct" | |
}, | |
"tts_models": { | |
"Standard English": "tts_models/en/ljspeech/tacotron2-DDC", | |
"High Quality": "tts_models/en/ljspeech/vits", | |
"Fast Inference": "tts_models/en/sam/tacotron-DDC" | |
} | |
} | |
# ------------------------------- | |
# Model Manager | |
# ------------------------------- | |
class ModelManager: | |
def __init__(self): | |
self.llama_pipelines = {} | |
self.musicgen_models = {} | |
self.tts_models = {} | |
def get_llama_pipeline(self, model_id, token): | |
if model_id not in self.llama_pipelines: | |
tokenizer = AutoTokenizer.from_pretrained(model_id, token=token) | |
model = AutoModelForCausalLM.from_pretrained( | |
model_id, | |
token=token, | |
torch_dtype=torch.float16, | |
device_map="auto", | |
attn_implementation="flash_attention_2" | |
) | |
self.llama_pipelines[model_id] = pipeline( | |
"text-generation", | |
model=model, | |
tokenizer=tokenizer, | |
device_map="auto" | |
) | |
return self.llama_pipelines[model_id] | |
def get_musicgen_model(self, model_key="facebook/musicgen-large"): | |
if model_key not in self.musicgen_models: | |
model = MusicgenForConditionalGeneration.from_pretrained(model_key) | |
processor = AutoProcessor.from_pretrained(model_key) | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
model.to(device) | |
self.musicgen_models[model_key] = (model, processor) | |
return self.musicgen_models[model_key] | |
def get_tts_model(self, model_name): | |
if model_name not in self.tts_models: | |
self.tts_models[model_name] = TTS(model_name) | |
return self.tts_models[model_name] | |
model_manager = ModelManager() | |
# ------------------------------- | |
# Core Functions | |
# ------------------------------- | |
def generate_script(user_prompt, model_id, duration, temperature=0.7, max_tokens=512): | |
try: | |
text_pipeline = model_manager.get_llama_pipeline(model_id, HF_TOKEN) | |
system_prompt = f"""You are an AI audio production assistant. Create content for a {duration}-second promo: | |
1. Voice Script: [Clear, engaging narration] | |
2. Sound Design: [3-5 specific sound effects] | |
3. Music: [Genre, tempo, mood suggestions] | |
Keep sections concise and production-ready.""" | |
messages = [ | |
{"role": "system", "content": system_prompt}, | |
{"role": "user", "content": user_prompt} | |
] | |
response = text_pipeline( | |
messages, | |
max_new_tokens=max_tokens, | |
temperature=temperature, | |
do_sample=True, | |
top_p=0.95, | |
eos_token_id=text_pipeline.tokenizer.eos_token_id | |
) | |
return parse_generated_content(response[0]['generated_text'][-1]['content']) | |
except Exception as e: | |
return f"Error: {str(e)}", "", "" | |
def parse_generated_content(text): | |
sections = { | |
"Voice Script": "", | |
"Sound Design": "", | |
"Music": "" | |
} | |
current_section = None | |
for line in text.split('\n'): | |
line = line.strip() | |
if "Voice Script:" in line: | |
current_section = "Voice Script" | |
line = line.replace("Voice Script:", "").strip() | |
elif "Sound Design:" in line: | |
current_section = "Sound Design" | |
line = line.replace("Sound Design:", "").strip() | |
elif "Music:" in line: | |
current_section = "Music" | |
line = line.replace("Music:", "").strip() | |
if current_section and line: | |
sections[current_section] += line + "\n" | |
return sections["Voice Script"].strip(), sections["Sound Design"].strip(), sections["Music"].strip() | |
def generate_voice(script, tts_model, speed=1.0): | |
try: | |
if not script.strip(): | |
raise ValueError("Empty script") | |
tts = model_manager.get_tts_model(tts_model) | |
output_path = os.path.join(tempfile.gettempdir(), "enhanced_voice.wav") | |
tts.tts_to_file( | |
text=script, | |
file_path=output_path, | |
speed=speed | |
) | |
return output_path | |
except Exception as e: | |
return f"Error: {str(e)}" | |
def generate_music(prompt, duration_sec=30, temperature=1.0, guidance_scale=3.0): | |
try: | |
model, processor = model_manager.get_musicgen_model() | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
inputs = processor( | |
text=[prompt], | |
padding=True, | |
return_tensors="pt", | |
).to(device) | |
audio_values = model.generate( | |
**inputs, | |
max_new_tokens=int(duration_sec * 50), | |
temperature=temperature, | |
guidance_scale=guidance_scale, | |
do_sample=True | |
) | |
output_path = os.path.join(tempfile.gettempdir(), "enhanced_music.wav") | |
write(output_path, 32000, audio_values[0, 0].cpu().numpy()) | |
return output_path | |
except Exception as e: | |
return f"Error: {str(e)}" | |
def blend_audio(voice_path, music_path, ducking=True, duck_level=10, crossfade=500): | |
try: | |
voice = AudioSegment.from_wav(voice_path) | |
music = AudioSegment.from_wav(music_path) | |
if len(music) < len(voice): | |
loops = (len(voice) // len(music)) + 1 | |
music = music * loops | |
music = music[:len(voice)].fade_out(crossfade) | |
if ducking: | |
ducked_music = music - duck_level | |
mixed = ducked_music.overlay(voice.fade_in(crossfade)) | |
else: | |
mixed = music.overlay(voice) | |
output_path = os.path.join(tempfile.gettempdir(), "enhanced_mix.wav") | |
mixed.export(output_path, format="wav") | |
return output_path | |
except Exception as e: | |
return f"Error: {str(e)}" | |
# ------------------------------- | |
# Gradio Interface | |
# ------------------------------- | |
theme = gr.themes.Soft( | |
primary_hue="blue", | |
secondary_hue="teal", | |
).set( | |
body_text_color_dark='#FFFFFF', | |
background_fill_primary_dark='#1F1F1F' | |
) | |
with gr.Blocks(theme=theme, title="AI Audio Studio Pro") as demo: | |
gr.Markdown(""" | |
# ποΈ AI Audio Studio Pro | |
*Next-generation audio production powered by AI* | |
""") | |
with gr.Tabs(): | |
with gr.Tab("π― Concept Development"): | |
with gr.Row(): | |
with gr.Column(scale=2): | |
concept_input = gr.Textbox( | |
label="Your Concept", | |
placeholder="Describe your audio project...", | |
lines=3, | |
max_lines=6 | |
) | |
with gr.Accordion("Advanced Settings", open=False): | |
with gr.Row(): | |
model_selector = gr.Dropdown( | |
choices=list(MODEL_CONFIG["llama_models"].values()), | |
label="AI Model", | |
value=MODEL_CONFIG["llama_models"]["Meta-Llama-3-8B"] | |
) | |
duration_slider = gr.Slider(15, 120, value=30, step=15, label="Duration (seconds)") | |
with gr.Row(): | |
temp_slider = gr.Slider(0.1, 1.5, value=0.7, step=0.1, label="Creativity") | |
token_slider = gr.Slider(128, 1024, value=512, step=128, label="Max Length") | |
generate_btn = gr.Button("β¨ Generate Concept", variant="primary") | |
with gr.Column(scale=1): | |
script_output = gr.Textbox(label="Voice Script", interactive=True) | |
sound_output = gr.Textbox(label="Sound Design", interactive=True) | |
music_output = gr.Textbox(label="Music Suggestions", interactive=True) | |
generate_btn.click( | |
generate_script, | |
inputs=[concept_input, model_selector, duration_slider, temp_slider, token_slider], | |
outputs=[script_output, sound_output, music_output] | |
) | |
with gr.Tab("π£οΈ Voice Production"): | |
with gr.Row(): | |
with gr.Column(): | |
tts_model = gr.Dropdown( | |
choices=list(MODEL_CONFIG["tts_models"].values()), | |
label="Voice Model", | |
value=MODEL_CONFIG["tts_models"]["Standard English"] | |
) | |
speed_slider = gr.Slider(0.5, 2.0, value=1.0, step=0.1, label="Speaking Rate") | |
voice_btn = gr.Button("ποΈ Generate Voiceover", variant="primary") | |
with gr.Column(): | |
voice_preview = gr.Audio(label="Preview", interactive=False) | |
voice_btn.click( | |
generate_voice, | |
inputs=[script_output, tts_model, speed_slider], | |
outputs=voice_preview | |
) | |
with gr.Tab("πΆ Music Production"): | |
with gr.Row(): | |
with gr.Column(): | |
with gr.Accordion("Music Parameters", open=True): | |
music_duration = gr.Slider(10, 120, value=30, label="Duration (seconds)") | |
music_temp = gr.Slider(0.1, 2.0, value=1.0, label="Creativity") | |
guidance_scale = gr.Slider(1.0, 5.0, value=3.0, label="Focus") | |
music_btn = gr.Button("π΅ Generate Music", variant="primary") | |
with gr.Column(): | |
music_preview = gr.Audio(label="Preview", interactive=False) | |
music_btn.click( | |
generate_music, | |
inputs=[music_output, music_duration, music_temp, guidance_scale], | |
outputs=music_preview | |
) | |
with gr.Tab("π Final Mix"): | |
with gr.Row(): | |
with gr.Column(): | |
ducking_toggle = gr.Checkbox(value=True, label="Enable Voice Ducking") | |
duck_level = gr.Slider(0, 30, value=12, label="Ducking Strength (dB)") | |
crossfade_time = gr.Slider(0, 2000, value=500, label="Crossfade (ms)") | |
mix_btn = gr.Button("π Create Final Mix", variant="primary") | |
with gr.Column(): | |
final_mix = gr.Audio(label="Master Output", interactive=False) | |
mix_btn.click( | |
blend_audio, | |
inputs=[voice_preview, music_preview, ducking_toggle, duck_level, crossfade_time], | |
outputs=final_mix | |
) | |
with gr.Accordion("π Example Prompts", open=False): | |
gr.Examples( | |
examples=[ | |
["A 30-second tech podcast intro with futuristic sounds"], | |
["A 15-second radio ad for a coffee shop with morning vibes"], | |
["A 60-second documentary trailer with epic orchestral music"] | |
], | |
inputs=concept_input | |
) | |
with gr.Row(): | |
gr.Markdown("### System Resources") | |
gpu_status = gr.Textbox(label="GPU Utilization", interactive=False) | |
ram_status = gr.Textbox(label="RAM Usage", interactive=False) | |
# Custom Footer | |
gr.Markdown(""" | |
<hr> | |
<p style="text-align: center; font-size: 0.9em;"> | |
Created with β€οΈ by <a href="https://bilsimaging.com" target="_blank">bilsimaging.com</a> | |
</p> | |
""") | |
gr.HTML(""" | |
<a href="https://visitorbadge.io/status?path=https%3A%2F%2Fhuggingface.co%2Fspaces%2FBils%2Fradiogold"> | |
<img src="https://api.visitorbadge.io/api/visitors?path=https%3A%2F%2Fhuggingface.co%2Fspaces%2FBils%2Fradiogold&countColor=%23263759" /> | |
</a> | |
""") | |
if __name__ == "__main__": | |
demo.launch(server_name="0.0.0.0", server_port=7860) |