File size: 3,632 Bytes
ab4488b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 |
import numpy as np
from autograd.extend import VSpace
from autograd.builtins import NamedTupleVSpace
class ArrayVSpace(VSpace):
def __init__(self, value):
value = np.asarray(value)
self.shape = value.shape
self.dtype = value.dtype
@property
def size(self): return np.prod(self.shape)
@property
def ndim(self): return len(self.shape)
def zeros(self): return np.zeros(self.shape, dtype=self.dtype)
def ones(self): return np.ones( self.shape, dtype=self.dtype)
def standard_basis(self):
for idxs in np.ndindex(*self.shape):
vect = np.zeros(self.shape, dtype=self.dtype)
vect[idxs] = 1
yield vect
def randn(self):
return np.array(np.random.randn(*self.shape)).astype(self.dtype)
def _inner_prod(self, x, y):
return np.dot(np.ravel(x), np.ravel(y))
class ComplexArrayVSpace(ArrayVSpace):
iscomplex = True
@property
def size(self): return np.prod(self.shape) * 2
def ones(self):
return ( np.ones(self.shape, dtype=self.dtype)
+ 1.0j * np.ones(self.shape, dtype=self.dtype))
def standard_basis(self):
for idxs in np.ndindex(*self.shape):
for v in [1.0, 1.0j]:
vect = np.zeros(self.shape, dtype=self.dtype)
vect[idxs] = v
yield vect
def randn(self):
return ( np.array(np.random.randn(*self.shape)).astype(self.dtype)
+ 1.0j * np.array(np.random.randn(*self.shape)).astype(self.dtype))
def _inner_prod(self, x, y):
return np.real(np.dot(np.conj(np.ravel(x)), np.ravel(y)))
def _covector(self, x):
return np.conj(x)
VSpace.register(np.ndarray,
lambda x: ComplexArrayVSpace(x)
if np.iscomplexobj(x)
else ArrayVSpace(x))
for type_ in [float, np.longdouble, np.float64, np.float32, np.float16]:
ArrayVSpace.register(type_)
for type_ in [complex, np.clongdouble, np.complex64, np.complex128]:
ComplexArrayVSpace.register(type_)
if np.lib.NumpyVersion(np.__version__) >= '2.0.0':
class EigResultVSpace(NamedTupleVSpace): seq_type = np.linalg._linalg.EigResult
class EighResultVSpace(NamedTupleVSpace): seq_type = np.linalg._linalg.EighResult
class QRResultVSpace(NamedTupleVSpace): seq_type = np.linalg._linalg.QRResult
class SlogdetResultVSpace(NamedTupleVSpace): seq_type = np.linalg._linalg.SlogdetResult
class SVDResultVSpace(NamedTupleVSpace): seq_type = np.linalg._linalg.SVDResult
EigResultVSpace.register(np.linalg._linalg.EigResult)
EighResultVSpace.register(np.linalg._linalg.EighResult)
QRResultVSpace.register(np.linalg._linalg.QRResult)
SlogdetResultVSpace.register(np.linalg._linalg.SlogdetResult)
SVDResultVSpace.register(np.linalg._linalg.SVDResult)
elif np.__version__ >= '1.25':
class EigResultVSpace(NamedTupleVSpace): seq_type = np.linalg.linalg.EigResult
class EighResultVSpace(NamedTupleVSpace): seq_type = np.linalg.linalg.EighResult
class QRResultVSpace(NamedTupleVSpace): seq_type = np.linalg.linalg.QRResult
class SlogdetResultVSpace(NamedTupleVSpace): seq_type = np.linalg.linalg.SlogdetResult
class SVDResultVSpace(NamedTupleVSpace): seq_type = np.linalg.linalg.SVDResult
EigResultVSpace.register(np.linalg.linalg.EigResult)
EighResultVSpace.register(np.linalg.linalg.EighResult)
QRResultVSpace.register(np.linalg.linalg.QRResult)
SlogdetResultVSpace.register(np.linalg.linalg.SlogdetResult)
SVDResultVSpace.register(np.linalg.linalg.SVDResult)
|