kemuriririn commited on
Commit
04be12f
·
1 Parent(s): ff769a6

disable deepspeed and cuda kernel

Browse files
indextts/infer_v2.py CHANGED
@@ -35,7 +35,7 @@ import torch.nn.functional as F
35
  class IndexTTS2:
36
  def __init__(
37
  self, cfg_path="checkpoints/config.yaml", model_dir="checkpoints", is_fp16=False, device=None,
38
- use_cuda_kernel=None,
39
  ):
40
  """
41
  Args:
@@ -83,14 +83,13 @@ class IndexTTS2:
83
  try:
84
  import deepspeed
85
 
86
- use_deepspeed = True
87
  except (ImportError, OSError, CalledProcessError) as e:
88
  use_deepspeed = False
89
  print(f">> DeepSpeed加载失败,回退到标准推理: {e}")
90
 
91
  self.gpt.post_init_gpt2_config(use_deepspeed=use_deepspeed, kv_cache=True, half=True)
92
  else:
93
- self.gpt.post_init_gpt2_config(use_deepspeed=True, kv_cache=True, half=False)
94
 
95
  if self.use_cuda_kernel:
96
  # preload the CUDA kernel for BigVGAN
 
35
  class IndexTTS2:
36
  def __init__(
37
  self, cfg_path="checkpoints/config.yaml", model_dir="checkpoints", is_fp16=False, device=None,
38
+ use_cuda_kernel=None,use_deepspeed=False
39
  ):
40
  """
41
  Args:
 
83
  try:
84
  import deepspeed
85
 
 
86
  except (ImportError, OSError, CalledProcessError) as e:
87
  use_deepspeed = False
88
  print(f">> DeepSpeed加载失败,回退到标准推理: {e}")
89
 
90
  self.gpt.post_init_gpt2_config(use_deepspeed=use_deepspeed, kv_cache=True, half=True)
91
  else:
92
+ self.gpt.post_init_gpt2_config(use_deepspeed=use_deepspeed, kv_cache=True, half=False)
93
 
94
  if self.use_cuda_kernel:
95
  # preload the CUDA kernel for BigVGAN
