|
import re
|
|
import yaml
|
|
from munch import Munch
|
|
import numpy as np
|
|
import librosa
|
|
import noisereduce as nr
|
|
from meldataset import TextCleaner
|
|
import torch
|
|
import torchaudio
|
|
from nltk.tokenize import word_tokenize
|
|
import nltk
|
|
nltk.download('punkt_tab')
|
|
|
|
from models import ProsodyPredictor, TextEncoder, StyleEncoder
|
|
from Modules.hifigan import Decoder
|
|
|
|
class Preprocess:
|
|
def __text_normalize(self, text):
|
|
punctuation = [",", "、", "،", ";", "(", ".", "。", "…", "!", "–", ":", "?"]
|
|
map_to = "."
|
|
punctuation_pattern = re.compile(f"[{''.join(re.escape(p) for p in punctuation)}]")
|
|
|
|
text = punctuation_pattern.sub(map_to, text)
|
|
|
|
text = re.sub(r'\s+', ' ', text).strip()
|
|
return text
|
|
def __merge_fragments(self, texts, n):
|
|
merged = []
|
|
i = 0
|
|
while i < len(texts):
|
|
fragment = texts[i]
|
|
j = i + 1
|
|
while len(fragment.split()) < n and j < len(texts):
|
|
fragment += ", " + texts[j]
|
|
j += 1
|
|
merged.append(fragment)
|
|
i = j
|
|
if len(merged[-1].split()) < n and len(merged) > 1:
|
|
merged[-2] = merged[-2] + ", " + merged[-1]
|
|
del merged[-1]
|
|
else:
|
|
merged[-1] = merged[-1]
|
|
return merged
|
|
def wave_preprocess(self, wave):
|
|
to_mel = torchaudio.transforms.MelSpectrogram(n_mels=80, n_fft=2048, win_length=1200, hop_length=300)
|
|
mean, std = -4, 4
|
|
wave_tensor = torch.from_numpy(wave).float()
|
|
mel_tensor = to_mel(wave_tensor)
|
|
mel_tensor = (torch.log(1e-5 + mel_tensor.unsqueeze(0)) - mean) / std
|
|
return mel_tensor
|
|
def text_preprocess(self, text, n_merge=12):
|
|
text_norm = self.__text_normalize(text).split(".")
|
|
text_norm = [s.strip() for s in text_norm]
|
|
text_norm = list(filter(lambda x: x != '', text_norm))
|
|
text_norm = self.__merge_fragments(text_norm, n=n_merge)
|
|
return text_norm
|
|
def length_to_mask(self, lengths):
|
|
mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)
|
|
mask = torch.gt(mask+1, lengths.unsqueeze(1))
|
|
return mask
|
|
|
|
|
|
class StyleTTS2(torch.nn.Module):
|
|
def __init__(self, config_path, models_path):
|
|
super().__init__()
|
|
self.register_buffer("get_device", torch.empty(0))
|
|
self.preprocess = Preprocess()
|
|
self.ref_s = None
|
|
config = yaml.safe_load(open(config_path, "r", encoding="utf-8"))
|
|
|
|
try:
|
|
symbols = (
|
|
list(config['symbol']['pad']) +
|
|
list(config['symbol']['punctuation']) +
|
|
list(config['symbol']['letters']) +
|
|
list(config['symbol']['letters_ipa']) +
|
|
list(config['symbol']['extend'])
|
|
)
|
|
symbol_dict = {}
|
|
for i in range(len((symbols))):
|
|
symbol_dict[symbols[i]] = i
|
|
|
|
n_token = len(symbol_dict) + 1
|
|
print("\nFound:", n_token, "symbols")
|
|
except Exception as e:
|
|
print(f"\nERROR: Cannot find {e} in config file!\nYour config file is likely outdated, please download updated version from the repository.")
|
|
raise SystemExit(1)
|
|
|
|
args = self.__recursive_munch(config['model_params'])
|
|
args['n_token'] = n_token
|
|
|
|
self.cleaner = TextCleaner(symbol_dict, debug=False)
|
|
|
|
assert args.decoder.type in ['hifigan'], 'Decoder type unknown'
|
|
|
|
self.decoder = Decoder(dim_in=args.hidden_dim, style_dim=args.style_dim, dim_out=args.n_mels,
|
|
resblock_kernel_sizes = args.decoder.resblock_kernel_sizes,
|
|
upsample_rates = args.decoder.upsample_rates,
|
|
upsample_initial_channel=args.decoder.upsample_initial_channel,
|
|
resblock_dilation_sizes=args.decoder.resblock_dilation_sizes,
|
|
upsample_kernel_sizes=args.decoder.upsample_kernel_sizes)
|
|
self.predictor = ProsodyPredictor(style_dim=args.style_dim, d_hid=args.hidden_dim, nlayers=args.n_layer, max_dur=args.max_dur, dropout=args.dropout)
|
|
self.text_encoder = TextEncoder(channels=args.hidden_dim, kernel_size=5, depth=args.n_layer, n_symbols=args.n_token)
|
|
self.style_encoder = StyleEncoder(dim_in=args.dim_in, style_dim=args.style_dim, max_conv_dim=args.hidden_dim)
|
|
|
|
self.__load_models(models_path)
|
|
|
|
def __recursive_munch(self, d):
|
|
if isinstance(d, dict):
|
|
return Munch((k, self.__recursive_munch(v)) for k, v in d.items())
|
|
elif isinstance(d, list):
|
|
return [self.__recursive_munch(v) for v in d]
|
|
else:
|
|
return d
|
|
|
|
def __replace_outliers_zscore(self, tensor, threshold=3.0, factor=0.95):
|
|
mean = tensor.mean()
|
|
std = tensor.std()
|
|
z = (tensor - mean) / std
|
|
|
|
|
|
outlier_mask = torch.abs(z) > threshold
|
|
|
|
sign = torch.sign(tensor - mean)
|
|
replacement = mean + sign * (threshold * std * factor)
|
|
|
|
result = tensor.clone()
|
|
result[outlier_mask] = replacement[outlier_mask]
|
|
|
|
return result
|
|
|
|
def __load_models(self, models_path):
|
|
module_params = []
|
|
model = {'decoder':self.decoder, 'predictor':self.predictor, 'text_encoder':self.text_encoder, 'style_encoder':self.style_encoder}
|
|
|
|
params_whole = torch.load(models_path, map_location='cpu')
|
|
params = params_whole['net']
|
|
params = {key: value for key, value in params.items() if key in model.keys()}
|
|
|
|
for key in model:
|
|
try:
|
|
model[key].load_state_dict(params[key])
|
|
except:
|
|
from collections import OrderedDict
|
|
state_dict = params[key]
|
|
new_state_dict = OrderedDict()
|
|
for k, v in state_dict.items():
|
|
name = k[7:]
|
|
new_state_dict[name] = v
|
|
model[key].load_state_dict(new_state_dict, strict=False)
|
|
|
|
total_params = sum(p.numel() for p in model[key].parameters())
|
|
print(key,":",total_params)
|
|
module_params.append(total_params)
|
|
|
|
print('\nTotal',":",sum(module_params))
|
|
|
|
def __compute_style(self, path, denoise, split_dur):
|
|
device = self.get_device.device
|
|
denoise = min(denoise, 1)
|
|
if split_dur != 0: split_dur = max(int(split_dur), 1)
|
|
max_samples = 24000*20
|
|
print("Computing the style for:", path)
|
|
|
|
wave, sr = librosa.load(path, sr=24000)
|
|
audio, index = librosa.effects.trim(wave, top_db=30)
|
|
if sr != 24000:
|
|
audio = librosa.resample(audio, sr, 24000)
|
|
if len(audio) > max_samples:
|
|
audio = audio[:max_samples]
|
|
|
|
if denoise > 0.0:
|
|
audio_denoise = nr.reduce_noise(y=audio, sr=sr, n_fft=2048, win_length=1200, hop_length=300)
|
|
audio = audio*(1-denoise) + audio_denoise*denoise
|
|
|
|
with torch.no_grad():
|
|
if split_dur>0 and len(audio)/sr>=4:
|
|
|
|
count = 0
|
|
ref_s = None
|
|
jump = sr*split_dur
|
|
total_len = len(audio)
|
|
|
|
|
|
mel_tensor = self.preprocess.wave_preprocess(audio[0:jump]).to(device)
|
|
ref_s = self.style_encoder(mel_tensor.unsqueeze(1))
|
|
count += 1
|
|
for i in range(jump, total_len, jump):
|
|
if i+jump >= total_len:
|
|
left_dur = (total_len-i)/sr
|
|
if left_dur >= 1:
|
|
mel_tensor = self.preprocess.wave_preprocess(audio[i:total_len]).to(device)
|
|
ref_s += self.style_encoder(mel_tensor.unsqueeze(1))
|
|
count += 1
|
|
continue
|
|
mel_tensor = self.preprocess.wave_preprocess(audio[i:i+jump]).to(device)
|
|
ref_s += self.style_encoder(mel_tensor.unsqueeze(1))
|
|
count += 1
|
|
ref_s /= count
|
|
else:
|
|
mel_tensor = self.preprocess.wave_preprocess(audio).to(device)
|
|
ref_s = self.style_encoder(mel_tensor.unsqueeze(1))
|
|
|
|
return ref_s
|
|
|
|
def __inference(self, phonem, ref_s, speed=1, prev_d_mean=0, t=0.1):
|
|
device = self.get_device.device
|
|
speed = min(max(speed, 0.0001), 2)
|
|
|
|
phonem = ' '.join(word_tokenize(phonem))
|
|
tokens = self.cleaner(phonem)
|
|
tokens.insert(0, 0)
|
|
tokens.append(0)
|
|
tokens = torch.LongTensor(tokens).to(device).unsqueeze(0)
|
|
|
|
with torch.no_grad():
|
|
input_lengths = torch.LongTensor([tokens.shape[-1]]).to(device)
|
|
text_mask = self.preprocess.length_to_mask(input_lengths).to(device)
|
|
|
|
|
|
t_en = self.text_encoder(tokens, input_lengths, text_mask)
|
|
s = ref_s.to(device)
|
|
|
|
|
|
d = self.predictor.text_encoder(t_en, s, input_lengths, text_mask)
|
|
x, _ = self.predictor.lstm(d)
|
|
duration = self.predictor.duration_proj(x)
|
|
duration = torch.sigmoid(duration).sum(axis=-1)
|
|
|
|
if prev_d_mean != 0:
|
|
dur_stats = torch.empty(duration.shape).normal_(mean=prev_d_mean, std=duration.std()).to(device)
|
|
else:
|
|
dur_stats = torch.empty(duration.shape).normal_(mean=duration.mean(), std=duration.std()).to(device)
|
|
duration = duration*(1-t) + dur_stats*t
|
|
duration[:,1:-2] = self.__replace_outliers_zscore(duration[:,1:-2])
|
|
|
|
duration /= speed
|
|
|
|
pred_dur = torch.round(duration.squeeze()).clamp(min=1)
|
|
pred_aln_trg = torch.zeros(input_lengths, int(pred_dur.sum().data))
|
|
c_frame = 0
|
|
for i in range(pred_aln_trg.size(0)):
|
|
pred_aln_trg[i, c_frame:c_frame + int(pred_dur[i].data)] = 1
|
|
c_frame += int(pred_dur[i].data)
|
|
alignment = pred_aln_trg.unsqueeze(0).to(device)
|
|
|
|
|
|
en = (d.transpose(-1, -2) @ alignment)
|
|
F0_pred, N_pred = self.predictor.F0Ntrain(en, s)
|
|
asr = (t_en @ pred_aln_trg.unsqueeze(0).to(device))
|
|
|
|
out = self.decoder(asr, F0_pred, N_pred, s)
|
|
|
|
return out.squeeze().cpu().numpy(), duration.mean()
|
|
|
|
def get_styles(self, speaker, denoise=0.3, avg_style=True, load_styles=False):
|
|
if not load_styles:
|
|
if avg_style: split_dur = 3
|
|
else: split_dur = 0
|
|
self.ref_s = self.__compute_style(speaker['path'], denoise=denoise, split_dur=split_dur)
|
|
else:
|
|
if self.ref_s is None:
|
|
raise Exception("Have to compute or load the styles first!")
|
|
style = {
|
|
'style': self.ref_s,
|
|
'path': speaker['path'],
|
|
'speed': speaker['speed'],
|
|
}
|
|
return style
|
|
|
|
def save_styles(self, save_dir):
|
|
if self.ref_s is not None:
|
|
torch.save(self.ref_s, save_dir)
|
|
print("Saved styles!")
|
|
else:
|
|
raise Exception("Have to compute the styles before saving it.")
|
|
|
|
def load_styles(self, save_dir):
|
|
try:
|
|
self.ref_s = torch.load(save_dir)
|
|
print("Loaded styles!")
|
|
except Exception as e:
|
|
print(e)
|
|
|
|
def generate(self, phonem, style, stabilize=True, n_merge=16):
|
|
if stabilize: smooth_value=0.2
|
|
else: smooth_value=0
|
|
|
|
list_wav = []
|
|
prev_d_mean = 0
|
|
|
|
print("Generating Audio...")
|
|
text_norm = self.preprocess.text_preprocess(phonem, n_merge=n_merge)
|
|
for sentence in text_norm:
|
|
wav, prev_d_mean = self.__inference(sentence, style['style'], speed=style['speed'], prev_d_mean=prev_d_mean, t=smooth_value)
|
|
wav = wav[4000:-4000]
|
|
list_wav.append(wav)
|
|
|
|
final_wav = np.concatenate(list_wav)
|
|
final_wav = np.concatenate([np.zeros([4000]), final_wav, np.zeros([4000])], axis=0)
|
|
return final_wav |