Spaces:
Configuration error
Configuration error
File size: 2,027 Bytes
cfdc687 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 |
import torch
from modules.base import BaseModule
from modules.interpolation import InterpolationBlock
from modules.layers import Conv1dWithInitialization
class ConvolutionBlock(BaseModule):
def __init__(self, in_channels, out_channels, dilation):
super(ConvolutionBlock, self).__init__()
self.leaky_relu = torch.nn.LeakyReLU(0.2)
self.convolution = Conv1dWithInitialization(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=3,
stride=1,
padding=dilation,
dilation=dilation
)
def forward(self, x):
outputs = self.leaky_relu(x)
outputs = self.convolution(outputs)
return outputs
class DownsamplingBlock(BaseModule):
def __init__(self, in_channels, out_channels, factor, dilations):
super(DownsamplingBlock, self).__init__()
in_sizes = [in_channels] + [out_channels for _ in range(len(dilations) - 1)]
out_sizes = [out_channels for _ in range(len(in_sizes))]
self.main_branch = torch.nn.Sequential(*([
InterpolationBlock(
scale_factor=factor,
mode='linear',
align_corners=False,
downsample=True
)
] + [
ConvolutionBlock(in_size, out_size, dilation)
for in_size, out_size, dilation in zip(in_sizes, out_sizes, dilations)
]))
self.residual_branch = torch.nn.Sequential(*[
Conv1dWithInitialization(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=1,
stride=1
),
InterpolationBlock(
scale_factor=factor,
mode='linear',
align_corners=False,
downsample=True
)
])
def forward(self, x):
outputs = self.main_branch(x)
outputs = outputs + self.residual_branch(x)
return outputs
|