NeuCoSVC / modules /nhv /layer /incremental.py
kevinwang676's picture
Upload folder using huggingface_hub
cfdc687
#! /usr/bin/env python
# -*- coding: utf-8 -*-
# vim:fenc=utf-8
#
# Copyright (c) 2021 Kazuhiro KOBAYASHI <root.4mac@gmail.com>
#
# Distributed under terms of the MIT license.
"""
"""
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from .layer import ConvLayers, DFTLayer
from .model import NeuralHomomorphicVocoder
from .module import CCepLTVFilter, SinusoidsGenerator
class IncrementalCacheConvClass(nn.Module):
def __init__(self):
super().__init__()
# remain handles to remove old hooks
self.handles = []
def _forward_without_cache(self, x):
raise NotImplementedError("Please implement _forward_without_cache")
def forward(self, caches, *inputs):
self.caches = caches
self.new_caches = []
self.cache_num = 0
x = self._forward(*inputs)
return x, self.new_caches
def reset_caches(self, *args, hop_size=128, batch_size=1):
self.caches = []
self.receptive_sizes = []
self._initialize_caches(batch_size=batch_size, hop_size=hop_size)
# set ordering hook
self._set_pre_hooks(cache_ordering=True)
# caclulate order of inference
_ = self._forward_without_cache(*args)
# remove hook handles for ordering
[h.remove() for h in self.handles]
# set concatenate hook
self._set_pre_hooks(cache_ordering=False)
# make cache zeros
self.caches = [torch.zeros_like(c) for c in self.caches]
# remove conv padding
self._remove_padding()
return self.caches
def _initialize_caches(self, batch_size=1, hop_size=128):
self.caches_dict = {}
self.receptive_sizes_dict = {}
for k, m in self.named_modules():
if isinstance(m, nn.Conv1d):
if m.kernel_size[0] > 1:
receptive_size = self._get_receptive_size_1d(m)
# NOTE(k2kobayashi): postfilter_fn requires to accept
# hop_size length input
if "postfilter_fn" in k:
receptive_size += hop_size - 1
self.caches_dict[id(m)] = torch.randn(
(batch_size, m.in_channels, receptive_size)
)
self.receptive_sizes_dict[id(m)] = receptive_size
def _set_pre_hooks(self, cache_ordering=True):
if cache_ordering:
func = self._cache_ordering
else:
func = self._concat_cache
for k, m in self.named_modules():
if isinstance(m, nn.Conv1d):
if m.kernel_size[0] > 1:
self.handles.append(m.register_forward_pre_hook(func))
def _concat_cache(self, module, inputs):
def __concat_cache(inputs, cache, receptive_size):
inputs = torch.cat([cache, inputs[0]], axis=-1)
inputs = inputs[..., -receptive_size:]
return inputs
cache = self.caches[self.cache_num]
receptive_size = self.receptive_sizes[self.cache_num]
inputs = __concat_cache(inputs, cache, receptive_size)
self.new_caches += [inputs]
self.cache_num += 1
return inputs
def _cache_ordering(self, module, inputs):
self.caches.append(self.caches_dict[id(module)])
self.receptive_sizes.append(self.receptive_sizes_dict[id(module)])
def _remove_padding(self):
def __remove_padding(m):
if isinstance(m, torch.nn.Conv1d):
m.padding = (0,)
if isinstance(m, torch.nn.Conv2d):
m.padding = (0, 0)
self.apply(__remove_padding)
@staticmethod
def _get_receptive_size_1d(m):
return (m.kernel_size[0] - 1) * m.dilation[0] + 1
class IncrementalNeuralHomomorphicVocoder(
NeuralHomomorphicVocoder, IncrementalCacheConvClass
):
fs = 24000
fft_size = 1024
hop_size = 256
in_channels = 80
conv_channels = 256
ccep_size = 222
out_channels = 1
kernel_size = 3
dilation_size = 1
group_size = 8
fmin = 80
fmax = 7600
roll_size = 24
n_ltv_layers = 3
n_postfilter_layers = 4
n_ltv_postfilter_layers = 1
use_causal = False
use_reference_mag = False
use_tanh = False
use_uvmask = False
use_weight_norm = True
conv_type = "original"
postfilter_type = "ddsconv"
ltv_postfilter_type = "conv"
ltv_postfilter_kernel_size = 128
scaler_file = None
def __init__(self, **kwargs):
super().__init__(**kwargs)
assert kwargs["use_causal"], "Require use_causal"
self.impulse_generator = IncrementalSinusoidsGenerator(
hop_size=self.hop_size, fs=self.fs, use_uvmask=self.use_uvmask
)
self.ltv_harmonic = IncrementalCCepLTVFilter(
**self.ltv_params, feat2linear_fn=self.feat2linear_fn
)
self.ltv_noise = IncrementalCCepLTVFilter(**self.ltv_params)
self.window_size = self.ltv_harmonic.window_size
def _forward_without_cache(self, *inputs):
super()._forward(*inputs)
def forward(self, z, x, f0, uv, ltv_caches, conv_caches):
self.caches = conv_caches
self.new_caches = []
self.cache_num = 0
y, new_ltv_caches = self._incremental_forward(z, x, f0, uv, ltv_caches)
return y, new_ltv_caches, self.new_caches
def _incremental_forward(self, z, x, cf0, uv, ltv_caches):
if self.feat_scaler_fn is not None:
x = self.feat_scaler_fn(x)
# impulse
impulse, impulse_cache = self.impulse_generator.incremental_forward(
cf0, uv, ltv_caches[0]
)
# ltv for harmonic
harmonic = self._concat_ltv_input_cache(ltv_caches[1], impulse)
ltv_harm = self.ltv_harmonic.incremental_forward(x, harmonic)
sig_harm = ltv_caches[2][..., -self.hop_size :] + ltv_harm[..., : self.hop_size]
if self.ltv_harmonic.ltv_postfilter_fn is not None:
sig_harm = self.ltv_harmonic.ltv_postfilter_fn(
sig_harm.transpose(1, 2)
).transpose(1, 2)
# ltv for noise
noise = self._concat_ltv_input_cache(ltv_caches[3], z)
ltv_noise = self.ltv_noise.incremental_forward(x, noise)
sig_noise = (
ltv_caches[4][..., -self.hop_size :] + ltv_noise[..., : self.hop_size]
)
if self.ltv_noise.ltv_postfilter_fn is not None:
sig_noise = self.ltv_noise.ltv_postfilter_fn(
sig_noise.transpose(1, 2)
).transpose(1, 2)
# superimpose
y = sig_harm + sig_noise
if self.postfilter_fn is not None:
y = self.postfilter_fn(y.transpose(1, 2)).transpose(1, 2)
y = torch.tanh(y) if self.use_tanh else torch.clamp(y, -1, 1)
new_ltv_caches = [impulse_cache, harmonic, ltv_harm, noise, ltv_noise]
return y.reshape(1, self.out_channels, -1), new_ltv_caches
def reset_ltv_caches(self):
ltv_caches = []
# impulse generator
ltv_caches += [torch.zeros(1, 1, 1)]
# ltv harm
ltv_caches += [torch.zeros(1, 1, self.window_size)]
ltv_caches += [torch.zeros(1, 1, self.window_size)]
# ltv noise
ltv_caches += [torch.zeros(1, 1, self.window_size)]
ltv_caches += [torch.zeros(1, 1, self.window_size)]
return ltv_caches
def _concat_ltv_input_cache(self, cache, z):
z = torch.cat([cache, z], axis=-1)
z = z[..., self.hop_size :]
return z
class IncrementalSinusoidsGenerator(SinusoidsGenerator):
def __init__(self, **kwargs):
super().__init__(**kwargs)
def incremental_forward(self, cf0, uv, cache):
f0, uv = self.upsample(cf0.transpose(1, 2)), self.upsample(uv.transpose(1, 2))
harmonic, new_cache = self.incremental_generate_sinusoids(f0, uv, cache)
harmonic = self.harmonic_amp * harmonic.reshape(cf0.size(0), 1, -1)
return harmonic, new_cache
def incremental_generate_sinusoids(self, f0, uv, cache):
mask = self.anti_aliacing_mask(f0 * self.harmonics)
# f0[..., 0] = f0[..., 0] + cache
f0 = torch.cat([cache, f0], axis=-1)
cumsum = torch.cumsum(f0, dim=-1)[..., 1:]
rads = cumsum * 2.0 * math.pi / self.fs * self.harmonics
harmonic = torch.sum(torch.cos(rads) * mask, dim=1, keepdim=True)
if self.use_uvmask:
harmonic = uv * harmonic
new_cache = cumsum[..., -1:] % self.fs
return harmonic, new_cache
class IncrementalConvLayers(ConvLayers, IncrementalCacheConvClass):
in_channels = 80
conv_channels = 256
out_channels = 222
kernel_size = 3
dilation_size = 1
group_size = 8
n_conv_layers = 3
use_causal = False
conv_type = "original"
def __init__(self, **kwargs):
for k, v in kwargs.items():
if k not in self.__class__.__dict__.keys():
raise ValueError(f"{k} not in arguments {self.__class__}.")
setattr(self, k, v)
assert kwargs["use_causal"], "Require use_causal"
super().__init__(**kwargs)
def _forward_without_cache(self, *inputs):
super().forward(*inputs)
def forward(self, x, conv_caches):
self.caches = conv_caches
self.new_caches = []
self.cache_num = 0
x = self.conv_layers(x)
return x, self.new_caches
class IncrementalCCepLTVFilter(CCepLTVFilter):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.conv_dft = DFTLayer(n_fft=self.fft_size)
self.conv_idft = DFTLayer(n_fft=self.fft_size + 1)
self.padding = (self.fft_size - self.ccep_size) // 2
def incremental_forward(self, x, z):
"""Input tensor size
x: (1, 1, input_size)
z: (1, 1, fft_size + hop_size)
"""
# inference complex cepstrum
ccep = self.conv(x) / self.quef_norm
log_mag = None if self.feat2linear_fn is None else self.feat2linear_fn(x)
y = self._dft_ccep2impulse(ccep, ref=log_mag)
# convolve to a frame
z = F.pad(z, (self.fft_size // 2, self.fft_size // 2))
z = F.conv1d(z, y)
return z * self.win
def _dft_ccep2impulse(self, ccep, ref=None):
ccep = F.pad(ccep, (self.padding, self.padding))
real, imag = self.conv_dft(ccep)
if ref is not None:
real = self._apply_ref_mag(real, ref)
mag, phase = torch.pow(10, real / 10), imag
real, imag = mag * torch.cos(phase), mag * torch.sin(phase)
real, _ = self.conv_idft(F.pad(real, (0, 1)), F.pad(imag, (0, 1)), inverse=True)
return real