Spaces:
Configuration error
Configuration error
#! /usr/bin/env python | |
# -*- coding: utf-8 -*- | |
# vim:fenc=utf-8 | |
# | |
# Copyright (c) 2021 Kazuhiro KOBAYASHI <root.4mac@gmail.com> | |
# | |
# Distributed under terms of the MIT license. | |
""" | |
""" | |
import math | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import torch.fft | |
from .layer import Conv1d, ConvLayers | |
class CCepLTVFilter(nn.Module): | |
def __init__( | |
self, | |
in_channels, | |
conv_channels=256, | |
ccep_size=222, | |
kernel_size=3, | |
dilation_size=1, | |
group_size=8, | |
fft_size=1024, | |
hop_size=256, | |
n_ltv_layers=3, | |
n_ltv_postfilter_layers=1, | |
use_causal=False, | |
conv_type="original", | |
feat2linear_fn=None, | |
ltv_postfilter_type="conv", | |
ltv_postfilter_kernel_size=128, | |
): | |
super().__init__() | |
self.fft_size = fft_size | |
self.hop_size = hop_size | |
self.window_size = hop_size * 2 | |
self.ccep_size = ccep_size | |
self.use_causal = use_causal | |
self.feat2linear_fn = feat2linear_fn | |
self.ltv_postfilter_type = ltv_postfilter_type | |
self.ltv_postfilter_kernel_size = ltv_postfilter_kernel_size | |
self.n_ltv_postfilter_layers = n_ltv_postfilter_layers | |
win_norm = self.window_size // (hop_size * 2) # only for hanning window | |
# periodic must be True to become OLA 1 | |
win = torch.hann_window(self.window_size, periodic=True) / win_norm | |
self.conv = ConvLayers( | |
in_channels=in_channels, | |
conv_channels=conv_channels, | |
out_channels=ccep_size, | |
kernel_size=kernel_size, | |
dilation_size=dilation_size, | |
group_size=group_size, | |
n_conv_layers=n_ltv_layers, | |
use_causal=use_causal, | |
conv_type=conv_type, | |
) | |
self.ltv_postfilter_fn = self._get_ltv_postfilter_fn() | |
idx = torch.arange(1, ccep_size // 2 + 1).float() | |
quef_norm = torch.cat([torch.flip(idx, dims=[-1]), idx], dim=-1) | |
self.padding = (self.fft_size - self.ccep_size) // 2 | |
self.register_buffer("quef_norm", quef_norm) | |
self.register_buffer("win", win) | |
def forward(self, x, z): | |
""" | |
x: B, T, D | |
z: B, 1, T * hop_size | |
""" | |
# inference complex cepstrum | |
ccep = self.conv(x) / self.quef_norm | |
# apply LTV filter and overlap | |
log_mag = None if self.feat2linear_fn is None else self.feat2linear_fn(x) | |
y = self._ccep2impulse(ccep, ref=log_mag) | |
z = self._conv_impulse(z, y) | |
z = self._ola(z) | |
if self.ltv_postfilter_fn is not None: | |
z = self.ltv_postfilter_fn(z.transpose(1, 2)).transpose(1, 2) | |
return z | |
def _apply_ref_mag(self, real, ref): | |
# TODO(k2kobayashi): it requires to consider following line. | |
# this mask eliminates very small amplitude values (-100). | |
# ref = ref * (ref > -100) | |
real[..., : self.fft_size // 2 + 1] += ref | |
real[..., self.fft_size // 2 :] += torch.flip(ref[..., 1:], dims=[-1]) | |
return real | |
def _ccep2impulse(self, ccep, ref=None): | |
ccep = F.pad(ccep, (self.padding, self.padding)) | |
y = torch.fft.fft(ccep, n=self.fft_size, dim=-1) | |
# NOTE(k2kobayashi): we assume intermediate log amplitude as 10log10|mag| | |
if ref is not None: | |
y.real = self._apply_ref_mag(y.real, ref) | |
# logarithmic to linear | |
mag, phase = torch.pow(10, y.real / 10), y.imag | |
real, imag = mag * torch.cos(phase), mag * torch.sin(phase) | |
y = torch.fft.ifft(torch.complex(real, imag), n=self.fft_size + 1, dim=-1) | |
return y.real | |
def _conv_impulse(self, z, y): | |
# (B, T * hop_size + hop_size) | |
# z = F.pad(z, (self.hop_size // 2, self.hop_size // 2)).squeeze(1) | |
z = F.pad(z, (self.hop_size, 0)).squeeze(1) | |
z = z.unfold(-1, self.window_size, step=self.hop_size) # (B, T, window_size) | |
z = F.pad(z, (self.fft_size // 2, self.fft_size // 2)) | |
z = z.unfold(-1, self.fft_size + 1, step=1) # (B, T, window_size, fft_size + 1) | |
# y: (B, T, fft_size + 1) -> (B, T, fft_size + 1, 1) | |
# z: (B, T, window_size, fft_size + 1) | |
# output: (B, T, window_size) | |
output = torch.matmul(z, y.unsqueeze(-1)).squeeze(-1) | |
return output | |
def _conv_impulse_old(self, z, y): | |
z = F.pad(z, (self.window_size // 2 - 1, self.window_size // 2)).squeeze(1) | |
z = z.unfold(-1, self.window_size, step=self.hop_size) # (B, 1, T, window_size) | |
z = F.pad(z, (self.fft_size // 2 - 1, self.fft_size // 2)) | |
z = z.unfold(-1, self.fft_size, step=1) # (B, 1, T, window_size, fft_size) | |
# z = matmul(z, y) -> (B, 1, T, window_size) where | |
# z: (B, 1, T, window_size, fft_size) | |
# y: (B, T, fft_size) -> (B, 1, T, fft_size, 1) | |
z = torch.matmul(z, y.unsqueeze(-1)).squeeze(-1) | |
return z | |
def _ola(self, z): | |
z = z * self.win | |
l, r = torch.chunk(z, 2, dim=-1) # (B, 1, T, window_size // 2) | |
z = l + torch.roll(r, 1, dims=-2) # roll a frame of right side | |
z = z.reshape(z.size(0), 1, -1) | |
return z | |
def _get_ltv_postfilter_fn(self): | |
if self.ltv_postfilter_type == "ddsconv": | |
fn = ConvLayers( | |
in_channels=1, | |
conv_channels=64, | |
out_channels=1, | |
kernel_size=5, | |
dilation_size=2, | |
n_conv_layers=self.n_ltv_postfilter_layers, | |
use_causal=self.use_causal, | |
conv_type="ddsconv", | |
) | |
elif self.ltv_postfilter_type == "conv": | |
fn = Conv1d( | |
in_channels=1, | |
out_channels=1, | |
kernel_size=self.ltv_postfilter_kernel_size, | |
use_causal=self.use_causal, | |
) | |
elif self.ltv_postfilter_type is None: | |
fn = None | |
else: | |
raise ValueError(f"Invalid ltv_postfilter_type: {self.ltv_postfilter_type}") | |
return fn | |
class SinusoidsGenerator(nn.Module): | |
def __init__( | |
self, | |
hop_size, | |
fs=24000, | |
harmonic_amp=0.1, | |
n_harmonics=200, | |
use_uvmask=False, | |
): | |
super().__init__() | |
self.fs = fs | |
self.harmonic_amp = harmonic_amp | |
self.upsample = nn.Upsample(scale_factor=hop_size, mode="linear") | |
self.use_uvmask = use_uvmask | |
self.n_harmonics = n_harmonics | |
harmonics = torch.arange(1, self.n_harmonics + 1).unsqueeze(-1) | |
self.register_buffer("harmonics", harmonics) | |
def forward(self, cf0): | |
f0 = self.upsample(cf0.transpose(1, 2)) | |
uv = torch.zeros(f0.size()).to(f0.device) | |
nonzero_indices = torch.nonzero(f0, as_tuple=True) | |
uv[nonzero_indices] = 1.0 | |
harmonic = self.generate_sinusoids(f0, uv).reshape(cf0.size(0), 1, -1) | |
return self.harmonic_amp * harmonic | |
def generate_sinusoids(self, f0, uv): | |
mask = self.anti_aliacing_mask(f0 * self.harmonics) | |
rads = f0.cumsum(dim=-1) * 2.0 * math.pi / self.fs * self.harmonics | |
harmonic = torch.sum(torch.cos(rads) * mask, dim=1, keepdim=True) | |
if self.use_uvmask: | |
harmonic = uv * harmonic | |
return harmonic | |
def anti_aliacing_mask(self, f0_with_harmonics, use_soft_mask=False): | |
if use_soft_mask: | |
return torch.sigmoid(-(f0_with_harmonics - self.fs / 2.0)) | |
else: | |
return (f0_with_harmonics < self.fs / 2.0).float() | |