hsiangyualex's picture
Upload 64 files
f97a499 verified
raw
history blame
9.56 kB
import numpy as np
import torch
import torch.nn as nn
try:
# PyTorch 1.7.0 and newer versions
import torch.fft
def dct1_rfft_impl(x):
return torch.view_as_real(torch.fft.rfft(x, dim=1))
def dct_fft_impl(v):
return torch.view_as_real(torch.fft.fft(v, dim=1))
def idct_irfft_impl(V):
return torch.fft.irfft(torch.view_as_complex(V), n=V.shape[1], dim=1)
except ImportError:
# PyTorch 1.6.0 and older versions
def dct1_rfft_impl(x):
return torch.rfft(x, 1)
def dct_fft_impl(v):
return torch.rfft(v, 1, onesided=False)
def idct_irfft_impl(V):
return torch.irfft(V, 1, onesided=False)
def dct1(x):
"""
Discrete Cosine Transform, Type I
:param x: the input signal
:return: the DCT-I of the signal over the last dimension
"""
x_shape = x.shape
x = x.view(-1, x_shape[-1])
x = torch.cat([x, x.flip([1])[:, 1:-1]], dim=1)
return dct1_rfft_impl(x)[:, :, 0].view(*x_shape)
def idct1(X):
"""
The inverse of DCT-I, which is just a scaled DCT-I
Our definition if idct1 is such that idct1(dct1(x)) == x
:param X: the input signal
:return: the inverse DCT-I of the signal over the last dimension
"""
n = X.shape[-1]
return dct1(X) / (2 * (n - 1))
def dct(x, norm=None):
"""
Discrete Cosine Transform, Type II (a.k.a. the DCT)
For the meaning of the parameter `norm`, see:
https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html
:param x: the input signal
:param norm: the normalization, None or 'ortho'
:return: the DCT-II of the signal over the last dimension
"""
x_shape = x.shape
N = x_shape[-1]
x = x.contiguous().view(-1, N)
v = torch.cat([x[:, ::2], x[:, 1::2].flip([1])], dim=1)
Vc = dct_fft_impl(v)
k = - torch.arange(N, dtype=x.dtype, device=x.device)[None, :] * np.pi / (2 * N)
W_r = torch.cos(k)
W_i = torch.sin(k)
V = Vc[:, :, 0] * W_r - Vc[:, :, 1] * W_i
if norm == 'ortho':
V[:, 0] /= np.sqrt(N) * 2
V[:, 1:] /= np.sqrt(N / 2) * 2
V = 2 * V.view(*x_shape)
return V
def idct(X, norm=None):
"""
The inverse to DCT-II, which is a scaled Discrete Cosine Transform, Type III
Our definition of idct is that idct(dct(x)) == x
For the meaning of the parameter `norm`, see:
https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html
:param X: the input signal
:param norm: the normalization, None or 'ortho'
:return: the inverse DCT-II of the signal over the last dimension
"""
x_shape = X.shape
N = x_shape[-1]
X_v = X.contiguous().view(-1, x_shape[-1]) / 2
if norm == 'ortho':
X_v[:, 0] *= np.sqrt(N) * 2
X_v[:, 1:] *= np.sqrt(N / 2) * 2
k = torch.arange(x_shape[-1], dtype=X.dtype, device=X.device)[None, :] * np.pi / (2 * N)
W_r = torch.cos(k)
W_i = torch.sin(k)
V_t_r = X_v
V_t_i = torch.cat([X_v[:, :1] * 0, -X_v.flip([1])[:, :-1]], dim=1)
V_r = V_t_r * W_r - V_t_i * W_i
V_i = V_t_r * W_i + V_t_i * W_r
V = torch.cat([V_r.unsqueeze(2), V_i.unsqueeze(2)], dim=2)
v = idct_irfft_impl(V)
x = v.new_zeros(v.shape)
x[:, ::2] += v[:, :N - (N // 2)]
x[:, 1::2] += v.flip([1])[:, :N // 2]
return x.view(*x_shape)
def dct_2d(x, norm=None):
"""
2-dimentional Discrete Cosine Transform, Type II (a.k.a. the DCT)
For the meaning of the parameter `norm`, see:
https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html
:param x: the input signal
:param norm: the normalization, None or 'ortho'
:return: the DCT-II of the signal over the last 2 dimensions
"""
X1 = dct(x, norm=norm)
X2 = dct(X1.transpose(-1, -2), norm=norm)
return X2.transpose(-1, -2)
def idct_2d(X, norm=None):
"""
The inverse to 2D DCT-II, which is a scaled Discrete Cosine Transform, Type III
Our definition of idct is that idct_2d(dct_2d(x)) == x
For the meaning of the parameter `norm`, see:
https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html
:param X: the input signal
:param norm: the normalization, None or 'ortho'
:return: the DCT-II of the signal over the last 2 dimensions
"""
x1 = idct(X, norm=norm)
x2 = idct(x1.transpose(-1, -2), norm=norm)
return x2.transpose(-1, -2)
def dct_3d(x, norm=None):
"""
3-dimentional Discrete Cosine Transform, Type II (a.k.a. the DCT)
For the meaning of the parameter `norm`, see:
https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html
:param x: the input signal
:param norm: the normalization, None or 'ortho'
:return: the DCT-II of the signal over the last 3 dimensions
"""
X1 = dct(x, norm=norm)
X2 = dct(X1.transpose(-1, -2), norm=norm)
X3 = dct(X2.transpose(-1, -3), norm=norm)
return X3.transpose(-1, -3).transpose(-1, -2)
def idct_3d(X, norm=None):
"""
The inverse to 3D DCT-II, which is a scaled Discrete Cosine Transform, Type III
Our definition of idct is that idct_3d(dct_3d(x)) == x
For the meaning of the parameter `norm`, see:
https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html
:param X: the input signal
:param norm: the normalization, None or 'ortho'
:return: the DCT-II of the signal over the last 3 dimensions
"""
x1 = idct(X, norm=norm)
x2 = idct(x1.transpose(-1, -2), norm=norm)
x3 = idct(x2.transpose(-1, -3), norm=norm)
return x3.transpose(-1, -3).transpose(-1, -2)
class LinearDCT(nn.Linear):
"""Implement any DCT as a linear layer; in practice this executes around
50x faster on GPU. Unfortunately, the DCT matrix is stored, which will
increase memory usage.
:param in_features: size of expected input
:param type: which dct function in this file to use"""
def __init__(self, in_features, type, norm=None, bias=False):
self.type = type
self.N = in_features
self.norm = norm
super(LinearDCT, self).__init__(in_features, in_features, bias=bias)
def reset_parameters(self):
# initialise using dct function
I = torch.eye(self.N)
if self.type == 'dct1':
self.weight.data = dct1(I).data.t()
elif self.type == 'idct1':
self.weight.data = idct1(I).data.t()
elif self.type == 'dct':
self.weight.data = dct(I, norm=self.norm).data.t()
elif self.type == 'idct':
self.weight.data = idct(I, norm=self.norm).data.t()
self.weight.requires_grad = False # don't learn this!
def apply_linear_2d(x, linear_layer):
"""Can be used with a LinearDCT layer to do a 2D DCT.
:param x: the input signal
:param linear_layer: any PyTorch Linear layer
:return: result of linear layer applied to last 2 dimensions
"""
X1 = linear_layer(x)
X2 = linear_layer(X1.transpose(-1, -2))
return X2.transpose(-1, -2)
def apply_linear_3d(x, linear_layer):
"""Can be used with a LinearDCT layer to do a 3D DCT.
:param x: the input signal
:param linear_layer: any PyTorch Linear layer
:return: result of linear layer applied to last 3 dimensions
"""
X1 = linear_layer(x)
X2 = linear_layer(X1.transpose(-1, -2))
X3 = linear_layer(X2.transpose(-1, -3))
return X3.transpose(-1, -3).transpose(-1, -2)
class DCTHelper(nn.Module):
"""
Implement DCT operations and corresponding masking.
"""
def __init__(self, side_length: int, norm: str = None, cutoff: float = 0.8, data_range: tuple = (-1.0, 1.0)):
"""
Args:
side_length: the side length of the image
norm: the normalization, None or 'ortho'
cutoff: the cutoff frequency ratio for low-pass filtering
"""
super().__init__()
self.dct = LinearDCT(side_length, 'dct')
self.idct = LinearDCT(side_length, 'idct')
mask = self.create_circular_mask(side_length, side_length, radius=side_length * cutoff, center=(0, 0))
self.register_buffer('mask', torch.from_numpy(mask).float()[None, None, ...])
self.data_range = data_range
@staticmethod
def create_circular_mask(h, w, center=None, radius=None):
if center is None: # use the middle of the image
center = (int(w/2), int(h/2))
if radius is None: # use the smallest distance between the center and image walls
radius = min(center[0], center[1], w-center[0], h-center[1])
Y, X = np.ogrid[:h, :w]
dist_from_center = np.sqrt((X - center[0])**2 + (Y-center[1])**2)
mask = dist_from_center <= radius
return mask
def run_dct(self, x):
return apply_linear_2d(x, self.dct)
def run_idct(self, x):
return apply_linear_2d(x, self.idct)
def forward(self, x, mode: str = 'dct'):
if mode == 'dct':
return self.run_dct(x)
elif mode == 'idct':
return self.run_idct(x)
else:
raise ValueError(f"Invalid mode: {mode}")
if __name__ == '__main__':
x = torch.Tensor(1000,4096)
x.normal_(0,1)
linear_dct = LinearDCT(4096, 'dct')
error = torch.abs(dct(x) - linear_dct(x))
assert error.max() < 1e-3, (error, error.max())
linear_idct = LinearDCT(4096, 'idct')
error = torch.abs(idct(x) - linear_idct(x))
assert error.max() < 1e-3, (error, error.max())