AIPromoStudio / app.py
Bils's picture
Update app.py
7b531cd verified
raw
history blame
12.8 kB
import gradio as gr
import os
import torch
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
pipeline,
AutoProcessor,
MusicgenForConditionalGeneration,
)
from scipy.io.wavfile import write
from pydub import AudioSegment
from dotenv import load_dotenv
import tempfile
import spaces
from TTS.api import TTS
# -------------------------------
# Configuration
# -------------------------------
load_dotenv()
HF_TOKEN = os.getenv("HF_TOKEN")
MODEL_CONFIG = {
"llama_models": {
"Meta-Llama-3-8B": "meta-llama/Meta-Llama-3-8B-Instruct",
"Mistral-7B": "mistralai/Mistral-7B-Instruct-v0.2",
"Phi-3-mini": "microsoft/Phi-3-mini-4k-instruct"
},
"tts_models": {
"Standard English": "tts_models/en/ljspeech/tacotron2-DDC",
"High Quality": "tts_models/en/ljspeech/vits",
"Fast Inference": "tts_models/en/sam/tacotron-DDC"
}
}
# -------------------------------
# Model Manager
# -------------------------------
class ModelManager:
def __init__(self):
self.llama_pipelines = {}
self.musicgen_models = {}
self.tts_models = {}
def get_llama_pipeline(self, model_id, token):
if model_id not in self.llama_pipelines:
tokenizer = AutoTokenizer.from_pretrained(model_id, token=token)
model = AutoModelForCausalLM.from_pretrained(
model_id,
token=token,
torch_dtype=torch.float16,
device_map="auto",
attn_implementation="flash_attention_2"
)
self.llama_pipelines[model_id] = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
device_map="auto"
)
return self.llama_pipelines[model_id]
def get_musicgen_model(self, model_key="facebook/musicgen-large"):
if model_key not in self.musicgen_models:
model = MusicgenForConditionalGeneration.from_pretrained(model_key)
processor = AutoProcessor.from_pretrained(model_key)
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
self.musicgen_models[model_key] = (model, processor)
return self.musicgen_models[model_key]
def get_tts_model(self, model_name):
if model_name not in self.tts_models:
self.tts_models[model_name] = TTS(model_name)
return self.tts_models[model_name]
model_manager = ModelManager()
# -------------------------------
# Core Functions
# -------------------------------
@spaces.GPU(duration=120)
def generate_script(user_prompt, model_id, duration, temperature=0.7, max_tokens=512):
try:
text_pipeline = model_manager.get_llama_pipeline(model_id, HF_TOKEN)
system_prompt = f"""You are an AI audio production assistant. Create content for a {duration}-second promo:
1. Voice Script: [Clear, engaging narration]
2. Sound Design: [3-5 specific sound effects]
3. Music: [Genre, tempo, mood suggestions]
Keep sections concise and production-ready."""
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt}
]
response = text_pipeline(
messages,
max_new_tokens=max_tokens,
temperature=temperature,
do_sample=True,
top_p=0.95,
eos_token_id=text_pipeline.tokenizer.eos_token_id
)
return parse_generated_content(response[0]['generated_text'][-1]['content'])
except Exception as e:
return f"Error: {str(e)}", "", ""
def parse_generated_content(text):
sections = {
"Voice Script": "",
"Sound Design": "",
"Music": ""
}
current_section = None
for line in text.split('\n'):
line = line.strip()
if "Voice Script:" in line:
current_section = "Voice Script"
line = line.replace("Voice Script:", "").strip()
elif "Sound Design:" in line:
current_section = "Sound Design"
line = line.replace("Sound Design:", "").strip()
elif "Music:" in line:
current_section = "Music"
line = line.replace("Music:", "").strip()
if current_section and line:
sections[current_section] += line + "\n"
return sections["Voice Script"].strip(), sections["Sound Design"].strip(), sections["Music"].strip()
@spaces.GPU(duration=100)
def generate_voice(script, tts_model, speed=1.0):
try:
if not script.strip():
raise ValueError("Empty script")
tts = model_manager.get_tts_model(tts_model)
output_path = os.path.join(tempfile.gettempdir(), "enhanced_voice.wav")
tts.tts_to_file(
text=script,
file_path=output_path,
speed=speed
)
return output_path
except Exception as e:
return f"Error: {str(e)}"
@spaces.GPU(duration=150)
def generate_music(prompt, duration_sec=30, temperature=1.0, guidance_scale=3.0):
try:
model, processor = model_manager.get_musicgen_model()
device = "cuda" if torch.cuda.is_available() else "cpu"
inputs = processor(
text=[prompt],
padding=True,
return_tensors="pt",
).to(device)
audio_values = model.generate(
**inputs,
max_new_tokens=int(duration_sec * 50),
temperature=temperature,
guidance_scale=guidance_scale,
do_sample=True
)
output_path = os.path.join(tempfile.gettempdir(), "enhanced_music.wav")
write(output_path, 32000, audio_values[0, 0].cpu().numpy())
return output_path
except Exception as e:
return f"Error: {str(e)}"
def blend_audio(voice_path, music_path, ducking=True, duck_level=10, crossfade=500):
try:
voice = AudioSegment.from_wav(voice_path)
music = AudioSegment.from_wav(music_path)
if len(music) < len(voice):
loops = (len(voice) // len(music)) + 1
music = music * loops
music = music[:len(voice)].fade_out(crossfade)
if ducking:
ducked_music = music - duck_level
mixed = ducked_music.overlay(voice.fade_in(crossfade))
else:
mixed = music.overlay(voice)
output_path = os.path.join(tempfile.gettempdir(), "enhanced_mix.wav")
mixed.export(output_path, format="wav")
return output_path
except Exception as e:
return f"Error: {str(e)}"
# -------------------------------
# Gradio Interface
# -------------------------------
theme = gr.themes.Soft(
primary_hue="blue",
secondary_hue="teal",
).set(
body_text_color_dark='#FFFFFF',
background_fill_primary_dark='#1F1F1F'
)
with gr.Blocks(theme=theme, title="AI Audio Studio Pro") as demo:
gr.Markdown("""
# πŸŽ™οΈ AI Audio Studio Pro
*Next-generation audio production powered by AI*
""")
with gr.Tabs():
with gr.Tab("🎯 Concept Development"):
with gr.Row():
with gr.Column(scale=2):
concept_input = gr.Textbox(
label="Your Concept",
placeholder="Describe your audio project...",
lines=3,
max_lines=6
)
with gr.Accordion("Advanced Settings", open=False):
with gr.Row():
model_selector = gr.Dropdown(
choices=list(MODEL_CONFIG["llama_models"].values()),
label="AI Model",
value=MODEL_CONFIG["llama_models"]["Meta-Llama-3-8B"]
)
duration_slider = gr.Slider(15, 120, value=30, step=15, label="Duration (seconds)")
with gr.Row():
temp_slider = gr.Slider(0.1, 1.5, value=0.7, step=0.1, label="Creativity")
token_slider = gr.Slider(128, 1024, value=512, step=128, label="Max Length")
generate_btn = gr.Button("✨ Generate Concept", variant="primary")
with gr.Column(scale=1):
script_output = gr.Textbox(label="Voice Script", interactive=True)
sound_output = gr.Textbox(label="Sound Design", interactive=True)
music_output = gr.Textbox(label="Music Suggestions", interactive=True)
generate_btn.click(
generate_script,
inputs=[concept_input, model_selector, duration_slider, temp_slider, token_slider],
outputs=[script_output, sound_output, music_output]
)
with gr.Tab("πŸ—£οΈ Voice Production"):
with gr.Row():
with gr.Column():
tts_model = gr.Dropdown(
choices=list(MODEL_CONFIG["tts_models"].values()),
label="Voice Model",
value=MODEL_CONFIG["tts_models"]["Standard English"]
)
speed_slider = gr.Slider(0.5, 2.0, value=1.0, step=0.1, label="Speaking Rate")
voice_btn = gr.Button("πŸŽ™οΈ Generate Voiceover", variant="primary")
with gr.Column():
voice_preview = gr.Audio(label="Preview", interactive=False)
voice_btn.click(
generate_voice,
inputs=[script_output, tts_model, speed_slider],
outputs=voice_preview
)
with gr.Tab("🎢 Music Production"):
with gr.Row():
with gr.Column():
with gr.Accordion("Music Parameters", open=True):
music_duration = gr.Slider(10, 120, value=30, label="Duration (seconds)")
music_temp = gr.Slider(0.1, 2.0, value=1.0, label="Creativity")
guidance_scale = gr.Slider(1.0, 5.0, value=3.0, label="Focus")
music_btn = gr.Button("🎡 Generate Music", variant="primary")
with gr.Column():
music_preview = gr.Audio(label="Preview", interactive=False)
music_btn.click(
generate_music,
inputs=[music_output, music_duration, music_temp, guidance_scale],
outputs=music_preview
)
with gr.Tab("πŸ”Š Final Mix"):
with gr.Row():
with gr.Column():
ducking_toggle = gr.Checkbox(value=True, label="Enable Voice Ducking")
duck_level = gr.Slider(0, 30, value=12, label="Ducking Strength (dB)")
crossfade_time = gr.Slider(0, 2000, value=500, label="Crossfade (ms)")
mix_btn = gr.Button("πŸš€ Create Final Mix", variant="primary")
with gr.Column():
final_mix = gr.Audio(label="Master Output", interactive=False)
mix_btn.click(
blend_audio,
inputs=[voice_preview, music_preview, ducking_toggle, duck_level, crossfade_time],
outputs=final_mix
)
with gr.Accordion("πŸ“š Example Prompts", open=False):
gr.Examples(
examples=[
["A 30-second tech podcast intro with futuristic sounds"],
["A 15-second radio ad for a coffee shop with morning vibes"],
["A 60-second documentary trailer with epic orchestral music"]
],
inputs=concept_input
)
with gr.Row():
gr.Markdown("### System Resources")
gpu_status = gr.Textbox(label="GPU Utilization", interactive=False)
ram_status = gr.Textbox(label="RAM Usage", interactive=False)
# Custom Footer
gr.Markdown("""
<hr>
<p style="text-align: center; font-size: 0.9em;">
Created with ❀️ by <a href="https://bilsimaging.com" target="_blank">bilsimaging.com</a>
</p>
""")
gr.HTML("""
<a href="https://visitorbadge.io/status?path=https%3A%2F%2Fhuggingface.co%2Fspaces%2FBils%2Fradiogold">
<img src="https://api.visitorbadge.io/api/visitors?path=https%3A%2F%2Fhuggingface.co%2Fspaces%2FBils%2Fradiogold&countColor=%23263759" />
</a>
""")
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860)