indextts/s2mel/modules/.ipynb_checkpoints/audio-checkpoint.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.utils.data
4
+ from librosa.filters import mel as librosa_mel_fn
5
+ from scipy.io.wavfile import read
6
+
7
+ MAX_WAV_VALUE = 32768.0
8
+
9
+
10
+ def load_wav(full_path):
11
+ sampling_rate, data = read(full_path)
12
+ return data, sampling_rate
13
+
14
+
15
+ def dynamic_range_compression(x, C=1, clip_val=1e-5):
16
+ return np.log(np.clip(x, a_min=clip_val, a_max=None) * C)
17
+
18
+
19
+ def dynamic_range_decompression(x, C=1):
20
+ return np.exp(x) / C
21
+
22
+
23
+ def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
24
+ return torch.log(torch.clamp(x, min=clip_val) * C)
25
+
26
+
27
+ def dynamic_range_decompression_torch(x, C=1):
28
+ return torch.exp(x) / C
29
+
30
+
31
+ def spectral_normalize_torch(magnitudes):
32
+ output = dynamic_range_compression_torch(magnitudes)
33
+ return output
34
+
35
+
36
+ def spectral_de_normalize_torch(magnitudes):
37
+ output = dynamic_range_decompression_torch(magnitudes)
38
+ return output
39
+
40
+
41
+ mel_basis = {}
42
+ hann_window = {}
43
+
44
+
45
+ def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False):
46
+ # if torch.min(y) < -1.0:
47
+ # print("min value is ", torch.min(y))
48
+ # if torch.max(y) > 1.0:
49
+ # print("max value is ", torch.max(y))
50
+
51
+ global mel_basis, hann_window # pylint: disable=global-statement
52
+ if f"{str(sampling_rate)}_{str(fmax)}_{str(y.device)}" not in mel_basis:
53
+ mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
54
+ mel_basis[str(sampling_rate) + "_" + str(fmax) + "_" + str(y.device)] = torch.from_numpy(mel).float().to(y.device)
55
+ hann_window[str(sampling_rate) + "_" + str(y.device)] = torch.hann_window(win_size).to(y.device)
56
+
57
+ y = torch.nn.functional.pad(
58
+ y.unsqueeze(1), (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), mode="reflect"
59
+ )
60
+ y = y.squeeze(1)
61
+
62
+ spec = torch.view_as_real(
63
+ torch.stft(
64
+ y,
65
+ n_fft,
66
+ hop_length=hop_size,
67
+ win_length=win_size,
68
+ window=hann_window[str(sampling_rate) + "_" + str(y.device)],
69
+ center=center,
70
+ pad_mode="reflect",
71
+ normalized=False,
72
+ onesided=True,
73
+ return_complex=True,
74
+ )
75
+ )
76
+
77
+ spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9))
78
+
79
+ spec = torch.matmul(mel_basis[str(sampling_rate) + "_" + str(fmax) + "_" + str(y.device)], spec)
80
+ spec = spectral_normalize_torch(spec)
81
+
82
+ return spec
indextts/s2mel/modules/.ipynb_checkpoints/commons-checkpoint.py ADDED
@@ -0,0 +1,610 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import numpy as np
3
+ import torch
4
+ from torch import nn
5
+ from torch.nn import functional as F
6
+ from munch import Munch
7
+ import json
8
+ import argparse
9
+ from torch.nn.parallel import DistributedDataParallel as DDP
10
+
11
+ def str2bool(v):
12
+ if isinstance(v, bool):
13
+ return v
14
+ if v.lower() in ("yes", "true", "t", "y", "1"):
15
+ return True
16
+ elif v.lower() in ("no", "false", "f", "n", "0"):
17
+ return False
18
+ else:
19
+ raise argparse.ArgumentTypeError("Boolean value expected.")
20
+
21
+ class AttrDict(dict):
22
+ def __init__(self, *args, **kwargs):
23
+ super(AttrDict, self).__init__(*args, **kwargs)
24
+ self.__dict__ = self
25
+
26
+
27
+ def init_weights(m, mean=0.0, std=0.01):
28
+ classname = m.__class__.__name__
29
+ if classname.find("Conv") != -1:
30
+ m.weight.data.normal_(mean, std)
31
+
32
+
33
+ def get_padding(kernel_size, dilation=1):
34
+ return int((kernel_size * dilation - dilation) / 2)
35
+
36
+
37
+ def convert_pad_shape(pad_shape):
38
+ l = pad_shape[::-1]
39
+ pad_shape = [item for sublist in l for item in sublist]
40
+ return pad_shape
41
+
42
+
43
+ def intersperse(lst, item):
44
+ result = [item] * (len(lst) * 2 + 1)
45
+ result[1::2] = lst
46
+ return result
47
+
48
+
49
+ def kl_divergence(m_p, logs_p, m_q, logs_q):
50
+ """KL(P||Q)"""
51
+ kl = (logs_q - logs_p) - 0.5
52
+ kl += (
53
+ 0.5 * (torch.exp(2.0 * logs_p) + ((m_p - m_q) ** 2)) * torch.exp(-2.0 * logs_q)
54
+ )
55
+ return kl
56
+
57
+
58
+ def rand_gumbel(shape):
59
+ """Sample from the Gumbel distribution, protect from overflows."""
60
+ uniform_samples = torch.rand(shape) * 0.99998 + 0.00001
61
+ return -torch.log(-torch.log(uniform_samples))
62
+
63
+
64
+ def rand_gumbel_like(x):
65
+ g = rand_gumbel(x.size()).to(dtype=x.dtype, device=x.device)
66
+ return g
67
+
68
+
69
+ def slice_segments(x, ids_str, segment_size=4):
70
+ ret = torch.zeros_like(x[:, :, :segment_size])
71
+ for i in range(x.size(0)):
72
+ idx_str = ids_str[i]
73
+ idx_end = idx_str + segment_size
74
+ ret[i] = x[i, :, idx_str:idx_end]
75
+ return ret
76
+
77
+
78
+ def slice_segments_audio(x, ids_str, segment_size=4):
79
+ ret = torch.zeros_like(x[:, :segment_size])
80
+ for i in range(x.size(0)):
81
+ idx_str = ids_str[i]
82
+ idx_end = idx_str + segment_size
83
+ ret[i] = x[i, idx_str:idx_end]
84
+ return ret
85
+
86
+
87
+ def rand_slice_segments(x, x_lengths=None, segment_size=4):
88
+ b, d, t = x.size()
89
+ if x_lengths is None:
90
+ x_lengths = t
91
+ ids_str_max = x_lengths - segment_size + 1
92
+ ids_str = ((torch.rand([b]).to(device=x.device) * ids_str_max).clip(0)).to(
93
+ dtype=torch.long
94
+ )
95
+ ret = slice_segments(x, ids_str, segment_size)
96
+ return ret, ids_str
97
+
98
+
99
+ def get_timing_signal_1d(length, channels, min_timescale=1.0, max_timescale=1.0e4):
100
+ position = torch.arange(length, dtype=torch.float)
101
+ num_timescales = channels // 2
102
+ log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / (
103
+ num_timescales - 1
104
+ )
105
+ inv_timescales = min_timescale * torch.exp(
106
+ torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment
107
+ )
108
+ scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1)
109
+ signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0)
110
+ signal = F.pad(signal, [0, 0, 0, channels % 2])
111
+ signal = signal.view(1, channels, length)
112
+ return signal
113
+
114
+
115
+ def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4):
116
+ b, channels, length = x.size()
117
+ signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
118
+ return x + signal.to(dtype=x.dtype, device=x.device)
119
+
120
+
121
+ def cat_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4, axis=1):
122
+ b, channels, length = x.size()
123
+ signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
124
+ return torch.cat([x, signal.to(dtype=x.dtype, device=x.device)], axis)
125
+
126
+
127
+ def subsequent_mask(length):
128
+ mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0)
129
+ return mask
130
+
131
+
132
+ @torch.jit.script
133
+ def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
134
+ n_channels_int = n_channels[0]
135
+ in_act = input_a + input_b
136
+ t_act = torch.tanh(in_act[:, :n_channels_int, :])
137
+ s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
138
+ acts = t_act * s_act
139
+ return acts
140
+
141
+
142
+ def convert_pad_shape(pad_shape):
143
+ l = pad_shape[::-1]
144
+ pad_shape = [item for sublist in l for item in sublist]
145
+ return pad_shape
146
+
147
+
148
+ def shift_1d(x):
149
+ x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1]
150
+ return x
151
+
152
+
153
+ def sequence_mask(length, max_length=None):
154
+ if max_length is None:
155
+ max_length = length.max()
156
+ x = torch.arange(max_length, dtype=length.dtype, device=length.device)
157
+ return x.unsqueeze(0) < length.unsqueeze(1)
158
+
159
+
160
+ def avg_with_mask(x, mask):
161
+ assert mask.dtype == torch.float, "Mask should be float"
162
+
163
+ if mask.ndim == 2:
164
+ mask = mask.unsqueeze(1)
165
+
166
+ if mask.shape[1] == 1:
167
+ mask = mask.expand_as(x)
168
+
169
+ return (x * mask).sum() / mask.sum()
170
+
171
+
172
+ def generate_path(duration, mask):
173
+ """
174
+ duration: [b, 1, t_x]
175
+ mask: [b, 1, t_y, t_x]
176
+ """
177
+ device = duration.device
178
+
179
+ b, _, t_y, t_x = mask.shape
180
+ cum_duration = torch.cumsum(duration, -1)
181
+
182
+ cum_duration_flat = cum_duration.view(b * t_x)
183
+ path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype)
184
+ path = path.view(b, t_x, t_y)
185
+ path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1]
186
+ path = path.unsqueeze(1).transpose(2, 3) * mask
187
+ return path
188
+
189
+
190
+ def clip_grad_value_(parameters, clip_value, norm_type=2):
191
+ if isinstance(parameters, torch.Tensor):
192
+ parameters = [parameters]
193
+ parameters = list(filter(lambda p: p.grad is not None, parameters))
194
+ norm_type = float(norm_type)
195
+ if clip_value is not None:
196
+ clip_value = float(clip_value)
197
+
198
+ total_norm = 0
199
+ for p in parameters:
200
+ param_norm = p.grad.data.norm(norm_type)
201
+ total_norm += param_norm.item() ** norm_type
202
+ if clip_value is not None:
203
+ p.grad.data.clamp_(min=-clip_value, max=clip_value)
204
+ total_norm = total_norm ** (1.0 / norm_type)
205
+ return total_norm
206
+
207
+
208
+ def log_norm(x, mean=-4, std=4, dim=2):
209
+ """
210
+ normalized log mel -> mel -> norm -> log(norm)
211
+ """
212
+ x = torch.log(torch.exp(x * std + mean).norm(dim=dim))
213
+ return x
214
+
215
+
216
+ def load_F0_models(path):
217
+ # load F0 model
218
+ from .JDC.model import JDCNet
219
+
220
+ F0_model = JDCNet(num_class=1, seq_len=192)
221
+ params = torch.load(path, map_location="cpu")["net"]
222
+ F0_model.load_state_dict(params)
223
+ _ = F0_model.train()
224
+
225
+ return F0_model
226
+
227
+
228
+ def modify_w2v_forward(self, output_layer=15):
229
+ """
230
+ change forward method of w2v encoder to get its intermediate layer output
231
+ :param self:
232
+ :param layer:
233
+ :return:
234
+ """
235
+ from transformers.modeling_outputs import BaseModelOutput
236
+
237
+ def forward(
238
+ hidden_states,
239
+ attention_mask=None,
240
+ output_attentions=False,
241
+ output_hidden_states=False,
242
+ return_dict=True,
243
+ ):
244
+ all_hidden_states = () if output_hidden_states else None
245
+ all_self_attentions = () if output_attentions else None
246
+
247
+ conv_attention_mask = attention_mask
248
+ if attention_mask is not None:
249
+ # make sure padded tokens output 0
250
+ hidden_states = hidden_states.masked_fill(
251
+ ~attention_mask.bool().unsqueeze(-1), 0.0
252
+ )
253
+
254
+ # extend attention_mask
255
+ attention_mask = 1.0 - attention_mask[:, None, None, :].to(
256
+ dtype=hidden_states.dtype
257
+ )
258
+ attention_mask = attention_mask * torch.finfo(hidden_states.dtype).min
259
+ attention_mask = attention_mask.expand(
260
+ attention_mask.shape[0],
261
+ 1,
262
+ attention_mask.shape[-1],
263
+ attention_mask.shape[-1],
264
+ )
265
+
266
+ hidden_states = self.dropout(hidden_states)
267
+
268
+ if self.embed_positions is not None:
269
+ relative_position_embeddings = self.embed_positions(hidden_states)
270
+ else:
271
+ relative_position_embeddings = None
272
+
273
+ deepspeed_zero3_is_enabled = False
274
+
275
+ for i, layer in enumerate(self.layers):
276
+ if output_hidden_states:
277
+ all_hidden_states = all_hidden_states + (hidden_states,)
278
+
279
+ # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
280
+ dropout_probability = torch.rand([])
281
+
282
+ skip_the_layer = (
283
+ True
284
+ if self.training and (dropout_probability < self.config.layerdrop)
285
+ else False
286
+ )
287
+ if not skip_the_layer or deepspeed_zero3_is_enabled:
288
+ # under deepspeed zero3 all gpus must run in sync
289
+ if self.gradient_checkpointing and self.training:
290
+ layer_outputs = self._gradient_checkpointing_func(
291
+ layer.__call__,
292
+ hidden_states,
293
+ attention_mask,
294
+ relative_position_embeddings,
295
+ output_attentions,
296
+ conv_attention_mask,
297
+ )
298
+ else:
299
+ layer_outputs = layer(
300
+ hidden_states,
301
+ attention_mask=attention_mask,
302
+ relative_position_embeddings=relative_position_embeddings,
303
+ output_attentions=output_attentions,
304
+ conv_attention_mask=conv_attention_mask,
305
+ )
306
+ hidden_states = layer_outputs[0]
307
+
308
+ if skip_the_layer:
309
+ layer_outputs = (None, None)
310
+
311
+ if output_attentions:
312
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
313
+
314
+ if i == output_layer - 1:
315
+ break
316
+
317
+ if output_hidden_states:
318
+ all_hidden_states = all_hidden_states + (hidden_states,)
319
+
320
+ if not return_dict:
321
+ return tuple(
322
+ v
323
+ for v in [hidden_states, all_hidden_states, all_self_attentions]
324
+ if v is not None
325
+ )
326
+ return BaseModelOutput(
327
+ last_hidden_state=hidden_states,
328
+ hidden_states=all_hidden_states,
329
+ attentions=all_self_attentions,
330
+ )
331
+
332
+ return forward
333
+
334
+
335
+ MATPLOTLIB_FLAG = False
336
+
337
+
338
+ def plot_spectrogram_to_numpy(spectrogram):
339
+ global MATPLOTLIB_FLAG
340
+ if not MATPLOTLIB_FLAG:
341
+ import matplotlib
342
+ import logging
343
+
344
+ matplotlib.use("Agg")
345
+ MATPLOTLIB_FLAG = True
346
+ mpl_logger = logging.getLogger("matplotlib")
347
+ mpl_logger.setLevel(logging.WARNING)
348
+ import matplotlib.pylab as plt
349
+ import numpy as np
350
+
351
+ fig, ax = plt.subplots(figsize=(10, 2))
352
+ im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none")
353
+ plt.colorbar(im, ax=ax)
354
+ plt.xlabel("Frames")
355
+ plt.ylabel("Channels")
356
+ plt.tight_layout()
357
+
358
+ fig.canvas.draw()
359
+ data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="")
360
+ data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
361
+ plt.close()
362
+ return data
363
+
364
+
365
+ def normalize_f0(f0_sequence):
366
+ # Remove unvoiced frames (replace with -1)
367
+ voiced_indices = np.where(f0_sequence > 0)[0]
368
+ f0_voiced = f0_sequence[voiced_indices]
369
+
370
+ # Convert to log scale
371
+ log_f0 = np.log2(f0_voiced)
372
+
373
+ # Calculate mean and standard deviation
374
+ mean_f0 = np.mean(log_f0)
375
+ std_f0 = np.std(log_f0)
376
+
377
+ # Normalize the F0 sequence
378
+ normalized_f0 = (log_f0 - mean_f0) / std_f0
379
+
380
+ # Create the normalized F0 sequence with unvoiced frames
381
+ normalized_sequence = np.zeros_like(f0_sequence)
382
+ normalized_sequence[voiced_indices] = normalized_f0
383
+ normalized_sequence[f0_sequence <= 0] = -1 # Assign -1 to unvoiced frames
384
+
385
+ return normalized_sequence
386
+
387
+
388
+ class MyModel(nn.Module):
389
+ def __init__(self,args):
390
+ super(MyModel, self).__init__()
391
+ from modules.flow_matching import CFM
392
+ from modules.length_regulator import InterpolateRegulator
393
+
394
+ length_regulator = InterpolateRegulator(
395
+ channels=args.length_regulator.channels,
396
+ sampling_ratios=args.length_regulator.sampling_ratios,
397
+ is_discrete=args.length_regulator.is_discrete,
398
+ in_channels=args.length_regulator.in_channels if hasattr(args.length_regulator, "in_channels") else None,
399
+ vector_quantize=args.length_regulator.vector_quantize if hasattr(args.length_regulator, "vector_quantize") else False,
400
+ codebook_size=args.length_regulator.content_codebook_size,
401
+ n_codebooks=args.length_regulator.n_codebooks if hasattr(args.length_regulator, "n_codebooks") else 1,
402
+ quantizer_dropout=args.length_regulator.quantizer_dropout if hasattr(args.length_regulator, "quantizer_dropout") else 0.0,
403
+ f0_condition=args.length_regulator.f0_condition if hasattr(args.length_regulator, "f0_condition") else False,
404
+ n_f0_bins=args.length_regulator.n_f0_bins if hasattr(args.length_regulator, "n_f0_bins") else 512,
405
+ )
406
+
407
+ self.models = nn.ModuleDict({
408
+ 'cfm': CFM(args),
409
+ 'length_regulator': length_regulator
410
+ })
411
+
412
+ def forward(self, x, target_lengths, prompt_len, cond, y):
413
+ x = self.models['cfm'](x, target_lengths, prompt_len, cond, y)
414
+ return x
415
+
416
+ def forward2(self, S_ori,target_lengths,F0_ori):
417
+ x = self.models['length_regulator'](S_ori, ylens=target_lengths, f0=F0_ori)
418
+ return x
419
+
420
+ def build_model(args, stage="DiT"):
421
+ if stage == "DiT":
422
+ from modules.flow_matching import CFM
423
+ from modules.length_regulator import InterpolateRegulator
424
+
425
+ length_regulator = InterpolateRegulator(
426
+ channels=args.length_regulator.channels,
427
+ sampling_ratios=args.length_regulator.sampling_ratios,
428
+ is_discrete=args.length_regulator.is_discrete,
429
+ in_channels=args.length_regulator.in_channels if hasattr(args.length_regulator, "in_channels") else None,
430
+ vector_quantize=args.length_regulator.vector_quantize if hasattr(args.length_regulator, "vector_quantize") else False,
431
+ codebook_size=args.length_regulator.content_codebook_size,
432
+ n_codebooks=args.length_regulator.n_codebooks if hasattr(args.length_regulator, "n_codebooks") else 1,
433
+ quantizer_dropout=args.length_regulator.quantizer_dropout if hasattr(args.length_regulator, "quantizer_dropout") else 0.0,
434
+ f0_condition=args.length_regulator.f0_condition if hasattr(args.length_regulator, "f0_condition") else False,
435
+ n_f0_bins=args.length_regulator.n_f0_bins if hasattr(args.length_regulator, "n_f0_bins") else 512,
436
+ )
437
+ cfm = CFM(args)
438
+ nets = Munch(
439
+ cfm=cfm,
440
+ length_regulator=length_regulator,
441
+ )
442
+
443
+ elif stage == 'codec':
444
+ from dac.model.dac import Encoder
445
+ from modules.quantize import (
446
+ FAquantizer,
447
+ )
448
+
449
+ encoder = Encoder(
450
+ d_model=args.DAC.encoder_dim,
451
+ strides=args.DAC.encoder_rates,
452
+ d_latent=1024,
453
+ causal=args.causal,
454
+ lstm=args.lstm,
455
+ )
456
+
457
+ quantizer = FAquantizer(
458
+ in_dim=1024,
459
+ n_p_codebooks=1,
460
+ n_c_codebooks=args.n_c_codebooks,
461
+ n_t_codebooks=2,
462
+ n_r_codebooks=3,
463
+ codebook_size=1024,
464
+ codebook_dim=8,
465
+ quantizer_dropout=0.5,
466
+ causal=args.causal,
467
+ separate_prosody_encoder=args.separate_prosody_encoder,
468
+ timbre_norm=args.timbre_norm,
469
+ )
470
+
471
+ nets = Munch(
472
+ encoder=encoder,
473
+ quantizer=quantizer,
474
+ )
475
+
476
+ elif stage == "mel_vocos":
477
+ from modules.vocos import Vocos
478
+ decoder = Vocos(args)
479
+ nets = Munch(
480
+ decoder=decoder,
481
+ )
482
+
483
+ else:
484
+ raise ValueError(f"Unknown stage: {stage}")
485
+
486
+ return nets
487
+
488
+
489
+ def load_checkpoint(
490
+ model,
491
+ optimizer,
492
+ path,
493
+ load_only_params=True,
494
+ ignore_modules=[],
495
+ is_distributed=False,
496
+ load_ema=False,
497
+ ):
498
+ state = torch.load(path, map_location="cpu")
499
+ params = state["net"]
500
+ if load_ema and "ema" in state:
501
+ print("Loading EMA")
502
+ for key in model:
503
+ i = 0
504
+ for param_name in params[key]:
505
+ if "input_pos" in param_name:
506
+ continue
507
+ assert params[key][param_name].shape == state["ema"][key][0][i].shape
508
+ params[key][param_name] = state["ema"][key][0][i].clone()
509
+ i += 1
510
+ for key in model:
511
+ if key in params and key not in ignore_modules:
512
+ if not is_distributed:
513
+ # strip prefix of DDP (module.), create a new OrderedDict that does not contain the prefix
514
+ for k in list(params[key].keys()):
515
+ if k.startswith("module."):
516
+ params[key][k[len("module.") :]] = params[key][k]
517
+ del params[key][k]
518
+ model_state_dict = model[key].state_dict()
519
+ # 过滤出形状匹配的键值对
520
+ filtered_state_dict = {
521
+ k: v
522
+ for k, v in params[key].items()
523
+ if k in model_state_dict and v.shape == model_state_dict[k].shape
524
+ }
525
+ skipped_keys = set(params[key].keys()) - set(filtered_state_dict.keys())
526
+ if skipped_keys:
527
+ print(
528
+ f"Warning: Skipped loading some keys due to shape mismatch: {skipped_keys}"
529
+ )
530
+ print("%s loaded" % key)
531
+ model[key].load_state_dict(filtered_state_dict, strict=False)
532
+ _ = [model[key].eval() for key in model]
533
+
534
+ if not load_only_params:
535
+ epoch = state["epoch"] + 1
536
+ iters = state["iters"]
537
+ optimizer.load_state_dict(state["optimizer"])
538
+ optimizer.load_scheduler_state_dict(state["scheduler"])
539
+
540
+ else:
541
+ epoch = 0
542
+ iters = 0
543
+
544
+ return model, optimizer, epoch, iters
545
+
546
+ def load_checkpoint2(
547
+ model,
548
+ optimizer,
549
+ path,
550
+ load_only_params=True,
551
+ ignore_modules=[],
552
+ is_distributed=False,
553
+ load_ema=False,
554
+ ):
555
+ state = torch.load(path, map_location="cpu")
556
+ params = state["net"]
557
+ if load_ema and "ema" in state:
558
+ print("Loading EMA")
559
+ for key in model.models:
560
+ i = 0
561
+ for param_name in params[key]:
562
+ if "input_pos" in param_name:
563
+ continue
564
+ assert params[key][param_name].shape == state["ema"][key][0][i].shape
565
+ params[key][param_name] = state["ema"][key][0][i].clone()
566
+ i += 1
567
+ for key in model.models:
568
+ if key in params and key not in ignore_modules:
569
+ if not is_distributed:
570
+ # strip prefix of DDP (module.), create a new OrderedDict that does not contain the prefix
571
+ for k in list(params[key].keys()):
572
+ if k.startswith("module."):
573
+ params[key][k[len("module.") :]] = params[key][k]
574
+ del params[key][k]
575
+ model_state_dict = model.models[key].state_dict()
576
+ # 过滤出形状匹配的键值对
577
+ filtered_state_dict = {
578
+ k: v
579
+ for k, v in params[key].items()
580
+ if k in model_state_dict and v.shape == model_state_dict[k].shape
581
+ }
582
+ skipped_keys = set(params[key].keys()) - set(filtered_state_dict.keys())
583
+ if skipped_keys:
584
+ print(
585
+ f"Warning: Skipped loading some keys due to shape mismatch: {skipped_keys}"
586
+ )
587
+ print("%s loaded" % key)
588
+ model.models[key].load_state_dict(filtered_state_dict, strict=False)
589
+ model.eval()
590
+ # _ = [model[key].eval() for key in model]
591
+
592
+ if not load_only_params:
593
+ epoch = state["epoch"] + 1
594
+ iters = state["iters"]
595
+ optimizer.load_state_dict(state["optimizer"])
596
+ optimizer.load_scheduler_state_dict(state["scheduler"])
597
+
598
+ else:
599
+ epoch = 0
600
+ iters = 0
601
+
602
+ return model, optimizer, epoch, iters
603
+
604
+ def recursive_munch(d):
605
+ if isinstance(d, dict):
606
+ return Munch((k, recursive_munch(v)) for k, v in d.items())
607
+ elif isinstance(d, list):
608
+ return [recursive_munch(v) for v in d]
609
+ else:
610
+ return d
indextts/s2mel/modules/.ipynb_checkpoints/diffusion_transformer-checkpoint.py ADDED
@@ -0,0 +1,258 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ import math
4
+
5
+ from modules.gpt_fast.model import ModelArgs, Transformer
6
+ # from modules.torchscript_modules.gpt_fast_model import ModelArgs, Transformer
7
+ from modules.wavenet import WN
8
+ from modules.commons import sequence_mask
9
+
10
+ from torch.nn.utils import weight_norm
11
+
12
+ def modulate(x, shift, scale):
13
+ return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
14
+
15
+
16
+ #################################################################################
17
+ # Embedding Layers for Timesteps and Class Labels #
18
+ #################################################################################
19
+
20
+ class TimestepEmbedder(nn.Module):
21
+ """
22
+ Embeds scalar timesteps into vector representations.
23
+ """
24
+ def __init__(self, hidden_size, frequency_embedding_size=256):
25
+ super().__init__()
26
+ self.mlp = nn.Sequential(
27
+ nn.Linear(frequency_embedding_size, hidden_size, bias=True),
28
+ nn.SiLU(),
29
+ nn.Linear(hidden_size, hidden_size, bias=True),
30
+ )
31
+ self.frequency_embedding_size = frequency_embedding_size
32
+ self.max_period = 10000
33
+ self.scale = 1000
34
+
35
+ half = frequency_embedding_size // 2
36
+ freqs = torch.exp(
37
+ -math.log(self.max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
38
+ )
39
+ self.register_buffer("freqs", freqs)
40
+
41
+ def timestep_embedding(self, t):
42
+ """
43
+ Create sinusoidal timestep embeddings.
44
+ :param t: a 1-D Tensor of N indices, one per batch element.
45
+ These may be fractional.
46
+ :param dim: the dimension of the output.
47
+ :param max_period: controls the minimum frequency of the embeddings.
48
+ :return: an (N, D) Tensor of positional embeddings.
49
+ """
50
+ # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
51
+
52
+ args = self.scale * t[:, None].float() * self.freqs[None]
53
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
54
+ if self.frequency_embedding_size % 2:
55
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
56
+ return embedding
57
+
58
+ def forward(self, t):
59
+ t_freq = self.timestep_embedding(t)
60
+ t_emb = self.mlp(t_freq)
61
+ return t_emb
62
+
63
+
64
+ class StyleEmbedder(nn.Module):
65
+ """
66
+ Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
67
+ """
68
+ def __init__(self, input_size, hidden_size, dropout_prob):
69
+ super().__init__()
70
+ use_cfg_embedding = dropout_prob > 0
71
+ self.embedding_table = nn.Embedding(int(use_cfg_embedding), hidden_size)
72
+ self.style_in = weight_norm(nn.Linear(input_size, hidden_size, bias=True))
73
+ self.input_size = input_size
74
+ self.dropout_prob = dropout_prob
75
+
76
+ def forward(self, labels, train, force_drop_ids=None):
77
+ use_dropout = self.dropout_prob > 0
78
+ if (train and use_dropout) or (force_drop_ids is not None):
79
+ labels = self.token_drop(labels, force_drop_ids)
80
+ else:
81
+ labels = self.style_in(labels)
82
+ embeddings = labels
83
+ return embeddings
84
+
85
+ class FinalLayer(nn.Module):
86
+ """
87
+ The final layer of DiT.
88
+ """
89
+ def __init__(self, hidden_size, patch_size, out_channels):
90
+ super().__init__()
91
+ self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
92
+ self.linear = weight_norm(nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True))
93
+ self.adaLN_modulation = nn.Sequential(
94
+ nn.SiLU(),
95
+ nn.Linear(hidden_size, 2 * hidden_size, bias=True)
96
+ )
97
+
98
+ def forward(self, x, c):
99
+ shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
100
+ x = modulate(self.norm_final(x), shift, scale)
101
+ x = self.linear(x)
102
+ return x
103
+
104
+ class DiT(torch.nn.Module):
105
+ def __init__(
106
+ self,
107
+ args
108
+ ):
109
+ super(DiT, self).__init__()
110
+ self.time_as_token = args.DiT.time_as_token if hasattr(args.DiT, 'time_as_token') else False
111
+ self.style_as_token = args.DiT.style_as_token if hasattr(args.DiT, 'style_as_token') else False
112
+ self.uvit_skip_connection = args.DiT.uvit_skip_connection if hasattr(args.DiT, 'uvit_skip_connection') else False
113
+ model_args = ModelArgs(
114
+ block_size=16384,#args.DiT.block_size,
115
+ n_layer=args.DiT.depth,
116
+ n_head=args.DiT.num_heads,
117
+ dim=args.DiT.hidden_dim,
118
+ head_dim=args.DiT.hidden_dim // args.DiT.num_heads,
119
+ vocab_size=1024,
120
+ uvit_skip_connection=self.uvit_skip_connection,
121
+ time_as_token=self.time_as_token,
122
+ )
123
+ self.transformer = Transformer(model_args)
124
+ self.in_channels = args.DiT.in_channels
125
+ self.out_channels = args.DiT.in_channels
126
+ self.num_heads = args.DiT.num_heads
127
+
128
+ self.x_embedder = weight_norm(nn.Linear(args.DiT.in_channels, args.DiT.hidden_dim, bias=True))
129
+
130
+ self.content_type = args.DiT.content_type # 'discrete' or 'continuous'
131
+ self.content_codebook_size = args.DiT.content_codebook_size # for discrete content
132
+ self.content_dim = args.DiT.content_dim # for continuous content
133
+ self.cond_embedder = nn.Embedding(args.DiT.content_codebook_size, args.DiT.hidden_dim) # discrete content
134
+ self.cond_projection = nn.Linear(args.DiT.content_dim, args.DiT.hidden_dim, bias=True) # continuous content
135
+
136
+ self.is_causal = args.DiT.is_causal
137
+
138
+ self.t_embedder = TimestepEmbedder(args.DiT.hidden_dim)
139
+
140
+ # self.style_embedder1 = weight_norm(nn.Linear(1024, args.DiT.hidden_dim, bias=True))
141
+ # self.style_embedder2 = weight_norm(nn.Linear(1024, args.style_encoder.dim, bias=True))
142
+
143
+ input_pos = torch.arange(16384)
144
+ self.register_buffer("input_pos", input_pos)
145
+
146
+ self.final_layer_type = args.DiT.final_layer_type # mlp or wavenet
147
+ if self.final_layer_type == 'wavenet':
148
+ self.t_embedder2 = TimestepEmbedder(args.wavenet.hidden_dim)
149
+ self.conv1 = nn.Linear(args.DiT.hidden_dim, args.wavenet.hidden_dim)
150
+ self.conv2 = nn.Conv1d(args.wavenet.hidden_dim, args.DiT.in_channels, 1)
151
+ self.wavenet = WN(hidden_channels=args.wavenet.hidden_dim,
152
+ kernel_size=args.wavenet.kernel_size,
153
+ dilation_rate=args.wavenet.dilation_rate,
154
+ n_layers=args.wavenet.num_layers,
155
+ gin_channels=args.wavenet.hidden_dim,
156
+ p_dropout=args.wavenet.p_dropout,
157
+ causal=False)
158
+ self.final_layer = FinalLayer(args.wavenet.hidden_dim, 1, args.wavenet.hidden_dim)
159
+ self.res_projection = nn.Linear(args.DiT.hidden_dim,
160
+ args.wavenet.hidden_dim) # residual connection from tranformer output to final output
161
+ self.wavenet_style_condition = args.wavenet.style_condition
162
+ assert args.DiT.style_condition == args.wavenet.style_condition
163
+ else:
164
+ self.final_mlp = nn.Sequential(
165
+ nn.Linear(args.DiT.hidden_dim, args.DiT.hidden_dim),
166
+ nn.SiLU(),
167
+ nn.Linear(args.DiT.hidden_dim, args.DiT.in_channels),
168
+ )
169
+ self.transformer_style_condition = args.DiT.style_condition
170
+
171
+
172
+ self.class_dropout_prob = args.DiT.class_dropout_prob
173
+ self.content_mask_embedder = nn.Embedding(1, args.DiT.hidden_dim)
174
+
175
+ self.long_skip_connection = args.DiT.long_skip_connection
176
+ self.skip_linear = nn.Linear(args.DiT.hidden_dim + args.DiT.in_channels, args.DiT.hidden_dim)
177
+
178
+ self.cond_x_merge_linear = nn.Linear(args.DiT.hidden_dim + args.DiT.in_channels * 2 +
179
+ args.style_encoder.dim * self.transformer_style_condition * (not self.style_as_token),
180
+ args.DiT.hidden_dim)
181
+ if self.style_as_token:
182
+ self.style_in = nn.Linear(args.style_encoder.dim, args.DiT.hidden_dim)
183
+
184
+ def setup_caches(self, max_batch_size, max_seq_length):
185
+ self.transformer.setup_caches(max_batch_size, max_seq_length, use_kv_cache=False)
186
+
187
+ def forward(self, x, prompt_x, x_lens, t, style, cond, mask_content=False):
188
+ """
189
+ x (torch.Tensor): random noise
190
+ prompt_x (torch.Tensor): reference mel + zero mel
191
+ shape: (batch_size, 80, 795+1068)
192
+ x_lens (torch.Tensor): mel frames output
193
+ shape: (batch_size, mel_timesteps)
194
+ t (torch.Tensor): radshape:
195
+ shape: (batch_size)
196
+ style (torch.Tensor): reference global style
197
+ shape: (batch_size, 192)
198
+ cond (torch.Tensor): semantic info of reference audio and altered audio
199
+ shape: (batch_size, mel_timesteps(795+1069), 512)
200
+
201
+ """
202
+ class_dropout = False
203
+ if self.training and torch.rand(1) < self.class_dropout_prob:
204
+ class_dropout = True
205
+ if not self.training and mask_content:
206
+ class_dropout = True
207
+ # cond_in_module = self.cond_embedder if self.content_type == 'discrete' else self.cond_projection
208
+ cond_in_module = self.cond_projection
209
+
210
+ B, _, T = x.size()
211
+
212
+
213
+ t1 = self.t_embedder(t) # (N, D) # t1 [2, 512]
214
+ cond = cond_in_module(cond) # cond [2,1863,512]->[2,1863,512]
215
+
216
+ x = x.transpose(1, 2) # [2,1863,80]
217
+ prompt_x = prompt_x.transpose(1, 2) # [2,1863,80]
218
+
219
+ x_in = torch.cat([x, prompt_x, cond], dim=-1) # 80+80+512=672 [2, 1863, 672]
220
+
221
+ if self.transformer_style_condition and not self.style_as_token: # True and True
222
+ x_in = torch.cat([x_in, style[:, None, :].repeat(1, T, 1)], dim=-1) #[2, 1863, 864]
223
+
224
+ if class_dropout: #False
225
+ x_in[..., self.in_channels:] = x_in[..., self.in_channels:] * 0 # 80维后全置为0
226
+
227
+ x_in = self.cond_x_merge_linear(x_in) # (N, T, D) [2, 1863, 512]
228
+
229
+ if self.style_as_token: # False
230
+ style = self.style_in(style)
231
+ style = torch.zeros_like(style) if class_dropout else style
232
+ x_in = torch.cat([style.unsqueeze(1), x_in], dim=1)
233
+
234
+ if self.time_as_token: # False
235
+ x_in = torch.cat([t1.unsqueeze(1), x_in], dim=1)
236
+
237
+ x_mask = sequence_mask(x_lens + self.style_as_token + self.time_as_token).to(x.device).unsqueeze(1) #torch.Size([1, 1, 1863])True
238
+ input_pos = self.input_pos[:x_in.size(1)] # (T,) range(0,1863)
239
+ x_mask_expanded = x_mask[:, None, :].repeat(1, 1, x_in.size(1), 1) if not self.is_causal else None # torch.Size([1, 1, 1863, 1863]
240
+ x_res = self.transformer(x_in, t1.unsqueeze(1), input_pos, x_mask_expanded) # [2, 1863, 512]
241
+ x_res = x_res[:, 1:] if self.time_as_token else x_res
242
+ x_res = x_res[:, 1:] if self.style_as_token else x_res
243
+
244
+ if self.long_skip_connection: #True
245
+ x_res = self.skip_linear(torch.cat([x_res, x], dim=-1))
246
+ if self.final_layer_type == 'wavenet':
247
+ x = self.conv1(x_res)
248
+ x = x.transpose(1, 2)
249
+ t2 = self.t_embedder2(t)
250
+ x = self.wavenet(x, x_mask, g=t2.unsqueeze(2)).transpose(1, 2) + self.res_projection(
251
+ x_res) # long residual connection
252
+ x = self.final_layer(x, t1).transpose(1, 2)
253
+ x = self.conv2(x)
254
+ else:
255
+ x = self.final_mlp(x_res)
256
+ x = x.transpose(1, 2)
257
+ # x [2,80,1863]
258
+ return x
indextts/s2mel/modules/.ipynb_checkpoints/flow_matching-checkpoint.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+
6
+ from modules.diffusion_transformer import DiT
7
+ from modules.commons import sequence_mask
8
+
9
+ from tqdm import tqdm
10
+
11
+ class BASECFM(torch.nn.Module, ABC):
12
+ def __init__(
13
+ self,
14
+ args,
15
+ ):
16
+ super().__init__()
17
+ self.sigma_min = 1e-6
18
+
19
+ self.estimator = None
20
+
21
+ self.in_channels = args.DiT.in_channels
22
+
23
+ self.criterion = torch.nn.MSELoss() if args.reg_loss_type == "l2" else torch.nn.L1Loss()
24
+
25
+ if hasattr(args.DiT, 'zero_prompt_speech_token'):
26
+ self.zero_prompt_speech_token = args.DiT.zero_prompt_speech_token
27
+ else:
28
+ self.zero_prompt_speech_token = False
29
+
30
+ @torch.inference_mode()
31
+ def inference(self, mu, x_lens, prompt, style, f0, n_timesteps, temperature=1.0, inference_cfg_rate=0.5):
32
+ """Forward diffusion
33
+
34
+ Args:
35
+ mu (torch.Tensor): semantic info of reference audio and altered audio
36
+ shape: (batch_size, mel_timesteps(795+1069), 512)
37
+ x_lens (torch.Tensor): mel frames output
38
+ shape: (batch_size, mel_timesteps)
39
+ prompt (torch.Tensor): reference mel
40
+ shape: (batch_size, 80, 795)
41
+ style (torch.Tensor): reference global style
42
+ shape: (batch_size, 192)
43
+ f0: None
44
+ n_timesteps (int): number of diffusion steps
45
+ temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
46
+
47
+ Returns:
48
+ sample: generated mel-spectrogram
49
+ shape: (batch_size, 80, mel_timesteps)
50
+ """
51
+ B, T = mu.size(0), mu.size(1)
52
+ z = torch.randn([B, self.in_channels, T], device=mu.device) * temperature
53
+ t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device)
54
+ # t_span = t_span + (-1) * (torch.cos(torch.pi / 2 * t_span) - 1 + t_span)
55
+ return self.solve_euler(z, x_lens, prompt, mu, style, f0, t_span, inference_cfg_rate)
56
+
57
+ def solve_euler(self, x, x_lens, prompt, mu, style, f0, t_span, inference_cfg_rate=0.5):
58
+ """
59
+ Fixed euler solver for ODEs.
60
+ Args:
61
+ x (torch.Tensor): random noise
62
+ t_span (torch.Tensor): n_timesteps interpolated
63
+ shape: (n_timesteps + 1,)
64
+ mu (torch.Tensor): semantic info of reference audio and altered audio
65
+ shape: (batch_size, mel_timesteps(795+1069), 512)
66
+ x_lens (torch.Tensor): mel frames output
67
+ shape: (batch_size, mel_timesteps)
68
+ prompt (torch.Tensor): reference mel
69
+ shape: (batch_size, 80, 795)
70
+ style (torch.Tensor): reference global style
71
+ shape: (batch_size, 192)
72
+ """
73
+ t, _, _ = t_span[0], t_span[-1], t_span[1] - t_span[0]
74
+
75
+ # I am storing this because I can later plot it by putting a debugger here and saving it to a file
76
+ # Or in future might add like a return_all_steps flag
77
+ sol = []
78
+ # apply prompt
79
+ prompt_len = prompt.size(-1)
80
+ prompt_x = torch.zeros_like(x)
81
+ prompt_x[..., :prompt_len] = prompt[..., :prompt_len]
82
+ x[..., :prompt_len] = 0
83
+ if self.zero_prompt_speech_token:
84
+ mu[..., :prompt_len] = 0
85
+ for step in tqdm(range(1, len(t_span))):
86
+ dt = t_span[step] - t_span[step - 1]
87
+ if inference_cfg_rate > 0:
88
+ # Stack original and CFG (null) inputs for batched processing
89
+ stacked_prompt_x = torch.cat([prompt_x, torch.zeros_like(prompt_x)], dim=0)
90
+ stacked_style = torch.cat([style, torch.zeros_like(style)], dim=0)
91
+ stacked_mu = torch.cat([mu, torch.zeros_like(mu)], dim=0)
92
+ stacked_x = torch.cat([x, x], dim=0)
93
+ stacked_t = torch.cat([t.unsqueeze(0), t.unsqueeze(0)], dim=0)
94
+
95
+ # Perform a single forward pass for both original and CFG inputs
96
+ stacked_dphi_dt = self.estimator(
97
+ stacked_x, stacked_prompt_x, x_lens, stacked_t, stacked_style, stacked_mu,
98
+ )
99
+
100
+ # Split the output back into the original and CFG components
101
+ dphi_dt, cfg_dphi_dt = stacked_dphi_dt.chunk(2, dim=0)
102
+
103
+ # Apply CFG formula
104
+ dphi_dt = (1.0 + inference_cfg_rate) * dphi_dt - inference_cfg_rate * cfg_dphi_dt
105
+ else:
106
+ dphi_dt = self.estimator(x, prompt_x, x_lens, t.unsqueeze(0), style, mu)
107
+
108
+ x = x + dt * dphi_dt
109
+ t = t + dt
110
+ sol.append(x)
111
+ if step < len(t_span) - 1:
112
+ dt = t_span[step + 1] - t
113
+ x[:, :, :prompt_len] = 0
114
+
115
+ return sol[-1]
116
+ def forward(self, x1, x_lens, prompt_lens, mu, style):
117
+ """Computes diffusion loss
118
+
119
+ Args:
120
+ mu (torch.Tensor): semantic info of reference audio and altered audio
121
+ shape: (batch_size, mel_timesteps(795+1069), 512)
122
+ x1: mel
123
+ x_lens (torch.Tensor): mel frames output
124
+ shape: (batch_size, mel_timesteps)
125
+ prompt (torch.Tensor): reference mel
126
+ shape: (batch_size, 80, 795)
127
+ style (torch.Tensor): reference global style
128
+ shape: (batch_size, 192)
129
+
130
+ Returns:
131
+ loss: conditional flow matching loss
132
+ y: conditional flow
133
+ shape: (batch_size, n_feats, mel_timesteps)
134
+ """
135
+ b, _, t = x1.shape
136
+
137
+ # random timestep
138
+ t = torch.rand([b, 1, 1], device=mu.device, dtype=x1.dtype)
139
+ # sample noise p(x_0)
140
+ z = torch.randn_like(x1)
141
+
142
+ y = (1 - (1 - self.sigma_min) * t) * z + t * x1
143
+ u = x1 - (1 - self.sigma_min) * z
144
+
145
+ prompt = torch.zeros_like(x1)
146
+ for bib in range(b):
147
+ prompt[bib, :, :prompt_lens[bib]] = x1[bib, :, :prompt_lens[bib]]
148
+ # range covered by prompt are set to 0
149
+ y[bib, :, :prompt_lens[bib]] = 0
150
+ if self.zero_prompt_speech_token:
151
+ mu[bib, :, :prompt_lens[bib]] = 0
152
+
153
+ estimator_out = self.estimator(y, prompt, x_lens, t.squeeze(1).squeeze(1), style, mu, prompt_lens)
154
+ loss = 0
155
+ for bib in range(b):
156
+ loss += self.criterion(estimator_out[bib, :, prompt_lens[bib]:x_lens[bib]], u[bib, :, prompt_lens[bib]:x_lens[bib]])
157
+ loss /= b
158
+
159
+ return loss, estimator_out + (1 - self.sigma_min) * z
160
+
161
+
162
+
163
+ class CFM(BASECFM):
164
+ def __init__(self, args):
165
+ super().__init__(
166
+ args
167
+ )
168
+ if args.dit_type == "DiT":
169
+ self.estimator = DiT(args)
170
+ else:
171
+ raise NotImplementedError(f"Unknown diffusion type {args.dit_type}")
indextts/s2mel/modules/.ipynb_checkpoints/length_regulator-checkpoint.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple
2
+ import torch
3
+ import torch.nn as nn
4
+ from torch.nn import functional as F
5
+ from modules.commons import sequence_mask
6
+ import numpy as np
7
+ from dac.nn.quantize import VectorQuantize
8
+
9
+ # f0_bin = 256
10
+ f0_max = 1100.0
11
+ f0_min = 50.0
12
+ f0_mel_min = 1127 * np.log(1 + f0_min / 700)
13
+ f0_mel_max = 1127 * np.log(1 + f0_max / 700)
14
+
15
+ def f0_to_coarse(f0, f0_bin):
16
+ f0_mel = 1127 * (1 + f0 / 700).log()
17
+ a = (f0_bin - 2) / (f0_mel_max - f0_mel_min)
18
+ b = f0_mel_min * a - 1.
19
+ f0_mel = torch.where(f0_mel > 0, f0_mel * a - b, f0_mel)
20
+ # torch.clip_(f0_mel, min=1., max=float(f0_bin - 1))
21
+ f0_coarse = torch.round(f0_mel).long()
22
+ f0_coarse = f0_coarse * (f0_coarse > 0)
23
+ f0_coarse = f0_coarse + ((f0_coarse < 1) * 1)
24
+ f0_coarse = f0_coarse * (f0_coarse < f0_bin)
25
+ f0_coarse = f0_coarse + ((f0_coarse >= f0_bin) * (f0_bin - 1))
26
+ return f0_coarse
27
+
28
+ class InterpolateRegulator(nn.Module):
29
+ def __init__(
30
+ self,
31
+ channels: int,
32
+ sampling_ratios: Tuple,
33
+ is_discrete: bool = False,
34
+ in_channels: int = None, # only applies to continuous input
35
+ vector_quantize: bool = False, # whether to use vector quantization, only applies to continuous input
36
+ codebook_size: int = 1024, # for discrete only
37
+ out_channels: int = None,
38
+ groups: int = 1,
39
+ n_codebooks: int = 1, # number of codebooks
40
+ quantizer_dropout: float = 0.0, # dropout for quantizer
41
+ f0_condition: bool = False,
42
+ n_f0_bins: int = 512,
43
+ ):
44
+ super().__init__()
45
+ self.sampling_ratios = sampling_ratios
46
+ out_channels = out_channels or channels
47
+ model = nn.ModuleList([])
48
+ if len(sampling_ratios) > 0:
49
+ self.interpolate = True
50
+ for _ in sampling_ratios:
51
+ module = nn.Conv1d(channels, channels, 3, 1, 1)
52
+ norm = nn.GroupNorm(groups, channels)
53
+ act = nn.Mish()
54
+ model.extend([module, norm, act])
55
+ else:
56
+ self.interpolate = False
57
+ model.append(
58
+ nn.Conv1d(channels, out_channels, 1, 1)
59
+ )
60
+ self.model = nn.Sequential(*model)
61
+ self.embedding = nn.Embedding(codebook_size, channels)
62
+ self.is_discrete = is_discrete
63
+
64
+ self.mask_token = nn.Parameter(torch.zeros(1, channels))
65
+
66
+ self.n_codebooks = n_codebooks
67
+ if n_codebooks > 1:
68
+ self.extra_codebooks = nn.ModuleList([
69
+ nn.Embedding(codebook_size, channels) for _ in range(n_codebooks - 1)
70
+ ])
71
+ self.extra_codebook_mask_tokens = nn.ParameterList([
72
+ nn.Parameter(torch.zeros(1, channels)) for _ in range(n_codebooks - 1)
73
+ ])
74
+ self.quantizer_dropout = quantizer_dropout
75
+
76
+ if f0_condition:
77
+ self.f0_embedding = nn.Embedding(n_f0_bins, channels)
78
+ self.f0_condition = f0_condition
79
+ self.n_f0_bins = n_f0_bins
80
+ self.f0_bins = torch.arange(2, 1024, 1024 // n_f0_bins)
81
+ self.f0_mask = nn.Parameter(torch.zeros(1, channels))
82
+ else:
83
+ self.f0_condition = False
84
+
85
+ if not is_discrete:
86
+ self.content_in_proj = nn.Linear(in_channels, channels)
87
+ if vector_quantize:
88
+ self.vq = VectorQuantize(channels, codebook_size, 8)
89
+
90
+ def forward(self, x, ylens=None, n_quantizers=None, f0=None):
91
+ # apply token drop
92
+ if self.training:
93
+ n_quantizers = torch.ones((x.shape[0],)) * self.n_codebooks
94
+ dropout = torch.randint(1, self.n_codebooks + 1, (x.shape[0],))
95
+ n_dropout = int(x.shape[0] * self.quantizer_dropout)
96
+ n_quantizers[:n_dropout] = dropout[:n_dropout]
97
+ n_quantizers = n_quantizers.to(x.device)
98
+ # decide whether to drop for each sample in batch
99
+ else:
100
+ n_quantizers = torch.ones((x.shape[0],), device=x.device) * (self.n_codebooks if n_quantizers is None else n_quantizers)
101
+ if self.is_discrete:
102
+ if self.n_codebooks > 1:
103
+ assert len(x.size()) == 3
104
+ x_emb = self.embedding(x[:, 0])
105
+ for i, emb in enumerate(self.extra_codebooks):
106
+ x_emb = x_emb + (n_quantizers > i+1)[..., None, None] * emb(x[:, i+1])
107
+ # add mask token if not using this codebook
108
+ # x_emb = x_emb + (n_quantizers <= i+1)[..., None, None] * self.extra_codebook_mask_tokens[i]
109
+ x = x_emb
110
+ elif self.n_codebooks == 1:
111
+ if len(x.size()) == 2:
112
+ x = self.embedding(x)
113
+ else:
114
+ x = self.embedding(x[:, 0])
115
+ else:
116
+ x = self.content_in_proj(x)
117
+ # x in (B, T, D)
118
+ mask = sequence_mask(ylens).unsqueeze(-1)
119
+ if self.interpolate:
120
+ x = F.interpolate(x.transpose(1, 2).contiguous(), size=ylens.max(), mode='nearest')
121
+ else:
122
+ x = x.transpose(1, 2).contiguous()
123
+ mask = mask[:, :x.size(2), :]
124
+ ylens = ylens.clamp(max=x.size(2)).long()
125
+ if self.f0_condition:
126
+ if f0 is None:
127
+ x = x + self.f0_mask.unsqueeze(-1)
128
+ else:
129
+ #quantized_f0 = torch.bucketize(f0, self.f0_bins.to(f0.device)) # (N, T)
130
+ quantized_f0 = f0_to_coarse(f0, self.n_f0_bins)
131
+ quantized_f0 = quantized_f0.clamp(0, self.n_f0_bins - 1).long()
132
+ f0_emb = self.f0_embedding(quantized_f0)
133
+ f0_emb = F.interpolate(f0_emb.transpose(1, 2).contiguous(), size=ylens.max(), mode='nearest')
134
+ x = x + f0_emb
135
+ out = self.model(x).transpose(1, 2).contiguous()
136
+ if hasattr(self, 'vq'):
137
+ out_q, commitment_loss, codebook_loss, codes, out, = self.vq(out.transpose(1, 2))
138
+ out_q = out_q.transpose(1, 2)
139
+ return out_q * mask, ylens, codes, commitment_loss, codebook_loss
140
+ olens = ylens
141
+ return out * mask, olens, None, None, None
webui.py CHANGED
@@ -38,7 +38,9 @@ from modelscope.hub import api
38
 
39
  i18n = I18nAuto(language="Auto")
40
  MODE = 'local'
41
- tts = IndexTTS2(model_dir=cmd_args.model_dir, cfg_path=os.path.join(cmd_args.model_dir, "config.yaml"),is_fp16=cmd_args.is_fp16)
 
 
42
 
43
  # 支持的语言列表
44
  LANGUAGES = {
 
38
 
39
  i18n = I18nAuto(language="Auto")
40
  MODE = 'local'
41
+ tts = IndexTTS2(model_dir=cmd_args.model_dir,
42
+ cfg_path=os.path.join(cmd_args.model_dir, "config.yaml"),
43
+ is_fp16=False,use_cuda_kernel=False)
44
 
45
  # 支持的语言列表
46
  LANGUAGES = {