bluenevus's picture
Update app.py
ac511b5 verified
raw
history blame
13.8 kB
import spaces
from snac import SNAC
import torch
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
from huggingface_hub import snapshot_download
import google.generativeai as genai
import re
import logging
import numpy as np
from pydub import AudioSegment
import io
from docx import Document
import PyPDF2
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Loading SNAC model...")
snac_model = SNAC.from_pretrained("hubertsiuzdak/snac_24khz")
snac_model = snac_model.to(device)
model_name = "canopylabs/orpheus-3b-0.1-ft"
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16)
model.to(device)
tokenizer = AutoTokenizer.from_pretrained(model_name)
print(f"Orpheus model loaded to {device}")
# Available voices
VOICES = ["tara", "leah", "jess", "leo", "dan", "mia", "zac", "zoe"]
# Available Emotive Tags
EMOTIVE_TAGS = ["`<laugh>`", "`<chuckle>`", "`<sigh>`", "`<cough>`", "`<sniffle>`", "`<groan>`", "`<yawn>`", "`<gasp>`"]
@spaces.GPU()
def generate_podcast_script(api_key, prompt, uploaded_file, duration, num_hosts):
try:
genai.configure(api_key=api_key)
model = genai.GenerativeModel('gemini-2.5-pro-preview-03-25')
combined_content = prompt or ""
if uploaded_file is not None:
file_bytes = io.BytesIO(uploaded_file)
# Try to detect the file type based on content
file_bytes.seek(0)
if file_bytes.read(4) == b'%PDF':
# It's a PDF file
file_bytes.seek(0)
pdf_reader = PyPDF2.PdfReader(file_bytes)
file_content = "\n".join([page.extract_text() for page in pdf_reader.pages])
else:
# Try as text file first
file_bytes.seek(0)
try:
file_content = file_bytes.read().decode('utf-8')
except UnicodeDecodeError:
# If it's not a text file, try as a docx
file_bytes.seek(0)
try:
doc = Document(file_bytes)
file_content = "\n".join([para.text for para in doc.paragraphs])
except:
raise ValueError("Unsupported file type or corrupted file")
combined_content += "\n" + file_content if combined_content else file_content
num_hosts = int(num_hosts) # Convert to integer
prompt_template = f"""
Create a podcast script for {num_hosts} {'person' if num_hosts == 1 else 'people'} discussing:
{combined_content}
Duration: {duration} minutes. Include natural speech, humor, and occasional off-topic thoughts.
Use speech fillers like um, ah. Vary emotional tone.
Format: {'Monologue' if num_hosts == 1 else 'Alternating dialogue'} without speaker labels.
Separate {'paragraphs' if num_hosts == 1 else 'lines'} with blank lines.
only provide the dialog for text to speech
Only use these emotion tags in angle brackets: <laugh>, <sigh>, <chuckle>, <cough>, <sniffle>, <groan>, <yawn>, <gasp>.
-Example: "I can't believe I stayed up all night <yawn> only to find out the meeting was canceled <groan>."
Ensure content flows naturally and stays on topic. Match the script length to {duration} minutes.
Do not include speaker labels like "John:" or "Sara:" before dialogue.
The intro always includes the first speaker and should be in the same paragraph.
The outro always includes the first speaker and should be in the same paragraph
Do not include these types of transition "Intro Music fades in and then fades slightly to background"
Keep each speaker's entire monologue in a single paragraph, regardless of length if the number of hosts is not 1.
Start a new paragraph only when switching to a different speaker if the number of hosts is not 1.
Maintain natural conversation flow and speech patterns within each monologue.
Use context clues or subtle references to indicate who is speaking without explicit labels if the number of hosts is not 1
Use speaker names sparingly, only when necessary for clarity or emphasis. Avoid starting every line with the other person's name.
Rely more on context and speech patterns to indicate who is speaking, rather than always stating names.
Use names primarily for transitions sparingly, definitely with agreements, or to draw attention to a specific point, not as a constant form of address.
{'Make sure the script is a monologue for one person.' if num_hosts == 1 else 'Ensure the dialogue alternates between two distinct voices, with one speaking on odd-numbered lines and the other on even-numbered lines.'}
"""
response = model.generate_content(prompt_template)
return re.sub(r'[^a-zA-Z0-9\s.,?!<>]', '', response.text)
except Exception as e:
logger.error(f"Error generating podcast script: {str(e)}")
raise
def process_prompt(prompt, voice, tokenizer, device):
prompt = f"{voice}: {prompt}"
input_ids = tokenizer(prompt, return_tensors="pt").input_ids
start_token = torch.tensor([[128259]], dtype=torch.int64)
end_tokens = torch.tensor([[128009, 128260]], dtype=torch.int64)
modified_input_ids = torch.cat([start_token, input_ids, end_tokens], dim=1)
attention_mask = torch.ones_like(modified_input_ids)
return modified_input_ids.to(device), attention_mask.to(device)
def parse_output(generated_ids):
token_to_find = 128257
token_to_remove = 128258
token_indices = (generated_ids == token_to_find).nonzero(as_tuple=True)
if len(token_indices[1]) > 0:
last_occurrence_idx = token_indices[1][-1].item()
cropped_tensor = generated_ids[:, last_occurrence_idx+1:]
else:
cropped_tensor = generated_ids
processed_rows = []
for row in cropped_tensor:
masked_row = row[row != token_to_remove]
processed_rows.append(masked_row)
code_lists = []
for row in processed_rows:
row_length = row.size(0)
new_length = (row_length // 7) * 7
trimmed_row = row[:new_length]
trimmed_row = [t - 128266 for t in trimmed_row]
code_lists.append(trimmed_row)
return code_lists[0]
def redistribute_codes(code_list, snac_model):
device = next(snac_model.parameters()).device # Get the device of SNAC model
layer_1 = []
layer_2 = []
layer_3 = []
for i in range((len(code_list)+1)//7):
layer_1.append(code_list[7*i])
layer_2.append(code_list[7*i+1]-4096)
layer_3.append(code_list[7*i+2]-(2*4096))
layer_3.append(code_list[7*i+3]-(3*4096))
layer_2.append(code_list[7*i+4]-(4*4096))
layer_3.append(code_list[7*i+5]-(5*4096))
layer_3.append(code_list[7*i+6]-(6*4096))
codes = [
torch.tensor(layer_1, device=device).unsqueeze(0),
torch.tensor(layer_2, device=device).unsqueeze(0),
torch.tensor(layer_3, device=device).unsqueeze(0)
]
audio_hat = snac_model.decode(codes)
return audio_hat.detach().squeeze().cpu().numpy() # Always return CPU numpy array
def detect_silence(audio, threshold=0.01, min_silence_duration=1.2):
sample_rate = 24000 # Adjust if your sample rate is different
is_silent = np.abs(audio) < threshold
silent_regions = np.where(is_silent)[0]
silence_starts = []
silence_ends = []
if len(silent_regions) > 0:
silence_starts.append(silent_regions[0])
for i in range(1, len(silent_regions)):
if silent_regions[i] - silent_regions[i-1] > 1:
silence_ends.append(silent_regions[i-1])
silence_starts.append(silent_regions[i])
silence_ends.append(silent_regions[-1])
long_silences = [(start, end) for start, end in zip(silence_starts, silence_ends)
if (end - start) / sample_rate >= min_silence_duration]
return long_silences
@spaces.GPU()
def generate_speech(text, voice1, voice2, temperature, top_p, repetition_penalty, max_new_tokens, num_hosts, progress=gr.Progress()):
if not text.strip():
return None
try:
progress(0.1, "Processing text...")
paragraphs = text.split('\n\n') # Split by double newline
audio_samples = []
for i, paragraph in enumerate(paragraphs):
if not paragraph.strip():
continue
voice = voice1 if num_hosts == "1" or i % 2 == 0 else voice2
input_ids, attention_mask = process_prompt(paragraph, voice, tokenizer, device)
progress(0.3, f"Generating speech tokens for paragraph {i+1}...")
with torch.no_grad():
generated_ids = model.generate(
input_ids,
attention_mask=attention_mask,
do_sample=True,
temperature=temperature,
top_p=top_p,
repetition_penalty=repetition_penalty,
max_new_tokens=max_new_tokens,
num_return_sequences=1,
eos_token_id=128258,
)
progress(0.6, f"Processing speech tokens for paragraph {i+1}...")
code_list = parse_output(generated_ids)
progress(0.8, f"Converting paragraph {i+1} to audio...")
paragraph_audio = redistribute_codes(code_list, snac_model)
# Add silence detection here
silences = detect_silence(paragraph_audio)
if silences:
# Trim the audio at the last detected silence
paragraph_audio = paragraph_audio[:silences[-1][1]]
audio_samples.append(paragraph_audio)
final_audio = np.concatenate(audio_samples)
# Normalize the audio
final_audio = np.int16(final_audio / np.max(np.abs(final_audio)) * 32767)
return (24000, final_audio)
except Exception as e:
print(f"Error generating speech: {e}")
return None
with gr.Blocks(title="Orpheus Text-to-Speech") as demo:
with gr.Row():
with gr.Column(scale=1):
gemini_api_key = gr.Textbox(label="Gemini API Key", type="password")
prompt = gr.Textbox(
label="Prompt",
placeholder="Enter your text here...",
lines=5,
max_lines=30,
show_label=True,
interactive=True,
container=True
)
uploaded_file = gr.File(label="Upload File", type="binary")
with gr.Column(scale=2):
duration = gr.Slider(minimum=1, maximum=60, value=5, step=1, label="Duration (minutes)")
num_hosts = gr.Radio(["1", "2"], label="Number of Hosts", value="1")
script_output = gr.Textbox(label="Generated Script", lines=10)
generate_script_btn = gr.Button("Generate Podcast Script")
with gr.Column(scale=2):
voice1 = gr.Dropdown(
choices=VOICES,
value="tara",
label="Voice 1",
info="Select the first voice for speech generation"
)
voice2 = gr.Dropdown(
choices=VOICES,
value="zac",
label="Voice 2",
info="Select the second voice for speech generation"
)
with gr.Accordion("Advanced Settings", open=False):
temperature = gr.Slider(
minimum=0.1, maximum=1.5, value=0.6, step=0.05,
label="Temperature",
info="Higher values (0.7-1.0) create more expressive but less stable speech"
)
top_p = gr.Slider(
minimum=0.1, maximum=1.0, value=0.9, step=0.05,
label="Top P",
info="Higher values produce more diverse outputs"
)
repetition_penalty = gr.Slider(
minimum=1.0, maximum=2.0, value=1.2, step=0.1,
label="Repetition Penalty",
info="Higher values discourage repetitive patterns"
)
max_new_tokens = gr.Slider(
minimum=100, maximum=4096, value=2048, step=100,
label="Max Length",
info="Maximum length of generated audio (in tokens)"
)
audio_output = gr.Audio(label="Generated Audio", type="numpy")
with gr.Row():
submit_btn = gr.Button("Generate Audio", variant="primary")
clear_btn = gr.Button("Clear")
generate_script_btn.click(
fn=generate_podcast_script,
inputs=[gemini_api_key, prompt, uploaded_file, duration, num_hosts],
outputs=script_output
)
submit_btn.click(
fn=generate_speech,
inputs=[script_output, voice1, voice2, temperature, top_p, repetition_penalty, max_new_tokens, num_hosts],
outputs=audio_output
)
clear_btn.click(
fn=lambda: (None, None, None),
inputs=[],
outputs=[prompt, script_output, audio_output]
)
if __name__ == "__main__":
demo.queue().launch(share=False, ssr_mode=False)