|
from Modules.vits.models import VitsModel, VitsTokenizer |
|
import sys |
|
import tempfile |
|
import re |
|
import os |
|
from collections import OrderedDict |
|
from Modules.hifigan import Decoder |
|
from Utils.PLBERT.util import load_plbert |
|
import phonemizer |
|
import torch |
|
from cached_path import cached_path |
|
import nltk |
|
import audresample |
|
nltk.download('punkt', download_dir='./') |
|
nltk.download('punkt_tab', download_dir='./') |
|
nltk.data.path.append('.') |
|
import numpy as np |
|
import yaml |
|
import librosa |
|
from models import ProsodyPredictor, TextEncoder, StyleEncoder, MelSpec |
|
from nltk.tokenize import word_tokenize |
|
from Utils.text_utils import transliterate_number |
|
import textwrap |
|
|
|
device = 'cpu' |
|
if torch.cuda.is_available(): |
|
device = 'cuda' |
|
|
|
_pad = "$" |
|
_punctuation = ';:,.!?¡¿—…"«»“” ' |
|
_letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz' |
|
_letters_ipa = "ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'ᵻ" |
|
|
|
|
|
symbols = [_pad] + list(_punctuation) + list(_letters) + list(_letters_ipa) |
|
|
|
dicts = {} |
|
for i in range(len((symbols))): |
|
dicts[symbols[i]] = i |
|
|
|
|
|
|
|
|
|
|
|
class TextCleaner: |
|
def __init__(self, dummy=None): |
|
self.word_index_dictionary = dicts |
|
print(len(dicts)) |
|
|
|
def __call__(self, text): |
|
indexes = [] |
|
for char in text: |
|
try: |
|
indexes.append(self.word_index_dictionary[char]) |
|
except KeyError: |
|
print('CLEAN', text) |
|
return indexes |
|
|
|
|
|
textclenaer = TextCleaner() |
|
|
|
def alpha_num(f): |
|
f = re.sub(' +', ' ', f) |
|
f = re.sub(r'[^A-Z a-z0-9 ]+', '', f) |
|
return f |
|
|
|
mel_spec = MelSpec().to(device) |
|
|
|
def compute_style(path): |
|
x, sr = librosa.load(path, sr=24000) |
|
x, _ = librosa.effects.trim(x, top_db=30) |
|
if sr != 24000: |
|
x = librosa.resample(x, sr, 24000) |
|
|
|
with torch.no_grad(): |
|
x = torch.from_numpy(x[None, :]).to(device=device, dtype=torch.float) |
|
|
|
mel_tensor = (torch.log(1e-5 + mel_spec(x)) + 4) / 4 |
|
|
|
|
|
|
|
ref_s = style_encoder(mel_tensor) |
|
ref_p = predictor_encoder(mel_tensor) |
|
|
|
s = torch.cat([ref_s, ref_p], dim=3) |
|
|
|
s = s[:, :, 0, :].transpose(1, 2) |
|
return s |
|
|
|
global_phonemizer = phonemizer.backend.EspeakBackend( |
|
language='en-us', preserve_punctuation=True, with_stress=True) |
|
|
|
|
|
|
|
args = yaml.safe_load(open(str('Utils/config.yml'))) |
|
ASR_config = args['ASR_config'] |
|
|
|
bert = load_plbert(args['PLBERT_dir']).eval().to(device) |
|
|
|
decoder = Decoder(dim_in=512, |
|
style_dim=128, |
|
dim_out=80, |
|
resblock_kernel_sizes=[3, 7, 11], |
|
upsample_rates=[10, 5, 3, 2], |
|
upsample_initial_channel=512, |
|
resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]], |
|
upsample_kernel_sizes=[20, 10, 6, 4]).eval().to(device) |
|
|
|
text_encoder = TextEncoder(channels=512, |
|
kernel_size=5, |
|
depth=3, |
|
n_symbols=178, |
|
).eval().to(device) |
|
|
|
predictor = ProsodyPredictor(style_dim=128, |
|
d_hid=512, |
|
nlayers=3, |
|
max_dur=50).eval().to(device) |
|
|
|
style_encoder = StyleEncoder(dim_in=64, |
|
style_dim=128, |
|
max_conv_dim=512).eval().to(device) |
|
predictor_encoder = StyleEncoder(dim_in=64, |
|
style_dim=128, |
|
max_conv_dim=512).eval().to(device) |
|
bert_encoder = torch.nn.Linear(bert.config.hidden_size, 512).eval().to(device) |
|
|
|
|
|
params_whole = torch.load(str(cached_path( |
|
"hf://yl4579/StyleTTS2-LibriTTS/Models/LibriTTS/epochs_2nd_00020.pth")), map_location='cpu', weights_only=True) |
|
params = params_whole['net'] |
|
|
|
|
|
|
|
def _del_prefix(d): |
|
|
|
out = OrderedDict() |
|
for k, v in d.items(): |
|
out[k[7:]] = v |
|
return out |
|
|
|
|
|
bert.load_state_dict(_del_prefix(params['bert']), strict=True) |
|
bert_encoder.load_state_dict(_del_prefix(params['bert_encoder']), strict=True) |
|
|
|
predictor.load_state_dict(_del_prefix(params['predictor']), strict=True) |
|
decoder.load_state_dict(_del_prefix(params['decoder']), strict=True) |
|
text_encoder.load_state_dict(_del_prefix(params['text_encoder']), strict=True) |
|
predictor_encoder.load_state_dict(_del_prefix( |
|
params['predictor_encoder']), strict=True) |
|
style_encoder.load_state_dict(_del_prefix( |
|
params['style_encoder']), strict=True) |
|
|
|
def inference(text, |
|
ref_s): |
|
|
|
ps = global_phonemizer.phonemize([text]) |
|
ps = word_tokenize(ps[0]) |
|
ps = ' '.join(ps) |
|
tokens = textclenaer(ps) |
|
tokens.insert(0, 0) |
|
tokens = torch.LongTensor(tokens).to(device).unsqueeze(0) |
|
with torch.no_grad(): |
|
hidden_states = text_encoder(tokens) |
|
bert_dur = bert(tokens, attention_mask=torch.ones_like(tokens)) |
|
d_en = bert_encoder(bert_dur).transpose(-1, -2) |
|
|
|
aln_trg, F0_pred, N_pred = predictor(d_en=d_en, s=ref_s[:, 128:, :]) |
|
|
|
asr = torch.bmm(aln_trg, hidden_states) |
|
asr = asr.transpose(1, 2) |
|
asr = torch.cat([asr[:, :, 0:1], asr[:, :, 0:-1]], 2) |
|
x = decoder(asr=asr, |
|
F0_curve=F0_pred, |
|
N=N_pred, |
|
s=ref_s[:, :128, :]) |
|
|
|
x = x.cpu().numpy()[0, 0, :] |
|
x[-400:] = 0 |
|
|
|
|
|
|
|
if x.shape[0] > 10: |
|
|
|
x = audresample.resample(signal=x.astype(np.float32), |
|
original_rate=24000, |
|
target_rate=16000)[0, :] |
|
|
|
else: |
|
print('\n\n\n\n\nEMPTY TTS\n\n\n\n\n\nn', x.shape) |
|
x = np.zeros(0) |
|
return x |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
TTS_LANGUAGES = {} |
|
|
|
with open(f"Utils/all_langs.csv") as f: |
|
for line in f: |
|
iso, name = line.split(",", 1) |
|
TTS_LANGUAGES[iso.strip()] = name.strip() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
PHONEME_MAP = { |
|
'služ': 'sloooozz', |
|
'suver': 'siuveeerra', |
|
'država': 'dirrezav', |
|
'iči': 'ici', |
|
's ': 'se', |
|
'q': 'ku', |
|
'w': 'aou', |
|
'z': 's', |
|
"š": "s", |
|
'th': 'ta', |
|
'v': 'vv', |
|
|
|
|
|
|
|
|
|
"ž": "z", |
|
|
|
} |
|
|
|
|
|
def fix_phones(text): |
|
for src, target in PHONEME_MAP.items(): |
|
text = text.replace(src, target) |
|
|
|
|
|
|
|
return text.replace(',', '_ _').replace('.', '_ _') |
|
|
|
|
|
def has_cyrillic(text): |
|
|
|
return bool(re.search('[\u0400-\u04FF]', text)) |
|
|
|
|
|
def foreign(text=None, |
|
|
|
lang='romanian', |
|
speed=None): |
|
|
|
|
|
lang = lang.lower() |
|
|
|
|
|
|
|
if 'hun' in lang: |
|
|
|
lang_code = 'hun' |
|
|
|
elif any([i in lang for i in ['ser', 'bosn', 'herzegov', 'montenegr', 'macedon']]): |
|
|
|
if has_cyrillic(text): |
|
|
|
|
|
lang_code = 'rmc-script_cyrillic' |
|
|
|
else: |
|
|
|
|
|
lang_code = 'rmc-script_latin' |
|
|
|
elif 'rom' in lang: |
|
|
|
lang_code = 'ron' |
|
|
|
elif 'ger' in lang or 'deu' in lang or 'allem' in lang: |
|
|
|
lang_code = 'deu' |
|
|
|
elif 'alban' in lang: |
|
|
|
lang_code = 'sqi' |
|
|
|
else: |
|
|
|
lang_code = lang.split()[0].strip() |
|
|
|
|
|
|
|
|
|
|
|
global cached_lang_code, cached_net_g, cached_tokenizer |
|
|
|
if 'cached_lang_code' not in globals() or cached_lang_code != lang_code: |
|
cached_lang_code = lang_code |
|
cached_net_g = VitsModel.from_pretrained(f'facebook/mms-tts-{lang_code}').eval().to(device) |
|
cached_tokenizer = VitsTokenizer.from_pretrained(f'facebook/mms-tts-{lang_code}') |
|
|
|
net_g = cached_net_g |
|
tokenizer = cached_tokenizer |
|
|
|
|
|
total_audio = [] |
|
|
|
|
|
if not isinstance(text, list): |
|
|
|
text = [sub_sent+' ' for sub_sent in textwrap.wrap(text, 440, break_long_words=0)] |
|
|
|
for _t in text: |
|
|
|
_t = _t.lower() |
|
|
|
|
|
|
|
try: |
|
_t = transliterate_number(_t, lang=lang_code) |
|
except NotImplementedError: |
|
print('Transliterate Numbers - NotImplemented for {lang_code=}', _t,'\n____________________________________________') |
|
|
|
|
|
|
|
if lang_code == 'rmc-script_latin': |
|
|
|
_t = fix_phones(_t) |
|
|
|
elif lang_code == 'ron': |
|
|
|
|
|
_t = _t.replace("ţ", "ț" |
|
).replace('ț', 'ts').replace('î', 'u').replace('â', 'a').replace('ş', 's') |
|
|
|
|
|
|
|
inputs = tokenizer(_t, return_tensors="pt") |
|
|
|
with torch.no_grad(): |
|
|
|
|
|
|
|
x = net_g(input_ids=inputs.input_ids.to(device), |
|
attention_mask=inputs.attention_mask.to(device), |
|
lang_code=lang_code, |
|
)[0, :] |
|
|
|
|
|
|
|
total_audio.append(x) |
|
|
|
print(f'\n\n_______________________________ {_t} {x.shape=}') |
|
|
|
x = torch.cat(total_audio).cpu().numpy() |
|
|
|
|
|
|
|
return x |
|
|