File size: 3,632 Bytes
ab4488b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import numpy as np
from autograd.extend import VSpace
from autograd.builtins import NamedTupleVSpace

class ArrayVSpace(VSpace):
    def __init__(self, value):
        value = np.asarray(value)
        self.shape = value.shape
        self.dtype = value.dtype

    @property
    def size(self): return np.prod(self.shape)
    @property
    def ndim(self): return len(self.shape)
    def zeros(self): return np.zeros(self.shape, dtype=self.dtype)
    def ones(self):  return np.ones( self.shape, dtype=self.dtype)

    def standard_basis(self):
      for idxs in np.ndindex(*self.shape):
          vect = np.zeros(self.shape, dtype=self.dtype)
          vect[idxs] = 1
          yield vect

    def randn(self):
        return np.array(np.random.randn(*self.shape)).astype(self.dtype)

    def _inner_prod(self, x, y):
        return np.dot(np.ravel(x), np.ravel(y))

class ComplexArrayVSpace(ArrayVSpace):
    iscomplex = True

    @property
    def size(self): return np.prod(self.shape) * 2

    def ones(self):
        return (         np.ones(self.shape, dtype=self.dtype)
                + 1.0j * np.ones(self.shape, dtype=self.dtype))

    def standard_basis(self):
      for idxs in np.ndindex(*self.shape):
          for v in [1.0, 1.0j]:
              vect = np.zeros(self.shape, dtype=self.dtype)
              vect[idxs] = v
              yield vect

    def randn(self):
        return (         np.array(np.random.randn(*self.shape)).astype(self.dtype)
                + 1.0j * np.array(np.random.randn(*self.shape)).astype(self.dtype))

    def _inner_prod(self, x, y):
        return np.real(np.dot(np.conj(np.ravel(x)), np.ravel(y)))

    def _covector(self, x):
        return np.conj(x)

VSpace.register(np.ndarray,
                lambda x: ComplexArrayVSpace(x)
                if np.iscomplexobj(x)
                else ArrayVSpace(x))

for type_ in [float, np.longdouble, np.float64, np.float32, np.float16]:
    ArrayVSpace.register(type_)

for type_ in [complex, np.clongdouble, np.complex64, np.complex128]:
    ComplexArrayVSpace.register(type_)


if np.lib.NumpyVersion(np.__version__) >= '2.0.0':
    class EigResultVSpace(NamedTupleVSpace):     seq_type = np.linalg._linalg.EigResult
    class EighResultVSpace(NamedTupleVSpace):    seq_type = np.linalg._linalg.EighResult
    class QRResultVSpace(NamedTupleVSpace):      seq_type = np.linalg._linalg.QRResult
    class SlogdetResultVSpace(NamedTupleVSpace): seq_type = np.linalg._linalg.SlogdetResult
    class SVDResultVSpace(NamedTupleVSpace):     seq_type = np.linalg._linalg.SVDResult

    EigResultVSpace.register(np.linalg._linalg.EigResult)
    EighResultVSpace.register(np.linalg._linalg.EighResult)
    QRResultVSpace.register(np.linalg._linalg.QRResult)
    SlogdetResultVSpace.register(np.linalg._linalg.SlogdetResult)
    SVDResultVSpace.register(np.linalg._linalg.SVDResult)
elif np.__version__ >= '1.25':
    class EigResultVSpace(NamedTupleVSpace):     seq_type = np.linalg.linalg.EigResult
    class EighResultVSpace(NamedTupleVSpace):    seq_type = np.linalg.linalg.EighResult
    class QRResultVSpace(NamedTupleVSpace):      seq_type = np.linalg.linalg.QRResult
    class SlogdetResultVSpace(NamedTupleVSpace): seq_type = np.linalg.linalg.SlogdetResult
    class SVDResultVSpace(NamedTupleVSpace):     seq_type = np.linalg.linalg.SVDResult

    EigResultVSpace.register(np.linalg.linalg.EigResult)
    EighResultVSpace.register(np.linalg.linalg.EighResult)
    QRResultVSpace.register(np.linalg.linalg.QRResult)
    SlogdetResultVSpace.register(np.linalg.linalg.SlogdetResult)
    SVDResultVSpace.register(np.linalg.linalg.SVDResult)