Spaces:
Sleeping
Sleeping
File size: 1,372 Bytes
207ef6f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 |
import torch
from torch import nn
class Gauss_Pyramid_Conv(nn.Module):
"""
Code borrowed from: https://github.com/csjliang/LPTN
"""
def __init__(self, num_high=3):
super(Gauss_Pyramid_Conv, self).__init__()
self.num_high = num_high
self.kernel = self.gauss_kernel()
def gauss_kernel(self, device=torch.device('cuda'), channels=3):
kernel = torch.tensor([[1., 4., 6., 4., 1],
[4., 16., 24., 16., 4.],
[6., 24., 36., 24., 6.],
[4., 16., 24., 16., 4.],
[1., 4., 6., 4., 1.]])
kernel /= 256.
kernel = kernel.repeat(channels, 1, 1, 1)
kernel = kernel.to(device)
return kernel
def downsample(self, x):
return x[:, :, ::2, ::2]
def conv_gauss(self, img, kernel):
img = torch.nn.functional.pad(img, (2, 2, 2, 2), mode='reflect')
out = torch.nn.functional.conv2d(img, kernel, groups=img.shape[1])
return out
def forward(self, img):
current = img
pyr = []
for _ in range(self.num_high):
filtered = self.conv_gauss(current, self.kernel)
pyr.append(filtered)
down = self.downsample(filtered)
current = down
pyr.append(current)
return pyr
|