import torch from torch import nn class Gauss_Pyramid_Conv(nn.Module): """ Code borrowed from: https://github.com/csjliang/LPTN """ def __init__(self, num_high=3): super(Gauss_Pyramid_Conv, self).__init__() self.num_high = num_high self.kernel = self.gauss_kernel() def gauss_kernel(self, device=torch.device('cuda'), channels=3): kernel = torch.tensor([[1., 4., 6., 4., 1], [4., 16., 24., 16., 4.], [6., 24., 36., 24., 6.], [4., 16., 24., 16., 4.], [1., 4., 6., 4., 1.]]) kernel /= 256. kernel = kernel.repeat(channels, 1, 1, 1) kernel = kernel.to(device) return kernel def downsample(self, x): return x[:, :, ::2, ::2] def conv_gauss(self, img, kernel): img = torch.nn.functional.pad(img, (2, 2, 2, 2), mode='reflect') out = torch.nn.functional.conv2d(img, kernel, groups=img.shape[1]) return out def forward(self, img): current = img pyr = [] for _ in range(self.num_high): filtered = self.conv_gauss(current, self.kernel) pyr.append(filtered) down = self.downsample(filtered) current = down pyr.append(current) return pyr