Spaces:
Running
Running
File size: 4,696 Bytes
17d10a7 a15d204 d448add db46bfb 1c1b50f db46bfb 1c1b50f db8ba25 db46bfb cf3593c 16060e9 c243adb dfa5d3e cf3593c e7b189b 1c1b50f db8ba25 1c1b50f 07c07fa db8ba25 dfa5d3e db8ba25 dfa5d3e db8ba25 dfa5d3e 8b6a33e dfa5d3e db8ba25 dfa5d3e 16060e9 e7b189b 17d10a7 db8ba25 8b6a33e cf3593c 8b6a33e 17d10a7 d448add cf3593c d448add db8ba25 dfa5d3e db8ba25 dfa5d3e 16060e9 db8ba25 dfa5d3e 5080bd7 16060e9 07c07fa db8ba25 07c07fa 0abc339 07c07fa db8ba25 07c07fa db8ba25 07c07fa db8ba25 07c07fa db8ba25 07c07fa 3fe530b 1a0bb5e a8c9cb5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 |
import gradio as gr
import os
import torch
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
pipeline,
AutoProcessor,
MusicgenForConditionalGeneration,
)
from scipy.io.wavfile import write
import tempfile
from dotenv import load_dotenv
import spaces # Assumes Hugging Face Spaces library supports `@spaces.GPU`
# Load environment variables (e.g., Hugging Face token)
load_dotenv()
hf_token = os.getenv("HF_TOKEN")
# ---------------------------------------------------------------------
# Load Llama 3 Pipeline with Zero GPU (Encapsulated)
# ---------------------------------------------------------------------
@spaces.GPU(duration=300) # Adjust GPU allocation duration
def generate_script(user_prompt: str, model_id: str, token: str):
try:
tokenizer = AutoTokenizer.from_pretrained(model_id, use_auth_token=token)
model = AutoModelForCausalLM.from_pretrained(
model_id,
use_auth_token=token,
torch_dtype=torch.float16,
device_map="auto",
trust_remote_code=True,
)
llama_pipeline = pipeline("text-generation", model=model, tokenizer=tokenizer)
system_prompt = (
"You are a top-tier radio imaging producer using Llama 3. "
"Take the user's concept and craft a short, creative promo script."
)
combined_prompt = f"{system_prompt}\nUser concept: {user_prompt}\nRefined script:"
result = llama_pipeline(combined_prompt, max_new_tokens=200, do_sample=True, temperature=0.9)
return result[0]["generated_text"].split("Refined script:")[-1].strip()
except Exception as e:
return f"Error generating script: {e}"
# ---------------------------------------------------------------------
# Load MusicGen Model (Encapsulated)
# ---------------------------------------------------------------------
@spaces.GPU(duration=300)
def generate_audio(prompt: str, audio_length: int):
try:
musicgen_model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-small")
musicgen_processor = AutoProcessor.from_pretrained("facebook/musicgen-small")
musicgen_model.to("cuda")
inputs = musicgen_processor(text=[prompt], padding=True, return_tensors="pt")
outputs = musicgen_model.generate(**inputs, max_new_tokens=audio_length)
musicgen_model.to("cpu") # Return the model to CPU
sr = musicgen_model.config.audio_encoder.sampling_rate
audio_data = outputs[0, 0].cpu().numpy()
normalized_audio = (audio_data / max(abs(audio_data)) * 32767).astype("int16")
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_wav:
write(temp_wav.name, sr, normalized_audio)
return temp_wav.name
except Exception as e:
return f"Error generating audio: {e}"
# ---------------------------------------------------------------------
# Gradio Interface
# ---------------------------------------------------------------------
def interface_generate_script(user_prompt, llama_model_id):
return generate_script(user_prompt, llama_model_id, hf_token)
def interface_generate_audio(script, audio_length):
return generate_audio(script, audio_length)
# ---------------------------------------------------------------------
# Interface
# ---------------------------------------------------------------------
with gr.Blocks() as demo:
gr.Markdown("# 🎧 AI Radio Imaging with Llama 3 + MusicGen (Zero GPU)")
with gr.Row():
user_prompt = gr.Textbox(
label="Enter your promo idea",
placeholder="E.g., A 15-second hype jingle for a morning talk show.",
)
llama_model_id = gr.Textbox(
label="Llama 3 Model ID", value="meta-llama/Meta-Llama-3-8B-Instruct"
)
audio_length = gr.Slider(label="Audio Length (tokens)", minimum=128, maximum=1024, step=64, value=512)
generate_script_button = gr.Button("Generate Script")
script_output = gr.Textbox(label="Generated Script")
generate_audio_button = gr.Button("Generate Audio")
audio_output = gr.Audio(label="Generated Audio", type="filepath")
generate_script_button.click(
fn=interface_generate_script,
inputs=[user_prompt, llama_model_id],
outputs=script_output,
)
generate_audio_button.click(
fn=interface_generate_audio,
inputs=[script_output, audio_length],
outputs=audio_output,
)
# ---------------------------------------------------------------------
# Launch App
# ---------------------------------------------------------------------
demo.launch(debug=True)
|