File size: 5,250 Bytes
ab4488b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
from __future__ import absolute_import
from builtins import zip
import numpy.fft as ffto
from .numpy_wrapper import wrap_namespace
from .numpy_vjps import match_complex
from . import numpy_wrapper as anp
from autograd.extend import primitive, defvjp, vspace

wrap_namespace(ffto.__dict__, globals())

# TODO: make fft gradient work for a repeated axis,
# e.g. by replacing fftn with repeated calls to 1d fft along each axis
def fft_grad(get_args, fft_fun, ans, x, *args, **kwargs):
    axes, s, norm = get_args(x, *args, **kwargs)
    check_no_repeated_axes(axes)
    vs = vspace(x)
    return lambda g: match_complex(x, truncate_pad(fft_fun(g, *args, **kwargs), vs.shape))

defvjp(fft, lambda *args, **kwargs:
        fft_grad(get_fft_args, fft, *args, **kwargs))
defvjp(ifft, lambda *args, **kwargs:
        fft_grad(get_fft_args, ifft, *args, **kwargs))

defvjp(fft2, lambda *args, **kwargs:
        fft_grad(get_fft_args, fft2, *args, **kwargs))
defvjp(ifft2, lambda *args, **kwargs:
        fft_grad(get_fft_args, ifft2, *args, **kwargs))

defvjp(fftn, lambda *args, **kwargs:
        fft_grad(get_fft_args, fftn, *args, **kwargs))
defvjp(ifftn, lambda *args, **kwargs:
        fft_grad(get_fft_args, ifftn, *args, **kwargs))

def rfft_grad(get_args, irfft_fun, ans, x, *args, **kwargs):
    axes, s, norm = get_args(x, *args, **kwargs)
    vs = vspace(x)
    gvs = vspace(ans)
    check_no_repeated_axes(axes)
    if s is None: s = [vs.shape[i] for i in axes]
    check_even_shape(s)

    # s is the full fft shape
    # gs is the compressed shape
    gs = list(s)
    gs[-1] = gs[-1] // 2  + 1
    fac = make_rfft_factors(axes, gvs.shape, gs, s, norm)
    def vjp(g):
        g = anp.conj(g / fac)
        r = match_complex(x, truncate_pad((irfft_fun(g, *args, **kwargs)), vs.shape))
        return r
    return vjp

def irfft_grad(get_args, rfft_fun, ans, x, *args, **kwargs):
    axes, gs, norm = get_args(x, *args, **kwargs)
    vs = vspace(x)
    gvs = vspace(ans)
    check_no_repeated_axes(axes)
    if gs is None: gs = [gvs.shape[i] for i in axes]
    check_even_shape(gs)

    # gs is the full fft shape
    # s is the compressed shape
    s = list(gs)
    s[-1] = s[-1] // 2 + 1
    def vjp(g):
        r = match_complex(x, truncate_pad((rfft_fun(g,  *args, **kwargs)), vs.shape))
        fac = make_rfft_factors(axes, vs.shape, s, gs, norm)
        r = anp.conj(r) * fac
        return r
    return vjp

defvjp(rfft, lambda *args, **kwargs:
        rfft_grad(get_fft_args, irfft, *args, **kwargs))

defvjp(irfft, lambda *args, **kwargs:
        irfft_grad(get_fft_args, rfft, *args, **kwargs))

defvjp(rfft2, lambda *args, **kwargs:
        rfft_grad(get_fft2_args, irfft2, *args, **kwargs))

defvjp(irfft2, lambda *args, **kwargs:
        irfft_grad(get_fft2_args, rfft2, *args, **kwargs))

defvjp(rfftn, lambda *args, **kwargs:
        rfft_grad(get_fftn_args, irfftn, *args, **kwargs))

defvjp(irfftn, lambda *args, **kwargs:
        irfft_grad(get_fftn_args, rfftn, *args, **kwargs))

defvjp(fftshift,  lambda ans, x, axes=None : lambda g:
                 match_complex(x, anp.conj(ifftshift(anp.conj(g), axes))))
defvjp(ifftshift, lambda ans, x, axes=None : lambda g:
                 match_complex(x, anp.conj(fftshift(anp.conj(g), axes))))

@primitive
def truncate_pad(x, shape):
    # truncate/pad x to have the appropriate shape
    slices = [slice(n) for n in shape]
    pads = tuple(zip(anp.zeros(len(shape), dtype=int),
               anp.maximum(0, anp.array(shape) - anp.array(x.shape))))
    return anp.pad(x, pads, 'constant')[tuple(slices)]
defvjp(truncate_pad, lambda ans, x, shape: lambda g:
       match_complex(x, truncate_pad(g, vspace(x).shape)))

## TODO: could be made less stringent, to fail only when repeated axis has different values of s
def check_no_repeated_axes(axes):
    axes_set = set(axes)
    if len(axes) != len(axes_set):
        raise NotImplementedError("FFT gradient for repeated axes not implemented.")

def check_even_shape(shape):
    if shape[-1] % 2 != 0:
        raise NotImplementedError("Real FFT gradient for odd lengthed last axes is not implemented.")

def get_fft_args(a, d=None, axis=-1, norm=None, *args, **kwargs):
    axes = [axis]
    if d is not None: d = [d]
    return axes, d, norm

def get_fft2_args(a, s=None, axes=(-2, -1), norm=None, *args, **kwargs):
    return axes, s, norm

def get_fftn_args(a, s=None, axes=None, norm=None, *args, **kwargs):
    if axes is None:
        axes = list(range(a.ndim))
    return axes, s, norm

def make_rfft_factors(axes, resshape, facshape, normshape, norm):
    """ make the compression factors and compute the normalization
        for irfft and rfft.
    """
    N = 1.0
    for n in normshape: N = N * n

    # inplace modification is fine because we produce a constant
    # which doesn't go into autograd.
    # For same reason could have used numpy rather than anp.
    # but we already imported anp, so use it instead.
    fac = anp.zeros(resshape)
    fac[...] = 2
    index = [slice(None)] * len(resshape)
    if facshape[-1] <= resshape[axes[-1]]:
        index[axes[-1]] = (0, facshape[-1] - 1)
    else:
        index[axes[-1]] = (0,)
    fac[tuple(index)] = 1
    if norm is None:
        fac /= N
    return fac