#! /usr/bin/env python # -*- coding: utf-8 -*- # vim:fenc=utf-8 # # Copyright (c) 2021 Kazuhiro KOBAYASHI # # Distributed under terms of the MIT license. """ """ import math import torch import torch.nn as nn import torch.nn.functional as F import torch.fft from .layer import Conv1d, ConvLayers class CCepLTVFilter(nn.Module): def __init__( self, in_channels, conv_channels=256, ccep_size=222, kernel_size=3, dilation_size=1, group_size=8, fft_size=1024, hop_size=256, n_ltv_layers=3, n_ltv_postfilter_layers=1, use_causal=False, conv_type="original", feat2linear_fn=None, ltv_postfilter_type="conv", ltv_postfilter_kernel_size=128, ): super().__init__() self.fft_size = fft_size self.hop_size = hop_size self.window_size = hop_size * 2 self.ccep_size = ccep_size self.use_causal = use_causal self.feat2linear_fn = feat2linear_fn self.ltv_postfilter_type = ltv_postfilter_type self.ltv_postfilter_kernel_size = ltv_postfilter_kernel_size self.n_ltv_postfilter_layers = n_ltv_postfilter_layers win_norm = self.window_size // (hop_size * 2) # only for hanning window # periodic must be True to become OLA 1 win = torch.hann_window(self.window_size, periodic=True) / win_norm self.conv = ConvLayers( in_channels=in_channels, conv_channels=conv_channels, out_channels=ccep_size, kernel_size=kernel_size, dilation_size=dilation_size, group_size=group_size, n_conv_layers=n_ltv_layers, use_causal=use_causal, conv_type=conv_type, ) self.ltv_postfilter_fn = self._get_ltv_postfilter_fn() idx = torch.arange(1, ccep_size // 2 + 1).float() quef_norm = torch.cat([torch.flip(idx, dims=[-1]), idx], dim=-1) self.padding = (self.fft_size - self.ccep_size) // 2 self.register_buffer("quef_norm", quef_norm) self.register_buffer("win", win) def forward(self, x, z): """ x: B, T, D z: B, 1, T * hop_size """ # inference complex cepstrum ccep = self.conv(x) / self.quef_norm # apply LTV filter and overlap log_mag = None if self.feat2linear_fn is None else self.feat2linear_fn(x) y = self._ccep2impulse(ccep, ref=log_mag) z = self._conv_impulse(z, y) z = self._ola(z) if self.ltv_postfilter_fn is not None: z = self.ltv_postfilter_fn(z.transpose(1, 2)).transpose(1, 2) return z def _apply_ref_mag(self, real, ref): # TODO(k2kobayashi): it requires to consider following line. # this mask eliminates very small amplitude values (-100). # ref = ref * (ref > -100) real[..., : self.fft_size // 2 + 1] += ref real[..., self.fft_size // 2 :] += torch.flip(ref[..., 1:], dims=[-1]) return real def _ccep2impulse(self, ccep, ref=None): ccep = F.pad(ccep, (self.padding, self.padding)) y = torch.fft.fft(ccep, n=self.fft_size, dim=-1) # NOTE(k2kobayashi): we assume intermediate log amplitude as 10log10|mag| if ref is not None: y.real = self._apply_ref_mag(y.real, ref) # logarithmic to linear mag, phase = torch.pow(10, y.real / 10), y.imag real, imag = mag * torch.cos(phase), mag * torch.sin(phase) y = torch.fft.ifft(torch.complex(real, imag), n=self.fft_size + 1, dim=-1) return y.real def _conv_impulse(self, z, y): # (B, T * hop_size + hop_size) # z = F.pad(z, (self.hop_size // 2, self.hop_size // 2)).squeeze(1) z = F.pad(z, (self.hop_size, 0)).squeeze(1) z = z.unfold(-1, self.window_size, step=self.hop_size) # (B, T, window_size) z = F.pad(z, (self.fft_size // 2, self.fft_size // 2)) z = z.unfold(-1, self.fft_size + 1, step=1) # (B, T, window_size, fft_size + 1) # y: (B, T, fft_size + 1) -> (B, T, fft_size + 1, 1) # z: (B, T, window_size, fft_size + 1) # output: (B, T, window_size) output = torch.matmul(z, y.unsqueeze(-1)).squeeze(-1) return output def _conv_impulse_old(self, z, y): z = F.pad(z, (self.window_size // 2 - 1, self.window_size // 2)).squeeze(1) z = z.unfold(-1, self.window_size, step=self.hop_size) # (B, 1, T, window_size) z = F.pad(z, (self.fft_size // 2 - 1, self.fft_size // 2)) z = z.unfold(-1, self.fft_size, step=1) # (B, 1, T, window_size, fft_size) # z = matmul(z, y) -> (B, 1, T, window_size) where # z: (B, 1, T, window_size, fft_size) # y: (B, T, fft_size) -> (B, 1, T, fft_size, 1) z = torch.matmul(z, y.unsqueeze(-1)).squeeze(-1) return z def _ola(self, z): z = z * self.win l, r = torch.chunk(z, 2, dim=-1) # (B, 1, T, window_size // 2) z = l + torch.roll(r, 1, dims=-2) # roll a frame of right side z = z.reshape(z.size(0), 1, -1) return z def _get_ltv_postfilter_fn(self): if self.ltv_postfilter_type == "ddsconv": fn = ConvLayers( in_channels=1, conv_channels=64, out_channels=1, kernel_size=5, dilation_size=2, n_conv_layers=self.n_ltv_postfilter_layers, use_causal=self.use_causal, conv_type="ddsconv", ) elif self.ltv_postfilter_type == "conv": fn = Conv1d( in_channels=1, out_channels=1, kernel_size=self.ltv_postfilter_kernel_size, use_causal=self.use_causal, ) elif self.ltv_postfilter_type is None: fn = None else: raise ValueError(f"Invalid ltv_postfilter_type: {self.ltv_postfilter_type}") return fn class SinusoidsGenerator(nn.Module): def __init__( self, hop_size, fs=24000, harmonic_amp=0.1, n_harmonics=200, use_uvmask=False, ): super().__init__() self.fs = fs self.harmonic_amp = harmonic_amp self.upsample = nn.Upsample(scale_factor=hop_size, mode="linear") self.use_uvmask = use_uvmask self.n_harmonics = n_harmonics harmonics = torch.arange(1, self.n_harmonics + 1).unsqueeze(-1) self.register_buffer("harmonics", harmonics) def forward(self, cf0): f0 = self.upsample(cf0.transpose(1, 2)) uv = torch.zeros(f0.size()).to(f0.device) nonzero_indices = torch.nonzero(f0, as_tuple=True) uv[nonzero_indices] = 1.0 harmonic = self.generate_sinusoids(f0, uv).reshape(cf0.size(0), 1, -1) return self.harmonic_amp * harmonic def generate_sinusoids(self, f0, uv): mask = self.anti_aliacing_mask(f0 * self.harmonics) rads = f0.cumsum(dim=-1) * 2.0 * math.pi / self.fs * self.harmonics harmonic = torch.sum(torch.cos(rads) * mask, dim=1, keepdim=True) if self.use_uvmask: harmonic = uv * harmonic return harmonic def anti_aliacing_mask(self, f0_with_harmonics, use_soft_mask=False): if use_soft_mask: return torch.sigmoid(-(f0_with_harmonics - self.fs / 2.0)) else: return (f0_with_harmonics < self.fs / 2.0).float()