import streamlit as st import requests import torch import scipy.io.wavfile from transformers import ( AutoTokenizer, AutoModelForCausalLM, pipeline, AutoProcessor, MusicgenForConditionalGeneration ) from io import BytesIO from streamlit_lottie import st_lottie # --------------------------------------------------------------------- # 1) PAGE CONFIG # --------------------------------------------------------------------- st.set_page_config( page_title="Radio Imaging AI with Llama 3", page_icon="🎧", layout="wide" ) # --------------------------------------------------------------------- # 2) CUSTOM CSS / SPOTIFY-LIKE UI # --------------------------------------------------------------------- CUSTOM_CSS = """ """ st.markdown(CUSTOM_CSS, unsafe_allow_html=True) # --------------------------------------------------------------------- # 3) LOAD LOTTIE ANIMATION # --------------------------------------------------------------------- @st.cache_data def load_lottie_url(url: str): r = requests.get(url) if r.status_code != 200: return None return r.json() LOTTIE_URL = "https://assets3.lottiefiles.com/temp/lf20_Q6h5zV.json" lottie_animation = load_lottie_url(LOTTIE_URL) # --------------------------------------------------------------------- # 4) LOAD LLAMA 3 (GATED MODEL) - WITH use_auth_token # --------------------------------------------------------------------- @st.cache_resource def load_llama_pipeline(model_id: str, device: str): """ Load the Llama 3 model from Hugging Face. Requires huggingface-cli login if model is gated. """ tokenizer = AutoTokenizer.from_pretrained(model_id, use_auth_token=True) model = AutoModelForCausalLM.from_pretrained( model_id, torch_dtype=torch.float16 if device == "auto" else torch.float32, device_map=device, use_auth_token=True ) text_gen_pipeline = pipeline( "text-generation", model=model, tokenizer=tokenizer, device_map=device ) return text_gen_pipeline # --------------------------------------------------------------------- # 5) REFINE SCRIPT (LLAMA) # --------------------------------------------------------------------- def generate_radio_script(user_input: str, pipeline_llama) -> str: system_prompt = ( "You are a top-tier radio imaging producer using Llama 3. " "Take the user's concept and craft a short, creative promo script." ) combined_prompt = f"{system_prompt}\nUser concept: {user_input}\nRefined script:" result = pipeline_llama( combined_prompt, max_new_tokens=200, do_sample=True, temperature=0.9 ) output_text = result[0]["generated_text"] if "Refined script:" in output_text: output_text = output_text.split("Refined script:", 1)[-1].strip() output_text += "\n\n(Generated by Llama 3 - Radio Imaging)" return output_text # --------------------------------------------------------------------- # 6) LOAD MUSICGEN # --------------------------------------------------------------------- @st.cache_resource def load_musicgen_model(): mg_model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-small") mg_processor = AutoProcessor.from_pretrained("facebook/musicgen-small") return mg_model, mg_processor # --------------------------------------------------------------------- # 7) SIDEBAR # --------------------------------------------------------------------- with st.sidebar: st.header("🎚 Radio Library") st.write("**My Stations**") st.write("- Favorites") st.write("- Recently Generated") st.write("- Top Hits") st.write("---") st.write("**Settings**") st.markdown("
", unsafe_allow_html=True) # --------------------------------------------------------------------- # 8) HEADER # --------------------------------------------------------------------- col1, col2 = st.columns([3, 2], gap="large") with col1: st.title("AI Radio Imaging with Llama 3") st.subheader("Gated Model + MusicGen Audio") st.markdown( """ Create **radio imaging promos** and **jingles** with Llama 3 + MusicGen. **Note**: You must have access to `"meta-llama/Llama-3-70B-Instruct"` on Hugging Face, and be logged in via `huggingface-cli login`. """ ) with col2: if lottie_animation: with st.container(): st_lottie(lottie_animation, height=180, loop=True, key="radio_lottie") else: st.write("*No animation loaded.*") st.markdown("---") # --------------------------------------------------------------------- # 9) SCRIPT GENERATION # --------------------------------------------------------------------- st.subheader("🎙 Step 1: Describe Your Promo Idea") prompt = st.text_area( "Example: 'A 15-second hype jingle for a morning talk show, fun and energetic.'", height=120 ) col_model, col_device = st.columns(2) with col_model: llama_model_id = st.text_input( "Llama 3 Model ID", value="meta-llama/Llama-3-70B-Instruct", # Official ID if you have it help="Use the exact name you see on the Hugging Face model page." ) with col_device: device_option = st.selectbox( "Device (GPU vs CPU)", ["auto", "cpu"], help="If you have GPU, 'auto' tries to use it; CPU might be slow." ) if st.button("📝 Generate Promo Script"): if not prompt.strip(): st.error("Please type some concept first.") else: with st.spinner("Generating script with Llama 3..."): try: llm_pipeline = load_llama_pipeline(llama_model_id, device_option) final_script = generate_radio_script(prompt, llm_pipeline) st.session_state["final_script"] = final_script st.success("Promo script generated!") st.write(final_script) except Exception as e: st.error(f"Llama generation error: {e}") st.markdown("---") # --------------------------------------------------------------------- # 10) AUDIO GENERATION: MUSICGEN # --------------------------------------------------------------------- st.subheader("🎶 Step 2: Generate Audio") audio_length = st.slider("MusicGen Max Tokens (approx track length)", 128, 1024, 512, 64) if st.button("🎧 Create Audio with MusicGen"): if "final_script" not in st.session_state: st.error("No script found. Please generate a script first.") else: with st.spinner("Creating audio..."): try: mg_model, mg_processor = load_musicgen_model() text_for_audio = st.session_state["final_script"] inputs = mg_processor( text=[text_for_audio], padding=True, return_tensors="pt" ) audio_values = mg_model.generate(**inputs, max_new_tokens=audio_length) sr = mg_model.config.audio_encoder.sampling_rate outfile = "llama3_radio_jingle.wav" scipy.io.wavfile.write(outfile, rate=sr, data=audio_values[0, 0].numpy()) st.success("Audio generated! Press play below:") st.audio(outfile) except Exception as e: st.error(f"MusicGen error: {e}") # --------------------------------------------------------------------- # 11) FOOTER # --------------------------------------------------------------------- st.markdown("---") st.markdown( """ """, unsafe_allow_html=True )