SoloSpeech / app.py
OpenSound's picture
Update app.py
f5e60bb verified
raw
history blame
13.4 kB
import gradio as gr
import spaces
import yaml
import random
import argparse
import os
import torch
import torch.nn.functional as F
import librosa
from tqdm import tqdm
from diffusers import DDIMScheduler
from solospeech.model.solospeech.conditioners import SoloSpeech_TSE
# from solospeech.model.solospeech.conditioners import SoloSpeech_TSR
from solospeech.scripts.solospeech.utils import save_audio
import shutil
from solospeech.vae_modules.autoencoder_wrapper import Autoencoder
import pandas as pd
from speechbrain.pretrained.interfaces import Pretrained
from solospeech.corrector.fastgeco.model import ScoreModel
from solospeech.corrector.geco.util.other import pad_spec
from huggingface_hub import snapshot_download
import time
class Encoder(Pretrained):
MODULES_NEEDED = [
"compute_features",
"mean_var_norm",
"embedding_model"
]
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def encode_batch(self, wavs, wav_lens=None, normalize=False):
# Manage single waveforms in input
if len(wavs.shape) == 1:
wavs = wavs.unsqueeze(0)
# Assign full length if wav_lens is not assigned
if wav_lens is None:
wav_lens = torch.ones(wavs.shape[0], device=self.device)
# Storing waveform in the specified device
wavs, wav_lens = wavs.to(self.device), wav_lens.to(self.device)
wavs = wavs.float()
# Computing features and embeddings
feats = self.mods.compute_features(wavs)
feats = self.mods.mean_var_norm(feats, wav_lens)
embeddings = self.mods.embedding_model(feats, wav_lens)
if normalize:
embeddings = self.hparams.mean_var_norm_emb(
embeddings,
torch.ones(embeddings.shape[0], device=self.device)
)
return embeddings
parser = argparse.ArgumentParser()
# pre-trained model path
parser.add_argument('--eta', type=int, default=0)
parser.add_argument("--num_infer_steps", type=int, default=200)
parser.add_argument("--num_candidates", type=int, default=4)
parser.add_argument('--sample-rate', type=int, default=16000)
# random seed
parser.add_argument('--random-seed', type=int, default=42, help="Fixed seed")
args = parser.parse_args()
print("Downloading model from Huggingface...")
local_dir = snapshot_download(
repo_id="OpenSound/SoloSpeech-models"
)
args.tse_config = os.path.join(local_dir, "config_extractor.yaml")
# args.tsr_config = os.path.join(local_dir, "config_tsr.yaml")
args.vae_config = os.path.join(local_dir, "config_compressor.json")
args.autoencoder_path = os.path.join(local_dir, "compressor.ckpt")
args.tse_ckpt = os.path.join(local_dir, "extractor.pt")
# args.tsr_ckpt = os.path.join(local_dir, "tsr.pt")
args.geco_ckpt = os.path.join(local_dir, "corrector.ckpt")
device = "cuda:0" if torch.cuda.is_available() else "cpu"
print(f"Device: {device}")
# load config
print("Loading models...")
with open(args.tse_config, 'r') as fp:
args.tse_config = yaml.safe_load(fp)
# with open(args.tsr_config, 'r') as fp:
# args.tsr_config = yaml.safe_load(fp)
args.v_prediction = args.tse_config["ddim"]["v_prediction"]
# load compressor
autoencoder = Autoencoder(args.autoencoder_path, args.vae_config, 'stft_vae', quantization_first=True)
autoencoder.eval()
autoencoder.to(device)
# load extractor
tse_model = SoloSpeech_TSE(
args.tse_config['diffwrap']['UDiT'],
args.tse_config['diffwrap']['ViT'],
).to(device)
tse_model.load_state_dict(torch.load(args.tse_ckpt)['model'])
tse_model.eval()
# # load tsr model
# tsr_model = SoloSpeech_TSR(
# args.tsr_config['diffwrap']['UDiT']
# ).to(device)
# tsr_model.load_state_dict(torch.load(args.tsr_ckpt)['model'])
# tsr_model.eval()
# load corrector
geco_model = ScoreModel.load_from_checkpoint(
args.geco_ckpt,
batch_size=1, num_workers=0, kwargs=dict(gpu=False)
)
geco_model.eval(no_ema=False)
geco_model.cuda()
# load sid model
ecapatdnn_model = Encoder.from_hparams(source="yangwang825/ecapa-tdnn-vox2")
# cosine_sim = torch.nn.CosineSimilarity(dim=-1)
# load diffusion tools
noise_scheduler = DDIMScheduler(**args.tse_config["ddim"]['diffusers'])
# these steps reset dtype of noise_scheduler params
latents = torch.randn((1, 128, 128),
device=device)
noise = torch.randn(latents.shape).to(device)
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps,
(noise.shape[0],),
device=latents.device).long()
_ = noise_scheduler.add_noise(latents, noise, timesteps)
@spaces.GPU
def sample_diffusion(tse_model, autoencoder, std, scheduler, device,
mixture=None, reference=None, lengths=None, reference_lengths=None,
ddim_steps=50, eta=0, seed=2025
):
with torch.no_grad():
generator = torch.Generator(device=device).manual_seed(seed)
scheduler.set_timesteps(ddim_steps)
tse_pred = torch.randn(mixture.shape, generator=generator, device=device)
# tsr_pred = torch.randn(mixture.shape, generator=generator, device=device)
for t in scheduler.timesteps:
tse_pred = scheduler.scale_model_input(tse_pred, t)
model_output, _ = tse_model(
x=tse_pred,
timesteps=t,
mixture=mixture,
reference=reference,
x_len=lengths,
ref_len=reference_lengths
)
tse_pred = scheduler.step(model_output=model_output, timestep=t, sample=tse_pred,
eta=eta, generator=generator).prev_sample
# for t in scheduler.timesteps:
# tsr_pred = scheduler.scale_model_input(tsr_pred, t)
# model_output, _ = tsr_model(
# x=tsr_pred,
# timesteps=t,
# mixture=mixture,
# reference=tse_pred,
# x_len=lengths,
# )
# tsr_pred = scheduler.step(model_output=model_output, timestep=t, sample=tsr_pred,
# eta=eta, generator=generator).prev_sample
tse_pred = autoencoder(embedding=tse_pred.transpose(2,1), std=std).squeeze(1)
# tsr_pred = autoencoder(embedding=tsr_pred.transpose(2,1), std=std).squeeze(1)
return tse_pred
@spaces.GPU
def tse(test_wav, enroll_wav):
print("Start Extraction...")
start_time = time.time()
mixture, _ = librosa.load(test_wav, sr=16000)
reference, _ = librosa.load(enroll_wav, sr=16000)
reference_wav = reference
reference = torch.tensor(reference).unsqueeze(0).to(device)
with torch.no_grad():
# compressor
reference, _ = autoencoder(audio=reference.unsqueeze(1))
reference_lengths = torch.LongTensor([reference.shape[-1]] * args.num_candidates).to(device)
mixture_input = torch.tensor(mixture).unsqueeze(0).to(device)
mixture_wav = mixture_input
mixture_input, std = autoencoder(audio=mixture_input.unsqueeze(1))
lengths = torch.LongTensor([mixture_input.shape[-1]] * args.num_candidates).to(device)
# extractor
mixture_input = mixture_input.repeat(args.num_candidates, 1, 1)
reference = reference.repeat(args.num_candidates, 1, 1)
tse_pred = sample_diffusion(tse_model, autoencoder, std, noise_scheduler, device, mixture_input.transpose(2,1), reference.transpose(2,1), lengths, reference_lengths, ddim_steps=args.num_infer_steps, eta=args.eta, seed=args.random_seed)
tse_pred = sample_diffusion(tse_model, autoencoder, std, noise_scheduler, device, mixture_input.transpose(2,1), reference.transpose(2,1), lengths, reference_lengths, ddim_steps=args.num_infer_steps, eta=args.eta, seed=args.random_seed)
ecapatdnn_embedding_pred = ecapatdnn_model.encode_batch(tse_pred).squeeze()
ecapatdnn_embedding_ref = ecapatdnn_model.encode_batch(torch.tensor(reference_wav)).squeeze()
cos_sims = F.cosine_similarity(ecapatdnn_embedding_pred, ecapatdnn_embedding_ref.unsqueeze(0), dim=1)
_, max_idx = torch.max(cos_sims, dim=0)
pred = tse_pred[max_idx].unsqueeze(0)
# corrector
min_leng = min(pred.shape[-1], mixture_wav.shape[-1])
x = pred[...,:min_leng]
m = mixture_wav[...,:min_leng]
norm_factor = m.abs().max()
x = x / norm_factor
m = m / norm_factor
X = torch.unsqueeze(geco_model._forward_transform(geco_model._stft(x.cuda())), 0)
X = pad_spec(X)
M = torch.unsqueeze(geco_model._forward_transform(geco_model._stft(m.cuda())), 0)
M = pad_spec(M)
timesteps = torch.linspace(0.5, 0.03, 1, device=M.device)
std = geco_model.sde._std(0.5*torch.ones((M.shape[0],), device=M.device))
z = torch.randn_like(M)
X_t = M + z * std[:, None, None, None]
for idx in range(len(timesteps)):
t = timesteps[idx]
if idx != len(timesteps) - 1:
dt = t - timesteps[idx+1]
else:
dt = timesteps[-1]
with torch.no_grad():
f, g = geco_model.sde.sde(X_t, t, M)
vec_t = torch.ones(M.shape[0], device=M.device) * t
mean_x_tm1 = X_t - (f - g**2*geco_model.forward(X_t, vec_t, M, X, vec_t[:,None,None,None]))*dt
if idx == len(timesteps) - 1:
X_t = mean_x_tm1
break
z = torch.randn_like(X)
X_t = mean_x_tm1 + z*g*torch.sqrt(dt)
sample = X_t
sample = sample.squeeze()
x_hat = geco_model.to_audio(sample.squeeze(), min_leng)
x_hat = x_hat * norm_factor / x_hat.abs().max()
x_hat = x_hat.detach().cpu().squeeze().numpy()
end_time = time.time()
audio_len = x_hat.shape[-1] / 16000
rtf = (end_time-start_time)/audio_len
print(f"RTF: {rtf:.4f}")
return (16000, x_hat)
@spaces.GPU
def process_audio(test_wav, enroll_wav):
result = tse(test_wav, enroll_wav)
return result
# List of demo audio files
demo_audio_files = [
("Demo1: Extract male speaker from a mixture of multiple male speakers", "examples/test1.wav", "examples/test1_enroll.wav"),
("Demo2: Extract female speaker from a mixture of multiple female speakers", "examples/test2.wav", "examples/test2_enroll.wav"),
("Demo3: Extract male rapper from music with complex vocals", "examples/test_3_mixture.mp3", "examples/test_3_speaker.mp3"),
]
def update_audio_input(choice):
return choice
# CSS styling (optional)
css = """
#col-container {
margin: 0 auto;
max-width: 1280px;
}
"""
# Gradio Blocks layout
with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo:
with gr.Column(elem_id="col-container"):
gr.Markdown("""
# SoloSpeech: A Precise and High-Fidelity Target Speech Extractor
👋 Introduction: Extract the target voice from mixture speech given an enrollment speech.
💡 To extract sound effects or music from audio, try using [SoloAudio](https://huggingface.co/spaces/OpenSound/SoloAudio).
🔗 Learn more about this project on the [🎯SoloSpeech Repo](https://github.com/WangHelin1997/SoloSpeech/).
""")
with gr.Tab("Target Speech Extraction"):
with gr.Row():
mixture_input = gr.Audio(label="Upload Mixture Audio",
type="filepath",
value="examples/test1.wav")
# gr.Markdown("**Note:** Upload a short clip with only the target speaker. Some non-speech noise is fine.")
with gr.Row(equal_height=True):
enroll_input = gr.Audio(label="Upload Enrollment/Speaker Audio",
type="filepath",
value="examples/test1_enroll.wav",
)
with gr.Row():
extract_button = gr.Button("Extract", variant="primary")
# extract_button = gr.Button("Extract", scale=1)
with gr.Row():
result = gr.Audio(label="Extracted Speech", type="numpy")
with gr.Row(equal_height=True):
demo_selector = gr.Dropdown(
label="Select Test Demo",
choices=[name for name, _, _ in demo_audio_files],
value="Demo1: Extract male speaker from a mixture of multiple male speakers"
)
# Update audio inputs when selecting from dropdown
def update_audio_inputs(choice):
for name, mixture_path, enroll_path in demo_audio_files:
if name == choice:
return mixture_path, enroll_path
return None, None
demo_selector.change(
fn=update_audio_inputs,
inputs=demo_selector,
outputs=[mixture_input, enroll_input]
)
extract_button.click(
fn=process_audio,
inputs=[mixture_input, enroll_input],
outputs=[result]
)
# Launch the Gradio demo
demo.launch()