Spaces:
Configuration error
Configuration error
File size: 7,568 Bytes
cfdc687 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 |
#! /usr/bin/env python
# -*- coding: utf-8 -*-
# vim:fenc=utf-8
#
# Copyright (c) 2021 Kazuhiro KOBAYASHI <root.4mac@gmail.com>
#
# Distributed under terms of the MIT license.
"""
"""
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.fft
from .layer import Conv1d, ConvLayers
class CCepLTVFilter(nn.Module):
def __init__(
self,
in_channels,
conv_channels=256,
ccep_size=222,
kernel_size=3,
dilation_size=1,
group_size=8,
fft_size=1024,
hop_size=256,
n_ltv_layers=3,
n_ltv_postfilter_layers=1,
use_causal=False,
conv_type="original",
feat2linear_fn=None,
ltv_postfilter_type="conv",
ltv_postfilter_kernel_size=128,
):
super().__init__()
self.fft_size = fft_size
self.hop_size = hop_size
self.window_size = hop_size * 2
self.ccep_size = ccep_size
self.use_causal = use_causal
self.feat2linear_fn = feat2linear_fn
self.ltv_postfilter_type = ltv_postfilter_type
self.ltv_postfilter_kernel_size = ltv_postfilter_kernel_size
self.n_ltv_postfilter_layers = n_ltv_postfilter_layers
win_norm = self.window_size // (hop_size * 2) # only for hanning window
# periodic must be True to become OLA 1
win = torch.hann_window(self.window_size, periodic=True) / win_norm
self.conv = ConvLayers(
in_channels=in_channels,
conv_channels=conv_channels,
out_channels=ccep_size,
kernel_size=kernel_size,
dilation_size=dilation_size,
group_size=group_size,
n_conv_layers=n_ltv_layers,
use_causal=use_causal,
conv_type=conv_type,
)
self.ltv_postfilter_fn = self._get_ltv_postfilter_fn()
idx = torch.arange(1, ccep_size // 2 + 1).float()
quef_norm = torch.cat([torch.flip(idx, dims=[-1]), idx], dim=-1)
self.padding = (self.fft_size - self.ccep_size) // 2
self.register_buffer("quef_norm", quef_norm)
self.register_buffer("win", win)
def forward(self, x, z):
"""
x: B, T, D
z: B, 1, T * hop_size
"""
# inference complex cepstrum
ccep = self.conv(x) / self.quef_norm
# apply LTV filter and overlap
log_mag = None if self.feat2linear_fn is None else self.feat2linear_fn(x)
y = self._ccep2impulse(ccep, ref=log_mag)
z = self._conv_impulse(z, y)
z = self._ola(z)
if self.ltv_postfilter_fn is not None:
z = self.ltv_postfilter_fn(z.transpose(1, 2)).transpose(1, 2)
return z
def _apply_ref_mag(self, real, ref):
# TODO(k2kobayashi): it requires to consider following line.
# this mask eliminates very small amplitude values (-100).
# ref = ref * (ref > -100)
real[..., : self.fft_size // 2 + 1] += ref
real[..., self.fft_size // 2 :] += torch.flip(ref[..., 1:], dims=[-1])
return real
def _ccep2impulse(self, ccep, ref=None):
ccep = F.pad(ccep, (self.padding, self.padding))
y = torch.fft.fft(ccep, n=self.fft_size, dim=-1)
# NOTE(k2kobayashi): we assume intermediate log amplitude as 10log10|mag|
if ref is not None:
y.real = self._apply_ref_mag(y.real, ref)
# logarithmic to linear
mag, phase = torch.pow(10, y.real / 10), y.imag
real, imag = mag * torch.cos(phase), mag * torch.sin(phase)
y = torch.fft.ifft(torch.complex(real, imag), n=self.fft_size + 1, dim=-1)
return y.real
def _conv_impulse(self, z, y):
# (B, T * hop_size + hop_size)
# z = F.pad(z, (self.hop_size // 2, self.hop_size // 2)).squeeze(1)
z = F.pad(z, (self.hop_size, 0)).squeeze(1)
z = z.unfold(-1, self.window_size, step=self.hop_size) # (B, T, window_size)
z = F.pad(z, (self.fft_size // 2, self.fft_size // 2))
z = z.unfold(-1, self.fft_size + 1, step=1) # (B, T, window_size, fft_size + 1)
# y: (B, T, fft_size + 1) -> (B, T, fft_size + 1, 1)
# z: (B, T, window_size, fft_size + 1)
# output: (B, T, window_size)
output = torch.matmul(z, y.unsqueeze(-1)).squeeze(-1)
return output
def _conv_impulse_old(self, z, y):
z = F.pad(z, (self.window_size // 2 - 1, self.window_size // 2)).squeeze(1)
z = z.unfold(-1, self.window_size, step=self.hop_size) # (B, 1, T, window_size)
z = F.pad(z, (self.fft_size // 2 - 1, self.fft_size // 2))
z = z.unfold(-1, self.fft_size, step=1) # (B, 1, T, window_size, fft_size)
# z = matmul(z, y) -> (B, 1, T, window_size) where
# z: (B, 1, T, window_size, fft_size)
# y: (B, T, fft_size) -> (B, 1, T, fft_size, 1)
z = torch.matmul(z, y.unsqueeze(-1)).squeeze(-1)
return z
def _ola(self, z):
z = z * self.win
l, r = torch.chunk(z, 2, dim=-1) # (B, 1, T, window_size // 2)
z = l + torch.roll(r, 1, dims=-2) # roll a frame of right side
z = z.reshape(z.size(0), 1, -1)
return z
def _get_ltv_postfilter_fn(self):
if self.ltv_postfilter_type == "ddsconv":
fn = ConvLayers(
in_channels=1,
conv_channels=64,
out_channels=1,
kernel_size=5,
dilation_size=2,
n_conv_layers=self.n_ltv_postfilter_layers,
use_causal=self.use_causal,
conv_type="ddsconv",
)
elif self.ltv_postfilter_type == "conv":
fn = Conv1d(
in_channels=1,
out_channels=1,
kernel_size=self.ltv_postfilter_kernel_size,
use_causal=self.use_causal,
)
elif self.ltv_postfilter_type is None:
fn = None
else:
raise ValueError(f"Invalid ltv_postfilter_type: {self.ltv_postfilter_type}")
return fn
class SinusoidsGenerator(nn.Module):
def __init__(
self,
hop_size,
fs=24000,
harmonic_amp=0.1,
n_harmonics=200,
use_uvmask=False,
):
super().__init__()
self.fs = fs
self.harmonic_amp = harmonic_amp
self.upsample = nn.Upsample(scale_factor=hop_size, mode="linear")
self.use_uvmask = use_uvmask
self.n_harmonics = n_harmonics
harmonics = torch.arange(1, self.n_harmonics + 1).unsqueeze(-1)
self.register_buffer("harmonics", harmonics)
def forward(self, cf0):
f0 = self.upsample(cf0.transpose(1, 2))
uv = torch.zeros(f0.size()).to(f0.device)
nonzero_indices = torch.nonzero(f0, as_tuple=True)
uv[nonzero_indices] = 1.0
harmonic = self.generate_sinusoids(f0, uv).reshape(cf0.size(0), 1, -1)
return self.harmonic_amp * harmonic
def generate_sinusoids(self, f0, uv):
mask = self.anti_aliacing_mask(f0 * self.harmonics)
rads = f0.cumsum(dim=-1) * 2.0 * math.pi / self.fs * self.harmonics
harmonic = torch.sum(torch.cos(rads) * mask, dim=1, keepdim=True)
if self.use_uvmask:
harmonic = uv * harmonic
return harmonic
def anti_aliacing_mask(self, f0_with_harmonics, use_soft_mask=False):
if use_soft_mask:
return torch.sigmoid(-(f0_with_harmonics - self.fs / 2.0))
else:
return (f0_with_harmonics < self.fs / 2.0).float()
|