import spaces from snac import SNAC import torch import gradio as gr from transformers import AutoModelForCausalLM, AutoTokenizer from huggingface_hub import snapshot_download import google.generativeai as genai import re import logging import numpy as np from pydub import AudioSegment import io logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) device = "cuda" if torch.cuda.is_available() else "cpu" print("Loading SNAC model...") snac_model = SNAC.from_pretrained("hubertsiuzdak/snac_24khz") snac_model = snac_model.to(device) model_name = "canopylabs/orpheus-3b-0.1-ft" model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16) model.to(device) tokenizer = AutoTokenizer.from_pretrained(model_name) print(f"Orpheus model loaded to {device}") # Available voices VOICES = ["tara", "leah", "jess", "leo", "dan", "mia", "zac", "zoe"] # Available Emotive Tags EMOTIVE_TAGS = ["``", "``", "``", "``", "``", "``", "``", "``"] @spaces.GPU() def generate_podcast_script(api_key, prompt, uploaded_file, duration, num_hosts): try: genai.configure(api_key=api_key) model = genai.GenerativeModel('gemini-2.5-pro-preview-03-25') combined_content = prompt or "" if uploaded_file: file_content = uploaded_file.read().decode('utf-8') combined_content += "\n" + file_content if combined_content else file_content num_hosts = int(num_hosts) # Convert to integer prompt = f""" Create a podcast script for {num_hosts} {'person' if num_hosts == 1 else 'people'} discussing: {combined_content} Duration: {duration} minutes. Include natural speech, humor, and occasional off-topic thoughts. Use speech fillers like um, ah. Vary emotional tone. Format: {'Monologue' if num_hosts == 1 else 'Alternating dialogue'} without speaker labels. Separate {'paragraphs' if num_hosts == 1 else 'lines'} with blank lines. only provide the dialog for text to speech Use emotion tags in angle brackets: , , , , , , , . Example: "I can't believe I stayed up all night only to find out the meeting was canceled ." Ensure content flows naturally and stays on topic. Match the script length to {duration} minutes. {'Make sure the script is a monologue for one person.' if num_hosts == 1 else 'Ensure the dialogue alternates between two distinct voices, with one speaking on odd-numbered lines and the other on even-numbered lines.'} """ response = model.generate_content(prompt) return re.sub(r'[^a-zA-Z0-9\s.,?!<>]', '', response.text) except Exception as e: logger.error(f"Error generating podcast script: {str(e)}") raise def process_prompt(prompt, voice, tokenizer, device): prompt = f"{voice}: {prompt}" input_ids = tokenizer(prompt, return_tensors="pt").input_ids start_token = torch.tensor([[128259]], dtype=torch.int64) end_tokens = torch.tensor([[128009, 128260]], dtype=torch.int64) modified_input_ids = torch.cat([start_token, input_ids, end_tokens], dim=1) attention_mask = torch.ones_like(modified_input_ids) return modified_input_ids.to(device), attention_mask.to(device) def parse_output(generated_ids): token_to_find = 128257 token_to_remove = 128258 token_indices = (generated_ids == token_to_find).nonzero(as_tuple=True) if len(token_indices[1]) > 0: last_occurrence_idx = token_indices[1][-1].item() cropped_tensor = generated_ids[:, last_occurrence_idx+1:] else: cropped_tensor = generated_ids processed_rows = [] for row in cropped_tensor: masked_row = row[row != token_to_remove] processed_rows.append(masked_row) code_lists = [] for row in processed_rows: row_length = row.size(0) new_length = (row_length // 7) * 7 trimmed_row = row[:new_length] trimmed_row = [t - 128266 for t in trimmed_row] code_lists.append(trimmed_row) return code_lists[0] def redistribute_codes(code_list, snac_model): device = next(snac_model.parameters()).device # Get the device of SNAC model layer_1 = [] layer_2 = [] layer_3 = [] for i in range((len(code_list)+1)//7): layer_1.append(code_list[7*i]) layer_2.append(code_list[7*i+1]-4096) layer_3.append(code_list[7*i+2]-(2*4096)) layer_3.append(code_list[7*i+3]-(3*4096)) layer_2.append(code_list[7*i+4]-(4*4096)) layer_3.append(code_list[7*i+5]-(5*4096)) layer_3.append(code_list[7*i+6]-(6*4096)) codes = [ torch.tensor(layer_1, device=device).unsqueeze(0), torch.tensor(layer_2, device=device).unsqueeze(0), torch.tensor(layer_3, device=device).unsqueeze(0) ] audio_hat = snac_model.decode(codes) return audio_hat.detach().squeeze().cpu().numpy() # Always return CPU numpy array @spaces.GPU() @spaces.GPU() @spaces.GPU() def generate_speech(text, voice1, voice2, temperature, top_p, repetition_penalty, max_new_tokens, num_hosts, progress=gr.Progress()): if not text.strip(): return None try: # Load the intro/outro music music = AudioSegment.from_mp3("Maiko-intro-outro.mp3") progress(0.1, "Processing text...") lines = text.split('\n') audio_samples = [] for i, line in enumerate(lines): if not line.strip(): continue if num_hosts == "2": voice = voice1 if i % 2 == 0 else voice2 else: voice = voice1 input_ids, attention_mask = process_prompt(line, voice, tokenizer, device) progress(0.3, f"Generating speech tokens for line {i+1}...") with torch.no_grad(): generated_ids = model.generate( input_ids, attention_mask=attention_mask, do_sample=True, temperature=temperature, top_p=top_p, repetition_penalty=repetition_penalty, max_new_tokens=max_new_tokens, num_return_sequences=1, eos_token_id=128258, ) progress(0.6, f"Processing speech tokens for line {i+1}...") code_list = parse_output(generated_ids) progress(0.8, f"Converting line {i+1} to audio...") line_audio = redistribute_codes(code_list, snac_model) audio_samples.append(line_audio) # Concatenate all audio samples final_audio = np.concatenate(audio_samples) # Convert numpy array to AudioSegment speech_audio = AudioSegment( final_audio.tobytes(), frame_rate=24000, sample_width=final_audio.dtype.itemsize, channels=1 ) # Combine intro, speech, and outro combined_audio = music + speech_audio + music # Convert back to numpy array combined_numpy = np.array(combined_audio.get_array_of_samples()) # Add a check for 15-second limitation max_samples = 24000 * 15 # 15 seconds at 24kHz sample rate if len(combined_numpy) > max_samples: combined_numpy = combined_numpy[:max_samples] return (24000, combined_numpy) except Exception as e: print(f"Error generating speech: {e}") return None with gr.Blocks(title="Orpheus Text-to-Speech") as demo: with gr.Row(): with gr.Column(scale=1): gemini_api_key = gr.Textbox(label="Gemini API Key", type="password") prompt = gr.Textbox(label="Prompt", lines=8, placeholder="Enter your text here...") uploaded_file = gr.File(label="Upload File") duration = gr.Slider(minimum=1, maximum=60, value=5, step=1, label="Duration (minutes)") num_hosts = gr.Radio(["1", "2"], label="Number of Hosts", value="1") generate_script_btn = gr.Button("Generate Podcast Script") with gr.Column(scale=2): voice1 = gr.Dropdown( choices=VOICES, value="tara", label="Voice 1", info="Select the first voice for speech generation" ) voice2 = gr.Dropdown( choices=VOICES, value="dan", label="Voice 2", info="Select the second voice for speech generation" ) with gr.Accordion("Advanced Settings", open=False): temperature = gr.Slider( minimum=0.1, maximum=1.5, value=0.6, step=0.05, label="Temperature", info="Higher values (0.7-1.0) create more expressive but less stable speech" ) top_p = gr.Slider( minimum=0.1, maximum=1.0, value=0.9, step=0.05, label="Top P", info="Higher values produce more diverse outputs" ) repetition_penalty = gr.Slider( minimum=1.0, maximum=2.0, value=1.2, step=0.1, label="Repetition Penalty", info="Higher values discourage repetitive patterns" ) max_new_tokens = gr.Slider( minimum=100, maximum=2000, value=1200, step=100, label="Max Length", info="Maximum length of generated audio (in tokens)" ) with gr.Row(): submit_btn = gr.Button("Generate Speech", variant="primary") clear_btn = gr.Button("Clear") with gr.Column(scale=2): script_output = gr.Textbox(label="Generated Script", lines=10) audio_output = gr.Audio(label="Generated Speech", type="numpy") generate_script_btn.click( fn=generate_podcast_script, inputs=[gemini_api_key, prompt, uploaded_file, duration, num_hosts], outputs=script_output ) submit_btn.click( fn=generate_speech, inputs=[script_output, voice1, voice2, temperature, top_p, repetition_penalty, max_new_tokens, num_hosts], outputs=audio_output ) clear_btn.click( fn=lambda: (None, None, None), inputs=[], outputs=[prompt, script_output, audio_output] ) if __name__ == "__main__": demo.queue().launch(share=False, ssr_mode=False)