|
from __future__ import absolute_import |
|
from builtins import zip |
|
import numpy.fft as ffto |
|
from .numpy_wrapper import wrap_namespace |
|
from .numpy_vjps import match_complex |
|
from . import numpy_wrapper as anp |
|
from autograd.extend import primitive, defvjp, vspace |
|
|
|
wrap_namespace(ffto.__dict__, globals()) |
|
|
|
|
|
|
|
def fft_grad(get_args, fft_fun, ans, x, *args, **kwargs): |
|
axes, s, norm = get_args(x, *args, **kwargs) |
|
check_no_repeated_axes(axes) |
|
vs = vspace(x) |
|
return lambda g: match_complex(x, truncate_pad(fft_fun(g, *args, **kwargs), vs.shape)) |
|
|
|
defvjp(fft, lambda *args, **kwargs: |
|
fft_grad(get_fft_args, fft, *args, **kwargs)) |
|
defvjp(ifft, lambda *args, **kwargs: |
|
fft_grad(get_fft_args, ifft, *args, **kwargs)) |
|
|
|
defvjp(fft2, lambda *args, **kwargs: |
|
fft_grad(get_fft_args, fft2, *args, **kwargs)) |
|
defvjp(ifft2, lambda *args, **kwargs: |
|
fft_grad(get_fft_args, ifft2, *args, **kwargs)) |
|
|
|
defvjp(fftn, lambda *args, **kwargs: |
|
fft_grad(get_fft_args, fftn, *args, **kwargs)) |
|
defvjp(ifftn, lambda *args, **kwargs: |
|
fft_grad(get_fft_args, ifftn, *args, **kwargs)) |
|
|
|
def rfft_grad(get_args, irfft_fun, ans, x, *args, **kwargs): |
|
axes, s, norm = get_args(x, *args, **kwargs) |
|
vs = vspace(x) |
|
gvs = vspace(ans) |
|
check_no_repeated_axes(axes) |
|
if s is None: s = [vs.shape[i] for i in axes] |
|
check_even_shape(s) |
|
|
|
|
|
|
|
gs = list(s) |
|
gs[-1] = gs[-1] // 2 + 1 |
|
fac = make_rfft_factors(axes, gvs.shape, gs, s, norm) |
|
def vjp(g): |
|
g = anp.conj(g / fac) |
|
r = match_complex(x, truncate_pad((irfft_fun(g, *args, **kwargs)), vs.shape)) |
|
return r |
|
return vjp |
|
|
|
def irfft_grad(get_args, rfft_fun, ans, x, *args, **kwargs): |
|
axes, gs, norm = get_args(x, *args, **kwargs) |
|
vs = vspace(x) |
|
gvs = vspace(ans) |
|
check_no_repeated_axes(axes) |
|
if gs is None: gs = [gvs.shape[i] for i in axes] |
|
check_even_shape(gs) |
|
|
|
|
|
|
|
s = list(gs) |
|
s[-1] = s[-1] // 2 + 1 |
|
def vjp(g): |
|
r = match_complex(x, truncate_pad((rfft_fun(g, *args, **kwargs)), vs.shape)) |
|
fac = make_rfft_factors(axes, vs.shape, s, gs, norm) |
|
r = anp.conj(r) * fac |
|
return r |
|
return vjp |
|
|
|
defvjp(rfft, lambda *args, **kwargs: |
|
rfft_grad(get_fft_args, irfft, *args, **kwargs)) |
|
|
|
defvjp(irfft, lambda *args, **kwargs: |
|
irfft_grad(get_fft_args, rfft, *args, **kwargs)) |
|
|
|
defvjp(rfft2, lambda *args, **kwargs: |
|
rfft_grad(get_fft2_args, irfft2, *args, **kwargs)) |
|
|
|
defvjp(irfft2, lambda *args, **kwargs: |
|
irfft_grad(get_fft2_args, rfft2, *args, **kwargs)) |
|
|
|
defvjp(rfftn, lambda *args, **kwargs: |
|
rfft_grad(get_fftn_args, irfftn, *args, **kwargs)) |
|
|
|
defvjp(irfftn, lambda *args, **kwargs: |
|
irfft_grad(get_fftn_args, rfftn, *args, **kwargs)) |
|
|
|
defvjp(fftshift, lambda ans, x, axes=None : lambda g: |
|
match_complex(x, anp.conj(ifftshift(anp.conj(g), axes)))) |
|
defvjp(ifftshift, lambda ans, x, axes=None : lambda g: |
|
match_complex(x, anp.conj(fftshift(anp.conj(g), axes)))) |
|
|
|
@primitive |
|
def truncate_pad(x, shape): |
|
|
|
slices = [slice(n) for n in shape] |
|
pads = tuple(zip(anp.zeros(len(shape), dtype=int), |
|
anp.maximum(0, anp.array(shape) - anp.array(x.shape)))) |
|
return anp.pad(x, pads, 'constant')[tuple(slices)] |
|
defvjp(truncate_pad, lambda ans, x, shape: lambda g: |
|
match_complex(x, truncate_pad(g, vspace(x).shape))) |
|
|
|
|
|
def check_no_repeated_axes(axes): |
|
axes_set = set(axes) |
|
if len(axes) != len(axes_set): |
|
raise NotImplementedError("FFT gradient for repeated axes not implemented.") |
|
|
|
def check_even_shape(shape): |
|
if shape[-1] % 2 != 0: |
|
raise NotImplementedError("Real FFT gradient for odd lengthed last axes is not implemented.") |
|
|
|
def get_fft_args(a, d=None, axis=-1, norm=None, *args, **kwargs): |
|
axes = [axis] |
|
if d is not None: d = [d] |
|
return axes, d, norm |
|
|
|
def get_fft2_args(a, s=None, axes=(-2, -1), norm=None, *args, **kwargs): |
|
return axes, s, norm |
|
|
|
def get_fftn_args(a, s=None, axes=None, norm=None, *args, **kwargs): |
|
if axes is None: |
|
axes = list(range(a.ndim)) |
|
return axes, s, norm |
|
|
|
def make_rfft_factors(axes, resshape, facshape, normshape, norm): |
|
""" make the compression factors and compute the normalization |
|
for irfft and rfft. |
|
""" |
|
N = 1.0 |
|
for n in normshape: N = N * n |
|
|
|
|
|
|
|
|
|
|
|
fac = anp.zeros(resshape) |
|
fac[...] = 2 |
|
index = [slice(None)] * len(resshape) |
|
if facshape[-1] <= resshape[axes[-1]]: |
|
index[axes[-1]] = (0, facshape[-1] - 1) |
|
else: |
|
index[axes[-1]] = (0,) |
|
fac[tuple(index)] = 1 |
|
if norm is None: |
|
fac /= N |
|
return fac |
|
|
|
|