Spaces:
Running
on
Zero
Running
on
Zero
Commit
·
04be12f
1
Parent(s):
ff769a6
disable deepspeed and cuda kernel
Browse files- indextts/infer_v2.py +2 -3
- indextts/s2mel/modules/.ipynb_checkpoints/audio-checkpoint.py +82 -0
- indextts/s2mel/modules/.ipynb_checkpoints/commons-checkpoint.py +610 -0
- indextts/s2mel/modules/.ipynb_checkpoints/diffusion_transformer-checkpoint.py +258 -0
- indextts/s2mel/modules/.ipynb_checkpoints/flow_matching-checkpoint.py +171 -0
- indextts/s2mel/modules/.ipynb_checkpoints/length_regulator-checkpoint.py +141 -0
- webui.py +3 -1
indextts/infer_v2.py
CHANGED
@@ -35,7 +35,7 @@ import torch.nn.functional as F
|
|
35 |
class IndexTTS2:
|
36 |
def __init__(
|
37 |
self, cfg_path="checkpoints/config.yaml", model_dir="checkpoints", is_fp16=False, device=None,
|
38 |
-
use_cuda_kernel=None,
|
39 |
):
|
40 |
"""
|
41 |
Args:
|
@@ -83,14 +83,13 @@ class IndexTTS2:
|
|
83 |
try:
|
84 |
import deepspeed
|
85 |
|
86 |
-
use_deepspeed = True
|
87 |
except (ImportError, OSError, CalledProcessError) as e:
|
88 |
use_deepspeed = False
|
89 |
print(f">> DeepSpeed加载失败,回退到标准推理: {e}")
|
90 |
|
91 |
self.gpt.post_init_gpt2_config(use_deepspeed=use_deepspeed, kv_cache=True, half=True)
|
92 |
else:
|
93 |
-
self.gpt.post_init_gpt2_config(use_deepspeed=
|
94 |
|
95 |
if self.use_cuda_kernel:
|
96 |
# preload the CUDA kernel for BigVGAN
|
|
|
35 |
class IndexTTS2:
|
36 |
def __init__(
|
37 |
self, cfg_path="checkpoints/config.yaml", model_dir="checkpoints", is_fp16=False, device=None,
|
38 |
+
use_cuda_kernel=None,use_deepspeed=False
|
39 |
):
|
40 |
"""
|
41 |
Args:
|
|
|
83 |
try:
|
84 |
import deepspeed
|
85 |
|
|
|
86 |
except (ImportError, OSError, CalledProcessError) as e:
|
87 |
use_deepspeed = False
|
88 |
print(f">> DeepSpeed加载失败,回退到标准推理: {e}")
|
89 |
|
90 |
self.gpt.post_init_gpt2_config(use_deepspeed=use_deepspeed, kv_cache=True, half=True)
|
91 |
else:
|
92 |
+
self.gpt.post_init_gpt2_config(use_deepspeed=use_deepspeed, kv_cache=True, half=False)
|
93 |
|
94 |
if self.use_cuda_kernel:
|
95 |
# preload the CUDA kernel for BigVGAN
|
indextts/s2mel/modules/.ipynb_checkpoints/audio-checkpoint.py
ADDED
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
import torch.utils.data
|
4 |
+
from librosa.filters import mel as librosa_mel_fn
|
5 |
+
from scipy.io.wavfile import read
|
6 |
+
|
7 |
+
MAX_WAV_VALUE = 32768.0
|
8 |
+
|
9 |
+
|
10 |
+
def load_wav(full_path):
|
11 |
+
sampling_rate, data = read(full_path)
|
12 |
+
return data, sampling_rate
|
13 |
+
|
14 |
+
|
15 |
+
def dynamic_range_compression(x, C=1, clip_val=1e-5):
|
16 |
+
return np.log(np.clip(x, a_min=clip_val, a_max=None) * C)
|
17 |
+
|
18 |
+
|
19 |
+
def dynamic_range_decompression(x, C=1):
|
20 |
+
return np.exp(x) / C
|
21 |
+
|
22 |
+
|
23 |
+
def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
|
24 |
+
return torch.log(torch.clamp(x, min=clip_val) * C)
|
25 |
+
|
26 |
+
|
27 |
+
def dynamic_range_decompression_torch(x, C=1):
|
28 |
+
return torch.exp(x) / C
|
29 |
+
|
30 |
+
|
31 |
+
def spectral_normalize_torch(magnitudes):
|
32 |
+
output = dynamic_range_compression_torch(magnitudes)
|
33 |
+
return output
|
34 |
+
|
35 |
+
|
36 |
+
def spectral_de_normalize_torch(magnitudes):
|
37 |
+
output = dynamic_range_decompression_torch(magnitudes)
|
38 |
+
return output
|
39 |
+
|
40 |
+
|
41 |
+
mel_basis = {}
|
42 |
+
hann_window = {}
|
43 |
+
|
44 |
+
|
45 |
+
def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False):
|
46 |
+
# if torch.min(y) < -1.0:
|
47 |
+
# print("min value is ", torch.min(y))
|
48 |
+
# if torch.max(y) > 1.0:
|
49 |
+
# print("max value is ", torch.max(y))
|
50 |
+
|
51 |
+
global mel_basis, hann_window # pylint: disable=global-statement
|
52 |
+
if f"{str(sampling_rate)}_{str(fmax)}_{str(y.device)}" not in mel_basis:
|
53 |
+
mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
|
54 |
+
mel_basis[str(sampling_rate) + "_" + str(fmax) + "_" + str(y.device)] = torch.from_numpy(mel).float().to(y.device)
|
55 |
+
hann_window[str(sampling_rate) + "_" + str(y.device)] = torch.hann_window(win_size).to(y.device)
|
56 |
+
|
57 |
+
y = torch.nn.functional.pad(
|
58 |
+
y.unsqueeze(1), (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), mode="reflect"
|
59 |
+
)
|
60 |
+
y = y.squeeze(1)
|
61 |
+
|
62 |
+
spec = torch.view_as_real(
|
63 |
+
torch.stft(
|
64 |
+
y,
|
65 |
+
n_fft,
|
66 |
+
hop_length=hop_size,
|
67 |
+
win_length=win_size,
|
68 |
+
window=hann_window[str(sampling_rate) + "_" + str(y.device)],
|
69 |
+
center=center,
|
70 |
+
pad_mode="reflect",
|
71 |
+
normalized=False,
|
72 |
+
onesided=True,
|
73 |
+
return_complex=True,
|
74 |
+
)
|
75 |
+
)
|
76 |
+
|
77 |
+
spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9))
|
78 |
+
|
79 |
+
spec = torch.matmul(mel_basis[str(sampling_rate) + "_" + str(fmax) + "_" + str(y.device)], spec)
|
80 |
+
spec = spectral_normalize_torch(spec)
|
81 |
+
|
82 |
+
return spec
|
indextts/s2mel/modules/.ipynb_checkpoints/commons-checkpoint.py
ADDED
@@ -0,0 +1,610 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
from torch import nn
|
5 |
+
from torch.nn import functional as F
|
6 |
+
from munch import Munch
|
7 |
+
import json
|
8 |
+
import argparse
|
9 |
+
from torch.nn.parallel import DistributedDataParallel as DDP
|
10 |
+
|
11 |
+
def str2bool(v):
|
12 |
+
if isinstance(v, bool):
|
13 |
+
return v
|
14 |
+
if v.lower() in ("yes", "true", "t", "y", "1"):
|
15 |
+
return True
|
16 |
+
elif v.lower() in ("no", "false", "f", "n", "0"):
|
17 |
+
return False
|
18 |
+
else:
|
19 |
+
raise argparse.ArgumentTypeError("Boolean value expected.")
|
20 |
+
|
21 |
+
class AttrDict(dict):
|
22 |
+
def __init__(self, *args, **kwargs):
|
23 |
+
super(AttrDict, self).__init__(*args, **kwargs)
|
24 |
+
self.__dict__ = self
|
25 |
+
|
26 |
+
|
27 |
+
def init_weights(m, mean=0.0, std=0.01):
|
28 |
+
classname = m.__class__.__name__
|
29 |
+
if classname.find("Conv") != -1:
|
30 |
+
m.weight.data.normal_(mean, std)
|
31 |
+
|
32 |
+
|
33 |
+
def get_padding(kernel_size, dilation=1):
|
34 |
+
return int((kernel_size * dilation - dilation) / 2)
|
35 |
+
|
36 |
+
|
37 |
+
def convert_pad_shape(pad_shape):
|
38 |
+
l = pad_shape[::-1]
|
39 |
+
pad_shape = [item for sublist in l for item in sublist]
|
40 |
+
return pad_shape
|
41 |
+
|
42 |
+
|
43 |
+
def intersperse(lst, item):
|
44 |
+
result = [item] * (len(lst) * 2 + 1)
|
45 |
+
result[1::2] = lst
|
46 |
+
return result
|
47 |
+
|
48 |
+
|
49 |
+
def kl_divergence(m_p, logs_p, m_q, logs_q):
|
50 |
+
"""KL(P||Q)"""
|
51 |
+
kl = (logs_q - logs_p) - 0.5
|
52 |
+
kl += (
|
53 |
+
0.5 * (torch.exp(2.0 * logs_p) + ((m_p - m_q) ** 2)) * torch.exp(-2.0 * logs_q)
|
54 |
+
)
|
55 |
+
return kl
|
56 |
+
|
57 |
+
|
58 |
+
def rand_gumbel(shape):
|
59 |
+
"""Sample from the Gumbel distribution, protect from overflows."""
|
60 |
+
uniform_samples = torch.rand(shape) * 0.99998 + 0.00001
|
61 |
+
return -torch.log(-torch.log(uniform_samples))
|
62 |
+
|
63 |
+
|
64 |
+
def rand_gumbel_like(x):
|
65 |
+
g = rand_gumbel(x.size()).to(dtype=x.dtype, device=x.device)
|
66 |
+
return g
|
67 |
+
|
68 |
+
|
69 |
+
def slice_segments(x, ids_str, segment_size=4):
|
70 |
+
ret = torch.zeros_like(x[:, :, :segment_size])
|
71 |
+
for i in range(x.size(0)):
|
72 |
+
idx_str = ids_str[i]
|
73 |
+
idx_end = idx_str + segment_size
|
74 |
+
ret[i] = x[i, :, idx_str:idx_end]
|
75 |
+
return ret
|
76 |
+
|
77 |
+
|
78 |
+
def slice_segments_audio(x, ids_str, segment_size=4):
|
79 |
+
ret = torch.zeros_like(x[:, :segment_size])
|
80 |
+
for i in range(x.size(0)):
|
81 |
+
idx_str = ids_str[i]
|
82 |
+
idx_end = idx_str + segment_size
|
83 |
+
ret[i] = x[i, idx_str:idx_end]
|
84 |
+
return ret
|
85 |
+
|
86 |
+
|
87 |
+
def rand_slice_segments(x, x_lengths=None, segment_size=4):
|
88 |
+
b, d, t = x.size()
|
89 |
+
if x_lengths is None:
|
90 |
+
x_lengths = t
|
91 |
+
ids_str_max = x_lengths - segment_size + 1
|
92 |
+
ids_str = ((torch.rand([b]).to(device=x.device) * ids_str_max).clip(0)).to(
|
93 |
+
dtype=torch.long
|
94 |
+
)
|
95 |
+
ret = slice_segments(x, ids_str, segment_size)
|
96 |
+
return ret, ids_str
|
97 |
+
|
98 |
+
|
99 |
+
def get_timing_signal_1d(length, channels, min_timescale=1.0, max_timescale=1.0e4):
|
100 |
+
position = torch.arange(length, dtype=torch.float)
|
101 |
+
num_timescales = channels // 2
|
102 |
+
log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / (
|
103 |
+
num_timescales - 1
|
104 |
+
)
|
105 |
+
inv_timescales = min_timescale * torch.exp(
|
106 |
+
torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment
|
107 |
+
)
|
108 |
+
scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1)
|
109 |
+
signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0)
|
110 |
+
signal = F.pad(signal, [0, 0, 0, channels % 2])
|
111 |
+
signal = signal.view(1, channels, length)
|
112 |
+
return signal
|
113 |
+
|
114 |
+
|
115 |
+
def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4):
|
116 |
+
b, channels, length = x.size()
|
117 |
+
signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
|
118 |
+
return x + signal.to(dtype=x.dtype, device=x.device)
|
119 |
+
|
120 |
+
|
121 |
+
def cat_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4, axis=1):
|
122 |
+
b, channels, length = x.size()
|
123 |
+
signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
|
124 |
+
return torch.cat([x, signal.to(dtype=x.dtype, device=x.device)], axis)
|
125 |
+
|
126 |
+
|
127 |
+
def subsequent_mask(length):
|
128 |
+
mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0)
|
129 |
+
return mask
|
130 |
+
|
131 |
+
|
132 |
+
@torch.jit.script
|
133 |
+
def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
|
134 |
+
n_channels_int = n_channels[0]
|
135 |
+
in_act = input_a + input_b
|
136 |
+
t_act = torch.tanh(in_act[:, :n_channels_int, :])
|
137 |
+
s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
|
138 |
+
acts = t_act * s_act
|
139 |
+
return acts
|
140 |
+
|
141 |
+
|
142 |
+
def convert_pad_shape(pad_shape):
|
143 |
+
l = pad_shape[::-1]
|
144 |
+
pad_shape = [item for sublist in l for item in sublist]
|
145 |
+
return pad_shape
|
146 |
+
|
147 |
+
|
148 |
+
def shift_1d(x):
|
149 |
+
x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1]
|
150 |
+
return x
|
151 |
+
|
152 |
+
|
153 |
+
def sequence_mask(length, max_length=None):
|
154 |
+
if max_length is None:
|
155 |
+
max_length = length.max()
|
156 |
+
x = torch.arange(max_length, dtype=length.dtype, device=length.device)
|
157 |
+
return x.unsqueeze(0) < length.unsqueeze(1)
|
158 |
+
|
159 |
+
|
160 |
+
def avg_with_mask(x, mask):
|
161 |
+
assert mask.dtype == torch.float, "Mask should be float"
|
162 |
+
|
163 |
+
if mask.ndim == 2:
|
164 |
+
mask = mask.unsqueeze(1)
|
165 |
+
|
166 |
+
if mask.shape[1] == 1:
|
167 |
+
mask = mask.expand_as(x)
|
168 |
+
|
169 |
+
return (x * mask).sum() / mask.sum()
|
170 |
+
|
171 |
+
|
172 |
+
def generate_path(duration, mask):
|
173 |
+
"""
|
174 |
+
duration: [b, 1, t_x]
|
175 |
+
mask: [b, 1, t_y, t_x]
|
176 |
+
"""
|
177 |
+
device = duration.device
|
178 |
+
|
179 |
+
b, _, t_y, t_x = mask.shape
|
180 |
+
cum_duration = torch.cumsum(duration, -1)
|
181 |
+
|
182 |
+
cum_duration_flat = cum_duration.view(b * t_x)
|
183 |
+
path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype)
|
184 |
+
path = path.view(b, t_x, t_y)
|
185 |
+
path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1]
|
186 |
+
path = path.unsqueeze(1).transpose(2, 3) * mask
|
187 |
+
return path
|
188 |
+
|
189 |
+
|
190 |
+
def clip_grad_value_(parameters, clip_value, norm_type=2):
|
191 |
+
if isinstance(parameters, torch.Tensor):
|
192 |
+
parameters = [parameters]
|
193 |
+
parameters = list(filter(lambda p: p.grad is not None, parameters))
|
194 |
+
norm_type = float(norm_type)
|
195 |
+
if clip_value is not None:
|
196 |
+
clip_value = float(clip_value)
|
197 |
+
|
198 |
+
total_norm = 0
|
199 |
+
for p in parameters:
|
200 |
+
param_norm = p.grad.data.norm(norm_type)
|
201 |
+
total_norm += param_norm.item() ** norm_type
|
202 |
+
if clip_value is not None:
|
203 |
+
p.grad.data.clamp_(min=-clip_value, max=clip_value)
|
204 |
+
total_norm = total_norm ** (1.0 / norm_type)
|
205 |
+
return total_norm
|
206 |
+
|
207 |
+
|
208 |
+
def log_norm(x, mean=-4, std=4, dim=2):
|
209 |
+
"""
|
210 |
+
normalized log mel -> mel -> norm -> log(norm)
|
211 |
+
"""
|
212 |
+
x = torch.log(torch.exp(x * std + mean).norm(dim=dim))
|
213 |
+
return x
|
214 |
+
|
215 |
+
|
216 |
+
def load_F0_models(path):
|
217 |
+
# load F0 model
|
218 |
+
from .JDC.model import JDCNet
|
219 |
+
|
220 |
+
F0_model = JDCNet(num_class=1, seq_len=192)
|
221 |
+
params = torch.load(path, map_location="cpu")["net"]
|
222 |
+
F0_model.load_state_dict(params)
|
223 |
+
_ = F0_model.train()
|
224 |
+
|
225 |
+
return F0_model
|
226 |
+
|
227 |
+
|
228 |
+
def modify_w2v_forward(self, output_layer=15):
|
229 |
+
"""
|
230 |
+
change forward method of w2v encoder to get its intermediate layer output
|
231 |
+
:param self:
|
232 |
+
:param layer:
|
233 |
+
:return:
|
234 |
+
"""
|
235 |
+
from transformers.modeling_outputs import BaseModelOutput
|
236 |
+
|
237 |
+
def forward(
|
238 |
+
hidden_states,
|
239 |
+
attention_mask=None,
|
240 |
+
output_attentions=False,
|
241 |
+
output_hidden_states=False,
|
242 |
+
return_dict=True,
|
243 |
+
):
|
244 |
+
all_hidden_states = () if output_hidden_states else None
|
245 |
+
all_self_attentions = () if output_attentions else None
|
246 |
+
|
247 |
+
conv_attention_mask = attention_mask
|
248 |
+
if attention_mask is not None:
|
249 |
+
# make sure padded tokens output 0
|
250 |
+
hidden_states = hidden_states.masked_fill(
|
251 |
+
~attention_mask.bool().unsqueeze(-1), 0.0
|
252 |
+
)
|
253 |
+
|
254 |
+
# extend attention_mask
|
255 |
+
attention_mask = 1.0 - attention_mask[:, None, None, :].to(
|
256 |
+
dtype=hidden_states.dtype
|
257 |
+
)
|
258 |
+
attention_mask = attention_mask * torch.finfo(hidden_states.dtype).min
|
259 |
+
attention_mask = attention_mask.expand(
|
260 |
+
attention_mask.shape[0],
|
261 |
+
1,
|
262 |
+
attention_mask.shape[-1],
|
263 |
+
attention_mask.shape[-1],
|
264 |
+
)
|
265 |
+
|
266 |
+
hidden_states = self.dropout(hidden_states)
|
267 |
+
|
268 |
+
if self.embed_positions is not None:
|
269 |
+
relative_position_embeddings = self.embed_positions(hidden_states)
|
270 |
+
else:
|
271 |
+
relative_position_embeddings = None
|
272 |
+
|
273 |
+
deepspeed_zero3_is_enabled = False
|
274 |
+
|
275 |
+
for i, layer in enumerate(self.layers):
|
276 |
+
if output_hidden_states:
|
277 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
278 |
+
|
279 |
+
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
|
280 |
+
dropout_probability = torch.rand([])
|
281 |
+
|
282 |
+
skip_the_layer = (
|
283 |
+
True
|
284 |
+
if self.training and (dropout_probability < self.config.layerdrop)
|
285 |
+
else False
|
286 |
+
)
|
287 |
+
if not skip_the_layer or deepspeed_zero3_is_enabled:
|
288 |
+
# under deepspeed zero3 all gpus must run in sync
|
289 |
+
if self.gradient_checkpointing and self.training:
|
290 |
+
layer_outputs = self._gradient_checkpointing_func(
|
291 |
+
layer.__call__,
|
292 |
+
hidden_states,
|
293 |
+
attention_mask,
|
294 |
+
relative_position_embeddings,
|
295 |
+
output_attentions,
|
296 |
+
conv_attention_mask,
|
297 |
+
)
|
298 |
+
else:
|
299 |
+
layer_outputs = layer(
|
300 |
+
hidden_states,
|
301 |
+
attention_mask=attention_mask,
|
302 |
+
relative_position_embeddings=relative_position_embeddings,
|
303 |
+
output_attentions=output_attentions,
|
304 |
+
conv_attention_mask=conv_attention_mask,
|
305 |
+
)
|
306 |
+
hidden_states = layer_outputs[0]
|
307 |
+
|
308 |
+
if skip_the_layer:
|
309 |
+
layer_outputs = (None, None)
|
310 |
+
|
311 |
+
if output_attentions:
|
312 |
+
all_self_attentions = all_self_attentions + (layer_outputs[1],)
|
313 |
+
|
314 |
+
if i == output_layer - 1:
|
315 |
+
break
|
316 |
+
|
317 |
+
if output_hidden_states:
|
318 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
319 |
+
|
320 |
+
if not return_dict:
|
321 |
+
return tuple(
|
322 |
+
v
|
323 |
+
for v in [hidden_states, all_hidden_states, all_self_attentions]
|
324 |
+
if v is not None
|
325 |
+
)
|
326 |
+
return BaseModelOutput(
|
327 |
+
last_hidden_state=hidden_states,
|
328 |
+
hidden_states=all_hidden_states,
|
329 |
+
attentions=all_self_attentions,
|
330 |
+
)
|
331 |
+
|
332 |
+
return forward
|
333 |
+
|
334 |
+
|
335 |
+
MATPLOTLIB_FLAG = False
|
336 |
+
|
337 |
+
|
338 |
+
def plot_spectrogram_to_numpy(spectrogram):
|
339 |
+
global MATPLOTLIB_FLAG
|
340 |
+
if not MATPLOTLIB_FLAG:
|
341 |
+
import matplotlib
|
342 |
+
import logging
|
343 |
+
|
344 |
+
matplotlib.use("Agg")
|
345 |
+
MATPLOTLIB_FLAG = True
|
346 |
+
mpl_logger = logging.getLogger("matplotlib")
|
347 |
+
mpl_logger.setLevel(logging.WARNING)
|
348 |
+
import matplotlib.pylab as plt
|
349 |
+
import numpy as np
|
350 |
+
|
351 |
+
fig, ax = plt.subplots(figsize=(10, 2))
|
352 |
+
im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none")
|
353 |
+
plt.colorbar(im, ax=ax)
|
354 |
+
plt.xlabel("Frames")
|
355 |
+
plt.ylabel("Channels")
|
356 |
+
plt.tight_layout()
|
357 |
+
|
358 |
+
fig.canvas.draw()
|
359 |
+
data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="")
|
360 |
+
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
|
361 |
+
plt.close()
|
362 |
+
return data
|
363 |
+
|
364 |
+
|
365 |
+
def normalize_f0(f0_sequence):
|
366 |
+
# Remove unvoiced frames (replace with -1)
|
367 |
+
voiced_indices = np.where(f0_sequence > 0)[0]
|
368 |
+
f0_voiced = f0_sequence[voiced_indices]
|
369 |
+
|
370 |
+
# Convert to log scale
|
371 |
+
log_f0 = np.log2(f0_voiced)
|
372 |
+
|
373 |
+
# Calculate mean and standard deviation
|
374 |
+
mean_f0 = np.mean(log_f0)
|
375 |
+
std_f0 = np.std(log_f0)
|
376 |
+
|
377 |
+
# Normalize the F0 sequence
|
378 |
+
normalized_f0 = (log_f0 - mean_f0) / std_f0
|
379 |
+
|
380 |
+
# Create the normalized F0 sequence with unvoiced frames
|
381 |
+
normalized_sequence = np.zeros_like(f0_sequence)
|
382 |
+
normalized_sequence[voiced_indices] = normalized_f0
|
383 |
+
normalized_sequence[f0_sequence <= 0] = -1 # Assign -1 to unvoiced frames
|
384 |
+
|
385 |
+
return normalized_sequence
|
386 |
+
|
387 |
+
|
388 |
+
class MyModel(nn.Module):
|
389 |
+
def __init__(self,args):
|
390 |
+
super(MyModel, self).__init__()
|
391 |
+
from modules.flow_matching import CFM
|
392 |
+
from modules.length_regulator import InterpolateRegulator
|
393 |
+
|
394 |
+
length_regulator = InterpolateRegulator(
|
395 |
+
channels=args.length_regulator.channels,
|
396 |
+
sampling_ratios=args.length_regulator.sampling_ratios,
|
397 |
+
is_discrete=args.length_regulator.is_discrete,
|
398 |
+
in_channels=args.length_regulator.in_channels if hasattr(args.length_regulator, "in_channels") else None,
|
399 |
+
vector_quantize=args.length_regulator.vector_quantize if hasattr(args.length_regulator, "vector_quantize") else False,
|
400 |
+
codebook_size=args.length_regulator.content_codebook_size,
|
401 |
+
n_codebooks=args.length_regulator.n_codebooks if hasattr(args.length_regulator, "n_codebooks") else 1,
|
402 |
+
quantizer_dropout=args.length_regulator.quantizer_dropout if hasattr(args.length_regulator, "quantizer_dropout") else 0.0,
|
403 |
+
f0_condition=args.length_regulator.f0_condition if hasattr(args.length_regulator, "f0_condition") else False,
|
404 |
+
n_f0_bins=args.length_regulator.n_f0_bins if hasattr(args.length_regulator, "n_f0_bins") else 512,
|
405 |
+
)
|
406 |
+
|
407 |
+
self.models = nn.ModuleDict({
|
408 |
+
'cfm': CFM(args),
|
409 |
+
'length_regulator': length_regulator
|
410 |
+
})
|
411 |
+
|
412 |
+
def forward(self, x, target_lengths, prompt_len, cond, y):
|
413 |
+
x = self.models['cfm'](x, target_lengths, prompt_len, cond, y)
|
414 |
+
return x
|
415 |
+
|
416 |
+
def forward2(self, S_ori,target_lengths,F0_ori):
|
417 |
+
x = self.models['length_regulator'](S_ori, ylens=target_lengths, f0=F0_ori)
|
418 |
+
return x
|
419 |
+
|
420 |
+
def build_model(args, stage="DiT"):
|
421 |
+
if stage == "DiT":
|
422 |
+
from modules.flow_matching import CFM
|
423 |
+
from modules.length_regulator import InterpolateRegulator
|
424 |
+
|
425 |
+
length_regulator = InterpolateRegulator(
|
426 |
+
channels=args.length_regulator.channels,
|
427 |
+
sampling_ratios=args.length_regulator.sampling_ratios,
|
428 |
+
is_discrete=args.length_regulator.is_discrete,
|
429 |
+
in_channels=args.length_regulator.in_channels if hasattr(args.length_regulator, "in_channels") else None,
|
430 |
+
vector_quantize=args.length_regulator.vector_quantize if hasattr(args.length_regulator, "vector_quantize") else False,
|
431 |
+
codebook_size=args.length_regulator.content_codebook_size,
|
432 |
+
n_codebooks=args.length_regulator.n_codebooks if hasattr(args.length_regulator, "n_codebooks") else 1,
|
433 |
+
quantizer_dropout=args.length_regulator.quantizer_dropout if hasattr(args.length_regulator, "quantizer_dropout") else 0.0,
|
434 |
+
f0_condition=args.length_regulator.f0_condition if hasattr(args.length_regulator, "f0_condition") else False,
|
435 |
+
n_f0_bins=args.length_regulator.n_f0_bins if hasattr(args.length_regulator, "n_f0_bins") else 512,
|
436 |
+
)
|
437 |
+
cfm = CFM(args)
|
438 |
+
nets = Munch(
|
439 |
+
cfm=cfm,
|
440 |
+
length_regulator=length_regulator,
|
441 |
+
)
|
442 |
+
|
443 |
+
elif stage == 'codec':
|
444 |
+
from dac.model.dac import Encoder
|
445 |
+
from modules.quantize import (
|
446 |
+
FAquantizer,
|
447 |
+
)
|
448 |
+
|
449 |
+
encoder = Encoder(
|
450 |
+
d_model=args.DAC.encoder_dim,
|
451 |
+
strides=args.DAC.encoder_rates,
|
452 |
+
d_latent=1024,
|
453 |
+
causal=args.causal,
|
454 |
+
lstm=args.lstm,
|
455 |
+
)
|
456 |
+
|
457 |
+
quantizer = FAquantizer(
|
458 |
+
in_dim=1024,
|
459 |
+
n_p_codebooks=1,
|
460 |
+
n_c_codebooks=args.n_c_codebooks,
|
461 |
+
n_t_codebooks=2,
|
462 |
+
n_r_codebooks=3,
|
463 |
+
codebook_size=1024,
|
464 |
+
codebook_dim=8,
|
465 |
+
quantizer_dropout=0.5,
|
466 |
+
causal=args.causal,
|
467 |
+
separate_prosody_encoder=args.separate_prosody_encoder,
|
468 |
+
timbre_norm=args.timbre_norm,
|
469 |
+
)
|
470 |
+
|
471 |
+
nets = Munch(
|
472 |
+
encoder=encoder,
|
473 |
+
quantizer=quantizer,
|
474 |
+
)
|
475 |
+
|
476 |
+
elif stage == "mel_vocos":
|
477 |
+
from modules.vocos import Vocos
|
478 |
+
decoder = Vocos(args)
|
479 |
+
nets = Munch(
|
480 |
+
decoder=decoder,
|
481 |
+
)
|
482 |
+
|
483 |
+
else:
|
484 |
+
raise ValueError(f"Unknown stage: {stage}")
|
485 |
+
|
486 |
+
return nets
|
487 |
+
|
488 |
+
|
489 |
+
def load_checkpoint(
|
490 |
+
model,
|
491 |
+
optimizer,
|
492 |
+
path,
|
493 |
+
load_only_params=True,
|
494 |
+
ignore_modules=[],
|
495 |
+
is_distributed=False,
|
496 |
+
load_ema=False,
|
497 |
+
):
|
498 |
+
state = torch.load(path, map_location="cpu")
|
499 |
+
params = state["net"]
|
500 |
+
if load_ema and "ema" in state:
|
501 |
+
print("Loading EMA")
|
502 |
+
for key in model:
|
503 |
+
i = 0
|
504 |
+
for param_name in params[key]:
|
505 |
+
if "input_pos" in param_name:
|
506 |
+
continue
|
507 |
+
assert params[key][param_name].shape == state["ema"][key][0][i].shape
|
508 |
+
params[key][param_name] = state["ema"][key][0][i].clone()
|
509 |
+
i += 1
|
510 |
+
for key in model:
|
511 |
+
if key in params and key not in ignore_modules:
|
512 |
+
if not is_distributed:
|
513 |
+
# strip prefix of DDP (module.), create a new OrderedDict that does not contain the prefix
|
514 |
+
for k in list(params[key].keys()):
|
515 |
+
if k.startswith("module."):
|
516 |
+
params[key][k[len("module.") :]] = params[key][k]
|
517 |
+
del params[key][k]
|
518 |
+
model_state_dict = model[key].state_dict()
|
519 |
+
# 过滤出形状匹配的键值对
|
520 |
+
filtered_state_dict = {
|
521 |
+
k: v
|
522 |
+
for k, v in params[key].items()
|
523 |
+
if k in model_state_dict and v.shape == model_state_dict[k].shape
|
524 |
+
}
|
525 |
+
skipped_keys = set(params[key].keys()) - set(filtered_state_dict.keys())
|
526 |
+
if skipped_keys:
|
527 |
+
print(
|
528 |
+
f"Warning: Skipped loading some keys due to shape mismatch: {skipped_keys}"
|
529 |
+
)
|
530 |
+
print("%s loaded" % key)
|
531 |
+
model[key].load_state_dict(filtered_state_dict, strict=False)
|
532 |
+
_ = [model[key].eval() for key in model]
|
533 |
+
|
534 |
+
if not load_only_params:
|
535 |
+
epoch = state["epoch"] + 1
|
536 |
+
iters = state["iters"]
|
537 |
+
optimizer.load_state_dict(state["optimizer"])
|
538 |
+
optimizer.load_scheduler_state_dict(state["scheduler"])
|
539 |
+
|
540 |
+
else:
|
541 |
+
epoch = 0
|
542 |
+
iters = 0
|
543 |
+
|
544 |
+
return model, optimizer, epoch, iters
|
545 |
+
|
546 |
+
def load_checkpoint2(
|
547 |
+
model,
|
548 |
+
optimizer,
|
549 |
+
path,
|
550 |
+
load_only_params=True,
|
551 |
+
ignore_modules=[],
|
552 |
+
is_distributed=False,
|
553 |
+
load_ema=False,
|
554 |
+
):
|
555 |
+
state = torch.load(path, map_location="cpu")
|
556 |
+
params = state["net"]
|
557 |
+
if load_ema and "ema" in state:
|
558 |
+
print("Loading EMA")
|
559 |
+
for key in model.models:
|
560 |
+
i = 0
|
561 |
+
for param_name in params[key]:
|
562 |
+
if "input_pos" in param_name:
|
563 |
+
continue
|
564 |
+
assert params[key][param_name].shape == state["ema"][key][0][i].shape
|
565 |
+
params[key][param_name] = state["ema"][key][0][i].clone()
|
566 |
+
i += 1
|
567 |
+
for key in model.models:
|
568 |
+
if key in params and key not in ignore_modules:
|
569 |
+
if not is_distributed:
|
570 |
+
# strip prefix of DDP (module.), create a new OrderedDict that does not contain the prefix
|
571 |
+
for k in list(params[key].keys()):
|
572 |
+
if k.startswith("module."):
|
573 |
+
params[key][k[len("module.") :]] = params[key][k]
|
574 |
+
del params[key][k]
|
575 |
+
model_state_dict = model.models[key].state_dict()
|
576 |
+
# 过滤出形状匹配的键值对
|
577 |
+
filtered_state_dict = {
|
578 |
+
k: v
|
579 |
+
for k, v in params[key].items()
|
580 |
+
if k in model_state_dict and v.shape == model_state_dict[k].shape
|
581 |
+
}
|
582 |
+
skipped_keys = set(params[key].keys()) - set(filtered_state_dict.keys())
|
583 |
+
if skipped_keys:
|
584 |
+
print(
|
585 |
+
f"Warning: Skipped loading some keys due to shape mismatch: {skipped_keys}"
|
586 |
+
)
|
587 |
+
print("%s loaded" % key)
|
588 |
+
model.models[key].load_state_dict(filtered_state_dict, strict=False)
|
589 |
+
model.eval()
|
590 |
+
# _ = [model[key].eval() for key in model]
|
591 |
+
|
592 |
+
if not load_only_params:
|
593 |
+
epoch = state["epoch"] + 1
|
594 |
+
iters = state["iters"]
|
595 |
+
optimizer.load_state_dict(state["optimizer"])
|
596 |
+
optimizer.load_scheduler_state_dict(state["scheduler"])
|
597 |
+
|
598 |
+
else:
|
599 |
+
epoch = 0
|
600 |
+
iters = 0
|
601 |
+
|
602 |
+
return model, optimizer, epoch, iters
|
603 |
+
|
604 |
+
def recursive_munch(d):
|
605 |
+
if isinstance(d, dict):
|
606 |
+
return Munch((k, recursive_munch(v)) for k, v in d.items())
|
607 |
+
elif isinstance(d, list):
|
608 |
+
return [recursive_munch(v) for v in d]
|
609 |
+
else:
|
610 |
+
return d
|
indextts/s2mel/modules/.ipynb_checkpoints/diffusion_transformer-checkpoint.py
ADDED
@@ -0,0 +1,258 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
import math
|
4 |
+
|
5 |
+
from modules.gpt_fast.model import ModelArgs, Transformer
|
6 |
+
# from modules.torchscript_modules.gpt_fast_model import ModelArgs, Transformer
|
7 |
+
from modules.wavenet import WN
|
8 |
+
from modules.commons import sequence_mask
|
9 |
+
|
10 |
+
from torch.nn.utils import weight_norm
|
11 |
+
|
12 |
+
def modulate(x, shift, scale):
|
13 |
+
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
14 |
+
|
15 |
+
|
16 |
+
#################################################################################
|
17 |
+
# Embedding Layers for Timesteps and Class Labels #
|
18 |
+
#################################################################################
|
19 |
+
|
20 |
+
class TimestepEmbedder(nn.Module):
|
21 |
+
"""
|
22 |
+
Embeds scalar timesteps into vector representations.
|
23 |
+
"""
|
24 |
+
def __init__(self, hidden_size, frequency_embedding_size=256):
|
25 |
+
super().__init__()
|
26 |
+
self.mlp = nn.Sequential(
|
27 |
+
nn.Linear(frequency_embedding_size, hidden_size, bias=True),
|
28 |
+
nn.SiLU(),
|
29 |
+
nn.Linear(hidden_size, hidden_size, bias=True),
|
30 |
+
)
|
31 |
+
self.frequency_embedding_size = frequency_embedding_size
|
32 |
+
self.max_period = 10000
|
33 |
+
self.scale = 1000
|
34 |
+
|
35 |
+
half = frequency_embedding_size // 2
|
36 |
+
freqs = torch.exp(
|
37 |
+
-math.log(self.max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
|
38 |
+
)
|
39 |
+
self.register_buffer("freqs", freqs)
|
40 |
+
|
41 |
+
def timestep_embedding(self, t):
|
42 |
+
"""
|
43 |
+
Create sinusoidal timestep embeddings.
|
44 |
+
:param t: a 1-D Tensor of N indices, one per batch element.
|
45 |
+
These may be fractional.
|
46 |
+
:param dim: the dimension of the output.
|
47 |
+
:param max_period: controls the minimum frequency of the embeddings.
|
48 |
+
:return: an (N, D) Tensor of positional embeddings.
|
49 |
+
"""
|
50 |
+
# https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
|
51 |
+
|
52 |
+
args = self.scale * t[:, None].float() * self.freqs[None]
|
53 |
+
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
54 |
+
if self.frequency_embedding_size % 2:
|
55 |
+
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
56 |
+
return embedding
|
57 |
+
|
58 |
+
def forward(self, t):
|
59 |
+
t_freq = self.timestep_embedding(t)
|
60 |
+
t_emb = self.mlp(t_freq)
|
61 |
+
return t_emb
|
62 |
+
|
63 |
+
|
64 |
+
class StyleEmbedder(nn.Module):
|
65 |
+
"""
|
66 |
+
Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
|
67 |
+
"""
|
68 |
+
def __init__(self, input_size, hidden_size, dropout_prob):
|
69 |
+
super().__init__()
|
70 |
+
use_cfg_embedding = dropout_prob > 0
|
71 |
+
self.embedding_table = nn.Embedding(int(use_cfg_embedding), hidden_size)
|
72 |
+
self.style_in = weight_norm(nn.Linear(input_size, hidden_size, bias=True))
|
73 |
+
self.input_size = input_size
|
74 |
+
self.dropout_prob = dropout_prob
|
75 |
+
|
76 |
+
def forward(self, labels, train, force_drop_ids=None):
|
77 |
+
use_dropout = self.dropout_prob > 0
|
78 |
+
if (train and use_dropout) or (force_drop_ids is not None):
|
79 |
+
labels = self.token_drop(labels, force_drop_ids)
|
80 |
+
else:
|
81 |
+
labels = self.style_in(labels)
|
82 |
+
embeddings = labels
|
83 |
+
return embeddings
|
84 |
+
|
85 |
+
class FinalLayer(nn.Module):
|
86 |
+
"""
|
87 |
+
The final layer of DiT.
|
88 |
+
"""
|
89 |
+
def __init__(self, hidden_size, patch_size, out_channels):
|
90 |
+
super().__init__()
|
91 |
+
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
92 |
+
self.linear = weight_norm(nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True))
|
93 |
+
self.adaLN_modulation = nn.Sequential(
|
94 |
+
nn.SiLU(),
|
95 |
+
nn.Linear(hidden_size, 2 * hidden_size, bias=True)
|
96 |
+
)
|
97 |
+
|
98 |
+
def forward(self, x, c):
|
99 |
+
shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
|
100 |
+
x = modulate(self.norm_final(x), shift, scale)
|
101 |
+
x = self.linear(x)
|
102 |
+
return x
|
103 |
+
|
104 |
+
class DiT(torch.nn.Module):
|
105 |
+
def __init__(
|
106 |
+
self,
|
107 |
+
args
|
108 |
+
):
|
109 |
+
super(DiT, self).__init__()
|
110 |
+
self.time_as_token = args.DiT.time_as_token if hasattr(args.DiT, 'time_as_token') else False
|
111 |
+
self.style_as_token = args.DiT.style_as_token if hasattr(args.DiT, 'style_as_token') else False
|
112 |
+
self.uvit_skip_connection = args.DiT.uvit_skip_connection if hasattr(args.DiT, 'uvit_skip_connection') else False
|
113 |
+
model_args = ModelArgs(
|
114 |
+
block_size=16384,#args.DiT.block_size,
|
115 |
+
n_layer=args.DiT.depth,
|
116 |
+
n_head=args.DiT.num_heads,
|
117 |
+
dim=args.DiT.hidden_dim,
|
118 |
+
head_dim=args.DiT.hidden_dim // args.DiT.num_heads,
|
119 |
+
vocab_size=1024,
|
120 |
+
uvit_skip_connection=self.uvit_skip_connection,
|
121 |
+
time_as_token=self.time_as_token,
|
122 |
+
)
|
123 |
+
self.transformer = Transformer(model_args)
|
124 |
+
self.in_channels = args.DiT.in_channels
|
125 |
+
self.out_channels = args.DiT.in_channels
|
126 |
+
self.num_heads = args.DiT.num_heads
|
127 |
+
|
128 |
+
self.x_embedder = weight_norm(nn.Linear(args.DiT.in_channels, args.DiT.hidden_dim, bias=True))
|
129 |
+
|
130 |
+
self.content_type = args.DiT.content_type # 'discrete' or 'continuous'
|
131 |
+
self.content_codebook_size = args.DiT.content_codebook_size # for discrete content
|
132 |
+
self.content_dim = args.DiT.content_dim # for continuous content
|
133 |
+
self.cond_embedder = nn.Embedding(args.DiT.content_codebook_size, args.DiT.hidden_dim) # discrete content
|
134 |
+
self.cond_projection = nn.Linear(args.DiT.content_dim, args.DiT.hidden_dim, bias=True) # continuous content
|
135 |
+
|
136 |
+
self.is_causal = args.DiT.is_causal
|
137 |
+
|
138 |
+
self.t_embedder = TimestepEmbedder(args.DiT.hidden_dim)
|
139 |
+
|
140 |
+
# self.style_embedder1 = weight_norm(nn.Linear(1024, args.DiT.hidden_dim, bias=True))
|
141 |
+
# self.style_embedder2 = weight_norm(nn.Linear(1024, args.style_encoder.dim, bias=True))
|
142 |
+
|
143 |
+
input_pos = torch.arange(16384)
|
144 |
+
self.register_buffer("input_pos", input_pos)
|
145 |
+
|
146 |
+
self.final_layer_type = args.DiT.final_layer_type # mlp or wavenet
|
147 |
+
if self.final_layer_type == 'wavenet':
|
148 |
+
self.t_embedder2 = TimestepEmbedder(args.wavenet.hidden_dim)
|
149 |
+
self.conv1 = nn.Linear(args.DiT.hidden_dim, args.wavenet.hidden_dim)
|
150 |
+
self.conv2 = nn.Conv1d(args.wavenet.hidden_dim, args.DiT.in_channels, 1)
|
151 |
+
self.wavenet = WN(hidden_channels=args.wavenet.hidden_dim,
|
152 |
+
kernel_size=args.wavenet.kernel_size,
|
153 |
+
dilation_rate=args.wavenet.dilation_rate,
|
154 |
+
n_layers=args.wavenet.num_layers,
|
155 |
+
gin_channels=args.wavenet.hidden_dim,
|
156 |
+
p_dropout=args.wavenet.p_dropout,
|
157 |
+
causal=False)
|
158 |
+
self.final_layer = FinalLayer(args.wavenet.hidden_dim, 1, args.wavenet.hidden_dim)
|
159 |
+
self.res_projection = nn.Linear(args.DiT.hidden_dim,
|
160 |
+
args.wavenet.hidden_dim) # residual connection from tranformer output to final output
|
161 |
+
self.wavenet_style_condition = args.wavenet.style_condition
|
162 |
+
assert args.DiT.style_condition == args.wavenet.style_condition
|
163 |
+
else:
|
164 |
+
self.final_mlp = nn.Sequential(
|
165 |
+
nn.Linear(args.DiT.hidden_dim, args.DiT.hidden_dim),
|
166 |
+
nn.SiLU(),
|
167 |
+
nn.Linear(args.DiT.hidden_dim, args.DiT.in_channels),
|
168 |
+
)
|
169 |
+
self.transformer_style_condition = args.DiT.style_condition
|
170 |
+
|
171 |
+
|
172 |
+
self.class_dropout_prob = args.DiT.class_dropout_prob
|
173 |
+
self.content_mask_embedder = nn.Embedding(1, args.DiT.hidden_dim)
|
174 |
+
|
175 |
+
self.long_skip_connection = args.DiT.long_skip_connection
|
176 |
+
self.skip_linear = nn.Linear(args.DiT.hidden_dim + args.DiT.in_channels, args.DiT.hidden_dim)
|
177 |
+
|
178 |
+
self.cond_x_merge_linear = nn.Linear(args.DiT.hidden_dim + args.DiT.in_channels * 2 +
|
179 |
+
args.style_encoder.dim * self.transformer_style_condition * (not self.style_as_token),
|
180 |
+
args.DiT.hidden_dim)
|
181 |
+
if self.style_as_token:
|
182 |
+
self.style_in = nn.Linear(args.style_encoder.dim, args.DiT.hidden_dim)
|
183 |
+
|
184 |
+
def setup_caches(self, max_batch_size, max_seq_length):
|
185 |
+
self.transformer.setup_caches(max_batch_size, max_seq_length, use_kv_cache=False)
|
186 |
+
|
187 |
+
def forward(self, x, prompt_x, x_lens, t, style, cond, mask_content=False):
|
188 |
+
"""
|
189 |
+
x (torch.Tensor): random noise
|
190 |
+
prompt_x (torch.Tensor): reference mel + zero mel
|
191 |
+
shape: (batch_size, 80, 795+1068)
|
192 |
+
x_lens (torch.Tensor): mel frames output
|
193 |
+
shape: (batch_size, mel_timesteps)
|
194 |
+
t (torch.Tensor): radshape:
|
195 |
+
shape: (batch_size)
|
196 |
+
style (torch.Tensor): reference global style
|
197 |
+
shape: (batch_size, 192)
|
198 |
+
cond (torch.Tensor): semantic info of reference audio and altered audio
|
199 |
+
shape: (batch_size, mel_timesteps(795+1069), 512)
|
200 |
+
|
201 |
+
"""
|
202 |
+
class_dropout = False
|
203 |
+
if self.training and torch.rand(1) < self.class_dropout_prob:
|
204 |
+
class_dropout = True
|
205 |
+
if not self.training and mask_content:
|
206 |
+
class_dropout = True
|
207 |
+
# cond_in_module = self.cond_embedder if self.content_type == 'discrete' else self.cond_projection
|
208 |
+
cond_in_module = self.cond_projection
|
209 |
+
|
210 |
+
B, _, T = x.size()
|
211 |
+
|
212 |
+
|
213 |
+
t1 = self.t_embedder(t) # (N, D) # t1 [2, 512]
|
214 |
+
cond = cond_in_module(cond) # cond [2,1863,512]->[2,1863,512]
|
215 |
+
|
216 |
+
x = x.transpose(1, 2) # [2,1863,80]
|
217 |
+
prompt_x = prompt_x.transpose(1, 2) # [2,1863,80]
|
218 |
+
|
219 |
+
x_in = torch.cat([x, prompt_x, cond], dim=-1) # 80+80+512=672 [2, 1863, 672]
|
220 |
+
|
221 |
+
if self.transformer_style_condition and not self.style_as_token: # True and True
|
222 |
+
x_in = torch.cat([x_in, style[:, None, :].repeat(1, T, 1)], dim=-1) #[2, 1863, 864]
|
223 |
+
|
224 |
+
if class_dropout: #False
|
225 |
+
x_in[..., self.in_channels:] = x_in[..., self.in_channels:] * 0 # 80维后全置为0
|
226 |
+
|
227 |
+
x_in = self.cond_x_merge_linear(x_in) # (N, T, D) [2, 1863, 512]
|
228 |
+
|
229 |
+
if self.style_as_token: # False
|
230 |
+
style = self.style_in(style)
|
231 |
+
style = torch.zeros_like(style) if class_dropout else style
|
232 |
+
x_in = torch.cat([style.unsqueeze(1), x_in], dim=1)
|
233 |
+
|
234 |
+
if self.time_as_token: # False
|
235 |
+
x_in = torch.cat([t1.unsqueeze(1), x_in], dim=1)
|
236 |
+
|
237 |
+
x_mask = sequence_mask(x_lens + self.style_as_token + self.time_as_token).to(x.device).unsqueeze(1) #torch.Size([1, 1, 1863])True
|
238 |
+
input_pos = self.input_pos[:x_in.size(1)] # (T,) range(0,1863)
|
239 |
+
x_mask_expanded = x_mask[:, None, :].repeat(1, 1, x_in.size(1), 1) if not self.is_causal else None # torch.Size([1, 1, 1863, 1863]
|
240 |
+
x_res = self.transformer(x_in, t1.unsqueeze(1), input_pos, x_mask_expanded) # [2, 1863, 512]
|
241 |
+
x_res = x_res[:, 1:] if self.time_as_token else x_res
|
242 |
+
x_res = x_res[:, 1:] if self.style_as_token else x_res
|
243 |
+
|
244 |
+
if self.long_skip_connection: #True
|
245 |
+
x_res = self.skip_linear(torch.cat([x_res, x], dim=-1))
|
246 |
+
if self.final_layer_type == 'wavenet':
|
247 |
+
x = self.conv1(x_res)
|
248 |
+
x = x.transpose(1, 2)
|
249 |
+
t2 = self.t_embedder2(t)
|
250 |
+
x = self.wavenet(x, x_mask, g=t2.unsqueeze(2)).transpose(1, 2) + self.res_projection(
|
251 |
+
x_res) # long residual connection
|
252 |
+
x = self.final_layer(x, t1).transpose(1, 2)
|
253 |
+
x = self.conv2(x)
|
254 |
+
else:
|
255 |
+
x = self.final_mlp(x_res)
|
256 |
+
x = x.transpose(1, 2)
|
257 |
+
# x [2,80,1863]
|
258 |
+
return x
|
indextts/s2mel/modules/.ipynb_checkpoints/flow_matching-checkpoint.py
ADDED
@@ -0,0 +1,171 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from abc import ABC
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn.functional as F
|
5 |
+
|
6 |
+
from modules.diffusion_transformer import DiT
|
7 |
+
from modules.commons import sequence_mask
|
8 |
+
|
9 |
+
from tqdm import tqdm
|
10 |
+
|
11 |
+
class BASECFM(torch.nn.Module, ABC):
|
12 |
+
def __init__(
|
13 |
+
self,
|
14 |
+
args,
|
15 |
+
):
|
16 |
+
super().__init__()
|
17 |
+
self.sigma_min = 1e-6
|
18 |
+
|
19 |
+
self.estimator = None
|
20 |
+
|
21 |
+
self.in_channels = args.DiT.in_channels
|
22 |
+
|
23 |
+
self.criterion = torch.nn.MSELoss() if args.reg_loss_type == "l2" else torch.nn.L1Loss()
|
24 |
+
|
25 |
+
if hasattr(args.DiT, 'zero_prompt_speech_token'):
|
26 |
+
self.zero_prompt_speech_token = args.DiT.zero_prompt_speech_token
|
27 |
+
else:
|
28 |
+
self.zero_prompt_speech_token = False
|
29 |
+
|
30 |
+
@torch.inference_mode()
|
31 |
+
def inference(self, mu, x_lens, prompt, style, f0, n_timesteps, temperature=1.0, inference_cfg_rate=0.5):
|
32 |
+
"""Forward diffusion
|
33 |
+
|
34 |
+
Args:
|
35 |
+
mu (torch.Tensor): semantic info of reference audio and altered audio
|
36 |
+
shape: (batch_size, mel_timesteps(795+1069), 512)
|
37 |
+
x_lens (torch.Tensor): mel frames output
|
38 |
+
shape: (batch_size, mel_timesteps)
|
39 |
+
prompt (torch.Tensor): reference mel
|
40 |
+
shape: (batch_size, 80, 795)
|
41 |
+
style (torch.Tensor): reference global style
|
42 |
+
shape: (batch_size, 192)
|
43 |
+
f0: None
|
44 |
+
n_timesteps (int): number of diffusion steps
|
45 |
+
temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
|
46 |
+
|
47 |
+
Returns:
|
48 |
+
sample: generated mel-spectrogram
|
49 |
+
shape: (batch_size, 80, mel_timesteps)
|
50 |
+
"""
|
51 |
+
B, T = mu.size(0), mu.size(1)
|
52 |
+
z = torch.randn([B, self.in_channels, T], device=mu.device) * temperature
|
53 |
+
t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device)
|
54 |
+
# t_span = t_span + (-1) * (torch.cos(torch.pi / 2 * t_span) - 1 + t_span)
|
55 |
+
return self.solve_euler(z, x_lens, prompt, mu, style, f0, t_span, inference_cfg_rate)
|
56 |
+
|
57 |
+
def solve_euler(self, x, x_lens, prompt, mu, style, f0, t_span, inference_cfg_rate=0.5):
|
58 |
+
"""
|
59 |
+
Fixed euler solver for ODEs.
|
60 |
+
Args:
|
61 |
+
x (torch.Tensor): random noise
|
62 |
+
t_span (torch.Tensor): n_timesteps interpolated
|
63 |
+
shape: (n_timesteps + 1,)
|
64 |
+
mu (torch.Tensor): semantic info of reference audio and altered audio
|
65 |
+
shape: (batch_size, mel_timesteps(795+1069), 512)
|
66 |
+
x_lens (torch.Tensor): mel frames output
|
67 |
+
shape: (batch_size, mel_timesteps)
|
68 |
+
prompt (torch.Tensor): reference mel
|
69 |
+
shape: (batch_size, 80, 795)
|
70 |
+
style (torch.Tensor): reference global style
|
71 |
+
shape: (batch_size, 192)
|
72 |
+
"""
|
73 |
+
t, _, _ = t_span[0], t_span[-1], t_span[1] - t_span[0]
|
74 |
+
|
75 |
+
# I am storing this because I can later plot it by putting a debugger here and saving it to a file
|
76 |
+
# Or in future might add like a return_all_steps flag
|
77 |
+
sol = []
|
78 |
+
# apply prompt
|
79 |
+
prompt_len = prompt.size(-1)
|
80 |
+
prompt_x = torch.zeros_like(x)
|
81 |
+
prompt_x[..., :prompt_len] = prompt[..., :prompt_len]
|
82 |
+
x[..., :prompt_len] = 0
|
83 |
+
if self.zero_prompt_speech_token:
|
84 |
+
mu[..., :prompt_len] = 0
|
85 |
+
for step in tqdm(range(1, len(t_span))):
|
86 |
+
dt = t_span[step] - t_span[step - 1]
|
87 |
+
if inference_cfg_rate > 0:
|
88 |
+
# Stack original and CFG (null) inputs for batched processing
|
89 |
+
stacked_prompt_x = torch.cat([prompt_x, torch.zeros_like(prompt_x)], dim=0)
|
90 |
+
stacked_style = torch.cat([style, torch.zeros_like(style)], dim=0)
|
91 |
+
stacked_mu = torch.cat([mu, torch.zeros_like(mu)], dim=0)
|
92 |
+
stacked_x = torch.cat([x, x], dim=0)
|
93 |
+
stacked_t = torch.cat([t.unsqueeze(0), t.unsqueeze(0)], dim=0)
|
94 |
+
|
95 |
+
# Perform a single forward pass for both original and CFG inputs
|
96 |
+
stacked_dphi_dt = self.estimator(
|
97 |
+
stacked_x, stacked_prompt_x, x_lens, stacked_t, stacked_style, stacked_mu,
|
98 |
+
)
|
99 |
+
|
100 |
+
# Split the output back into the original and CFG components
|
101 |
+
dphi_dt, cfg_dphi_dt = stacked_dphi_dt.chunk(2, dim=0)
|
102 |
+
|
103 |
+
# Apply CFG formula
|
104 |
+
dphi_dt = (1.0 + inference_cfg_rate) * dphi_dt - inference_cfg_rate * cfg_dphi_dt
|
105 |
+
else:
|
106 |
+
dphi_dt = self.estimator(x, prompt_x, x_lens, t.unsqueeze(0), style, mu)
|
107 |
+
|
108 |
+
x = x + dt * dphi_dt
|
109 |
+
t = t + dt
|
110 |
+
sol.append(x)
|
111 |
+
if step < len(t_span) - 1:
|
112 |
+
dt = t_span[step + 1] - t
|
113 |
+
x[:, :, :prompt_len] = 0
|
114 |
+
|
115 |
+
return sol[-1]
|
116 |
+
def forward(self, x1, x_lens, prompt_lens, mu, style):
|
117 |
+
"""Computes diffusion loss
|
118 |
+
|
119 |
+
Args:
|
120 |
+
mu (torch.Tensor): semantic info of reference audio and altered audio
|
121 |
+
shape: (batch_size, mel_timesteps(795+1069), 512)
|
122 |
+
x1: mel
|
123 |
+
x_lens (torch.Tensor): mel frames output
|
124 |
+
shape: (batch_size, mel_timesteps)
|
125 |
+
prompt (torch.Tensor): reference mel
|
126 |
+
shape: (batch_size, 80, 795)
|
127 |
+
style (torch.Tensor): reference global style
|
128 |
+
shape: (batch_size, 192)
|
129 |
+
|
130 |
+
Returns:
|
131 |
+
loss: conditional flow matching loss
|
132 |
+
y: conditional flow
|
133 |
+
shape: (batch_size, n_feats, mel_timesteps)
|
134 |
+
"""
|
135 |
+
b, _, t = x1.shape
|
136 |
+
|
137 |
+
# random timestep
|
138 |
+
t = torch.rand([b, 1, 1], device=mu.device, dtype=x1.dtype)
|
139 |
+
# sample noise p(x_0)
|
140 |
+
z = torch.randn_like(x1)
|
141 |
+
|
142 |
+
y = (1 - (1 - self.sigma_min) * t) * z + t * x1
|
143 |
+
u = x1 - (1 - self.sigma_min) * z
|
144 |
+
|
145 |
+
prompt = torch.zeros_like(x1)
|
146 |
+
for bib in range(b):
|
147 |
+
prompt[bib, :, :prompt_lens[bib]] = x1[bib, :, :prompt_lens[bib]]
|
148 |
+
# range covered by prompt are set to 0
|
149 |
+
y[bib, :, :prompt_lens[bib]] = 0
|
150 |
+
if self.zero_prompt_speech_token:
|
151 |
+
mu[bib, :, :prompt_lens[bib]] = 0
|
152 |
+
|
153 |
+
estimator_out = self.estimator(y, prompt, x_lens, t.squeeze(1).squeeze(1), style, mu, prompt_lens)
|
154 |
+
loss = 0
|
155 |
+
for bib in range(b):
|
156 |
+
loss += self.criterion(estimator_out[bib, :, prompt_lens[bib]:x_lens[bib]], u[bib, :, prompt_lens[bib]:x_lens[bib]])
|
157 |
+
loss /= b
|
158 |
+
|
159 |
+
return loss, estimator_out + (1 - self.sigma_min) * z
|
160 |
+
|
161 |
+
|
162 |
+
|
163 |
+
class CFM(BASECFM):
|
164 |
+
def __init__(self, args):
|
165 |
+
super().__init__(
|
166 |
+
args
|
167 |
+
)
|
168 |
+
if args.dit_type == "DiT":
|
169 |
+
self.estimator = DiT(args)
|
170 |
+
else:
|
171 |
+
raise NotImplementedError(f"Unknown diffusion type {args.dit_type}")
|
indextts/s2mel/modules/.ipynb_checkpoints/length_regulator-checkpoint.py
ADDED
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Tuple
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
from torch.nn import functional as F
|
5 |
+
from modules.commons import sequence_mask
|
6 |
+
import numpy as np
|
7 |
+
from dac.nn.quantize import VectorQuantize
|
8 |
+
|
9 |
+
# f0_bin = 256
|
10 |
+
f0_max = 1100.0
|
11 |
+
f0_min = 50.0
|
12 |
+
f0_mel_min = 1127 * np.log(1 + f0_min / 700)
|
13 |
+
f0_mel_max = 1127 * np.log(1 + f0_max / 700)
|
14 |
+
|
15 |
+
def f0_to_coarse(f0, f0_bin):
|
16 |
+
f0_mel = 1127 * (1 + f0 / 700).log()
|
17 |
+
a = (f0_bin - 2) / (f0_mel_max - f0_mel_min)
|
18 |
+
b = f0_mel_min * a - 1.
|
19 |
+
f0_mel = torch.where(f0_mel > 0, f0_mel * a - b, f0_mel)
|
20 |
+
# torch.clip_(f0_mel, min=1., max=float(f0_bin - 1))
|
21 |
+
f0_coarse = torch.round(f0_mel).long()
|
22 |
+
f0_coarse = f0_coarse * (f0_coarse > 0)
|
23 |
+
f0_coarse = f0_coarse + ((f0_coarse < 1) * 1)
|
24 |
+
f0_coarse = f0_coarse * (f0_coarse < f0_bin)
|
25 |
+
f0_coarse = f0_coarse + ((f0_coarse >= f0_bin) * (f0_bin - 1))
|
26 |
+
return f0_coarse
|
27 |
+
|
28 |
+
class InterpolateRegulator(nn.Module):
|
29 |
+
def __init__(
|
30 |
+
self,
|
31 |
+
channels: int,
|
32 |
+
sampling_ratios: Tuple,
|
33 |
+
is_discrete: bool = False,
|
34 |
+
in_channels: int = None, # only applies to continuous input
|
35 |
+
vector_quantize: bool = False, # whether to use vector quantization, only applies to continuous input
|
36 |
+
codebook_size: int = 1024, # for discrete only
|
37 |
+
out_channels: int = None,
|
38 |
+
groups: int = 1,
|
39 |
+
n_codebooks: int = 1, # number of codebooks
|
40 |
+
quantizer_dropout: float = 0.0, # dropout for quantizer
|
41 |
+
f0_condition: bool = False,
|
42 |
+
n_f0_bins: int = 512,
|
43 |
+
):
|
44 |
+
super().__init__()
|
45 |
+
self.sampling_ratios = sampling_ratios
|
46 |
+
out_channels = out_channels or channels
|
47 |
+
model = nn.ModuleList([])
|
48 |
+
if len(sampling_ratios) > 0:
|
49 |
+
self.interpolate = True
|
50 |
+
for _ in sampling_ratios:
|
51 |
+
module = nn.Conv1d(channels, channels, 3, 1, 1)
|
52 |
+
norm = nn.GroupNorm(groups, channels)
|
53 |
+
act = nn.Mish()
|
54 |
+
model.extend([module, norm, act])
|
55 |
+
else:
|
56 |
+
self.interpolate = False
|
57 |
+
model.append(
|
58 |
+
nn.Conv1d(channels, out_channels, 1, 1)
|
59 |
+
)
|
60 |
+
self.model = nn.Sequential(*model)
|
61 |
+
self.embedding = nn.Embedding(codebook_size, channels)
|
62 |
+
self.is_discrete = is_discrete
|
63 |
+
|
64 |
+
self.mask_token = nn.Parameter(torch.zeros(1, channels))
|
65 |
+
|
66 |
+
self.n_codebooks = n_codebooks
|
67 |
+
if n_codebooks > 1:
|
68 |
+
self.extra_codebooks = nn.ModuleList([
|
69 |
+
nn.Embedding(codebook_size, channels) for _ in range(n_codebooks - 1)
|
70 |
+
])
|
71 |
+
self.extra_codebook_mask_tokens = nn.ParameterList([
|
72 |
+
nn.Parameter(torch.zeros(1, channels)) for _ in range(n_codebooks - 1)
|
73 |
+
])
|
74 |
+
self.quantizer_dropout = quantizer_dropout
|
75 |
+
|
76 |
+
if f0_condition:
|
77 |
+
self.f0_embedding = nn.Embedding(n_f0_bins, channels)
|
78 |
+
self.f0_condition = f0_condition
|
79 |
+
self.n_f0_bins = n_f0_bins
|
80 |
+
self.f0_bins = torch.arange(2, 1024, 1024 // n_f0_bins)
|
81 |
+
self.f0_mask = nn.Parameter(torch.zeros(1, channels))
|
82 |
+
else:
|
83 |
+
self.f0_condition = False
|
84 |
+
|
85 |
+
if not is_discrete:
|
86 |
+
self.content_in_proj = nn.Linear(in_channels, channels)
|
87 |
+
if vector_quantize:
|
88 |
+
self.vq = VectorQuantize(channels, codebook_size, 8)
|
89 |
+
|
90 |
+
def forward(self, x, ylens=None, n_quantizers=None, f0=None):
|
91 |
+
# apply token drop
|
92 |
+
if self.training:
|
93 |
+
n_quantizers = torch.ones((x.shape[0],)) * self.n_codebooks
|
94 |
+
dropout = torch.randint(1, self.n_codebooks + 1, (x.shape[0],))
|
95 |
+
n_dropout = int(x.shape[0] * self.quantizer_dropout)
|
96 |
+
n_quantizers[:n_dropout] = dropout[:n_dropout]
|
97 |
+
n_quantizers = n_quantizers.to(x.device)
|
98 |
+
# decide whether to drop for each sample in batch
|
99 |
+
else:
|
100 |
+
n_quantizers = torch.ones((x.shape[0],), device=x.device) * (self.n_codebooks if n_quantizers is None else n_quantizers)
|
101 |
+
if self.is_discrete:
|
102 |
+
if self.n_codebooks > 1:
|
103 |
+
assert len(x.size()) == 3
|
104 |
+
x_emb = self.embedding(x[:, 0])
|
105 |
+
for i, emb in enumerate(self.extra_codebooks):
|
106 |
+
x_emb = x_emb + (n_quantizers > i+1)[..., None, None] * emb(x[:, i+1])
|
107 |
+
# add mask token if not using this codebook
|
108 |
+
# x_emb = x_emb + (n_quantizers <= i+1)[..., None, None] * self.extra_codebook_mask_tokens[i]
|
109 |
+
x = x_emb
|
110 |
+
elif self.n_codebooks == 1:
|
111 |
+
if len(x.size()) == 2:
|
112 |
+
x = self.embedding(x)
|
113 |
+
else:
|
114 |
+
x = self.embedding(x[:, 0])
|
115 |
+
else:
|
116 |
+
x = self.content_in_proj(x)
|
117 |
+
# x in (B, T, D)
|
118 |
+
mask = sequence_mask(ylens).unsqueeze(-1)
|
119 |
+
if self.interpolate:
|
120 |
+
x = F.interpolate(x.transpose(1, 2).contiguous(), size=ylens.max(), mode='nearest')
|
121 |
+
else:
|
122 |
+
x = x.transpose(1, 2).contiguous()
|
123 |
+
mask = mask[:, :x.size(2), :]
|
124 |
+
ylens = ylens.clamp(max=x.size(2)).long()
|
125 |
+
if self.f0_condition:
|
126 |
+
if f0 is None:
|
127 |
+
x = x + self.f0_mask.unsqueeze(-1)
|
128 |
+
else:
|
129 |
+
#quantized_f0 = torch.bucketize(f0, self.f0_bins.to(f0.device)) # (N, T)
|
130 |
+
quantized_f0 = f0_to_coarse(f0, self.n_f0_bins)
|
131 |
+
quantized_f0 = quantized_f0.clamp(0, self.n_f0_bins - 1).long()
|
132 |
+
f0_emb = self.f0_embedding(quantized_f0)
|
133 |
+
f0_emb = F.interpolate(f0_emb.transpose(1, 2).contiguous(), size=ylens.max(), mode='nearest')
|
134 |
+
x = x + f0_emb
|
135 |
+
out = self.model(x).transpose(1, 2).contiguous()
|
136 |
+
if hasattr(self, 'vq'):
|
137 |
+
out_q, commitment_loss, codebook_loss, codes, out, = self.vq(out.transpose(1, 2))
|
138 |
+
out_q = out_q.transpose(1, 2)
|
139 |
+
return out_q * mask, ylens, codes, commitment_loss, codebook_loss
|
140 |
+
olens = ylens
|
141 |
+
return out * mask, olens, None, None, None
|
webui.py
CHANGED
@@ -38,7 +38,9 @@ from modelscope.hub import api
|
|
38 |
|
39 |
i18n = I18nAuto(language="Auto")
|
40 |
MODE = 'local'
|
41 |
-
tts = IndexTTS2(model_dir=cmd_args.model_dir,
|
|
|
|
|
42 |
|
43 |
# 支持的语言列表
|
44 |
LANGUAGES = {
|
|
|
38 |
|
39 |
i18n = I18nAuto(language="Auto")
|
40 |
MODE = 'local'
|
41 |
+
tts = IndexTTS2(model_dir=cmd_args.model_dir,
|
42 |
+
cfg_path=os.path.join(cmd_args.model_dir, "config.yaml"),
|
43 |
+
is_fp16=False,use_cuda_kernel=False)
|
44 |
|
45 |
# 支持的语言列表
|
46 |
LANGUAGES = {
|