File size: 7,268 Bytes
d72b2c3 8e60374 d72b2c3 8e60374 d72b2c3 8e60374 d72b2c3 8e60374 d72b2c3 8e60374 d72b2c3 8e60374 d72b2c3 d8e2a3d d72b2c3 8e60374 d72b2c3 d8e2a3d d72b2c3 d8e2a3d d72b2c3 d8e2a3d d72b2c3 d8e2a3d d72b2c3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 |
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import math
import typing as tp
import warnings
import torch
from torch import nn
from torch.nn import functional as F
from torch.nn.utils import spectral_norm, weight_norm
CONV_NORMALIZATIONS = frozenset(['none', 'weight_norm', 'spectral_norm',
'time_group_norm'])
def apply_parametrization_norm(module: nn.Module, norm: str = 'none'):
assert norm in CONV_NORMALIZATIONS
if norm == 'weight_norm':
return weight_norm(module)
elif norm == 'spectral_norm':
return spectral_norm(module)
else:
# We already check was in CONV_NORMALIZATION, so any other choice
# doesn't need reparametrization.
return module
def get_extra_padding_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int,
padding_total: int = 0) -> int:
"""See `pad_for_conv1d`."""
length = x.shape[-1]
n_frames = (length - kernel_size + padding_total) / stride + 1
ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total)
return ideal_length - length
def pad1d(x: torch.Tensor, paddings: tp.Tuple[int, int], mode: str = 'constant', value: float = 0.):
"""Tiny wrapper around F.pad, just to allow for reflect padding on small input.
If this is the case, we insert extra 0 padding to the right before the reflection happen.
"""
length = x.shape[-1]
padding_left, padding_right = paddings
assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
if mode == 'reflect':
max_pad = max(padding_left, padding_right)
extra_pad = 0
if length <= max_pad:
extra_pad = max_pad - length + 1
x = F.pad(x, (0, extra_pad))
padded = F.pad(x, paddings, mode, value)
end = padded.shape[-1] - extra_pad
return padded[..., :end]
else:
return F.pad(x, paddings, mode, value)
def unpad1d(x: torch.Tensor, paddings: tp.Tuple[int, int]):
"""Remove padding from x, handling properly zero padding. Only for 1d!"""
padding_left, padding_right = paddings
assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
assert (padding_left + padding_right) <= x.shape[-1]
end = x.shape[-1] - padding_right
return x[..., padding_left: end]
class NormConv1d(nn.Module):
def __init__(self, *args,
causal = False, norm = 'none',
norm_kwargs = {}, **kwargs):
super().__init__()
self.conv = apply_parametrization_norm(nn.Conv1d(*args, **kwargs), norm) # norm = weight_norm
def forward(self, x):
return self.conv(x)
class NormConvTranspose1d(nn.Module):
def __init__(self, *args, causal: bool = False, norm: str = 'none',
norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
super().__init__()
self.convtr = apply_parametrization_norm(nn.ConvTranspose1d(*args, **kwargs), norm)
def forward(self, x):
return self.convtr(x)
class StreamableConv1d(nn.Module):
"""Conv1d with some builtin handling of asymmetric or causal padding
and normalization.
"""
def __init__(self,
in_channels,
out_channels,
kernel_size,
stride=1,
dilation=1,
groups=1,
bias=True,
causal=False,
norm='none',
norm_kwargs={},
pad_mode='reflect'):
super().__init__()
# warn user on unusual setup between dilation and stride
# if stride > 1 and dilation > 1:
# warnings.warn("StreamableConv1d has been initialized with stride > 1 and dilation > 1"
# f" (kernel_size={kernel_size} stride={stride}, dilation={dilation}).")
self.conv = NormConv1d(in_channels, out_channels, kernel_size, stride,
dilation=dilation, groups=groups, bias=bias, causal=causal,
norm=norm, norm_kwargs=norm_kwargs)
self.causal = causal
self.pad_mode = pad_mode
def forward(self, x):
B, C, T = x.shape
kernel_size = self.conv.conv.kernel_size[0]
stride = self.conv.conv.stride[0]
dilation = self.conv.conv.dilation[0]
kernel_size = (kernel_size - 1) * dilation + 1 # effective kernel size with dilations
padding_total = kernel_size - stride
extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total)
if self.causal:
# Left padding for causal
# x = pad1d(x, (padding_total, extra_padding), mode=self.pad_mode)
print('\n \n\n\nn\n\n\nnCAUSAL N\n\n\n')
else:
# Asymmetric padding required for odd strides
padding_right = padding_total // 2
padding_left = padding_total - padding_right
x = pad1d(x, (padding_left, padding_right + extra_padding), mode=self.pad_mode)
# print(f'\n \/n\n\n\nANTICaus N {x.shape=}\n')
# ANTICaus CONV OLD_SHAPE=torch.Size([1, 512, 280]) x.shape=torch.Size([1, 512, 282])
return self.conv(x)
class StreamableConvTranspose1d(nn.Module):
"""ConvTranspose1d with some builtin handling of asymmetric or causal padding
and normalization.
"""
def __init__(self, in_channels: int, out_channels: int,
kernel_size: int, stride: int = 1, causal: bool = False,
norm: str = 'none', trim_right_ratio: float = 1.,
norm_kwargs: tp.Dict[str, tp.Any] = {}):
super().__init__()
self.convtr = NormConvTranspose1d(in_channels, out_channels, kernel_size, stride,
causal=causal, norm=norm, norm_kwargs=norm_kwargs)
self.causal = causal
self.trim_right_ratio = trim_right_ratio
assert self.causal or self.trim_right_ratio == 1., \
"`trim_right_ratio` != 1.0 only makes sense for causal convolutions"
assert self.trim_right_ratio >= 0. and self.trim_right_ratio <= 1.
def forward(self, x):
kernel_size = self.convtr.convtr.kernel_size[0]
stride = self.convtr.convtr.stride[0]
padding_total = kernel_size - stride
y = self.convtr(x)
# We will only trim fixed padding. Extra padding from `pad_for_conv1d` would be
# removed at the very end, when keeping only the right length for the output,
# as removing it here would require also passing the length at the matching layer
# in the encoder.
if self.causal:
print('\n \n\n\nn\n\n\nnCAUSAL T\n\n\n\n\n')
else:
# Asymmetric padding required for odd strides
# print('\n \n\n\nn\n\n\nnANTICAUSAL T\n\n\n')
padding_right = padding_total // 2
padding_left = padding_total - padding_right
y = unpad1d(y, (padding_left, padding_right))
return y
|