import numpy as np from autograd.extend import VSpace from autograd.builtins import NamedTupleVSpace class ArrayVSpace(VSpace): def __init__(self, value): value = np.asarray(value) self.shape = value.shape self.dtype = value.dtype @property def size(self): return np.prod(self.shape) @property def ndim(self): return len(self.shape) def zeros(self): return np.zeros(self.shape, dtype=self.dtype) def ones(self): return np.ones( self.shape, dtype=self.dtype) def standard_basis(self): for idxs in np.ndindex(*self.shape): vect = np.zeros(self.shape, dtype=self.dtype) vect[idxs] = 1 yield vect def randn(self): return np.array(np.random.randn(*self.shape)).astype(self.dtype) def _inner_prod(self, x, y): return np.dot(np.ravel(x), np.ravel(y)) class ComplexArrayVSpace(ArrayVSpace): iscomplex = True @property def size(self): return np.prod(self.shape) * 2 def ones(self): return ( np.ones(self.shape, dtype=self.dtype) + 1.0j * np.ones(self.shape, dtype=self.dtype)) def standard_basis(self): for idxs in np.ndindex(*self.shape): for v in [1.0, 1.0j]: vect = np.zeros(self.shape, dtype=self.dtype) vect[idxs] = v yield vect def randn(self): return ( np.array(np.random.randn(*self.shape)).astype(self.dtype) + 1.0j * np.array(np.random.randn(*self.shape)).astype(self.dtype)) def _inner_prod(self, x, y): return np.real(np.dot(np.conj(np.ravel(x)), np.ravel(y))) def _covector(self, x): return np.conj(x) VSpace.register(np.ndarray, lambda x: ComplexArrayVSpace(x) if np.iscomplexobj(x) else ArrayVSpace(x)) for type_ in [float, np.longdouble, np.float64, np.float32, np.float16]: ArrayVSpace.register(type_) for type_ in [complex, np.clongdouble, np.complex64, np.complex128]: ComplexArrayVSpace.register(type_) if np.lib.NumpyVersion(np.__version__) >= '2.0.0': class EigResultVSpace(NamedTupleVSpace): seq_type = np.linalg._linalg.EigResult class EighResultVSpace(NamedTupleVSpace): seq_type = np.linalg._linalg.EighResult class QRResultVSpace(NamedTupleVSpace): seq_type = np.linalg._linalg.QRResult class SlogdetResultVSpace(NamedTupleVSpace): seq_type = np.linalg._linalg.SlogdetResult class SVDResultVSpace(NamedTupleVSpace): seq_type = np.linalg._linalg.SVDResult EigResultVSpace.register(np.linalg._linalg.EigResult) EighResultVSpace.register(np.linalg._linalg.EighResult) QRResultVSpace.register(np.linalg._linalg.QRResult) SlogdetResultVSpace.register(np.linalg._linalg.SlogdetResult) SVDResultVSpace.register(np.linalg._linalg.SVDResult) elif np.__version__ >= '1.25': class EigResultVSpace(NamedTupleVSpace): seq_type = np.linalg.linalg.EigResult class EighResultVSpace(NamedTupleVSpace): seq_type = np.linalg.linalg.EighResult class QRResultVSpace(NamedTupleVSpace): seq_type = np.linalg.linalg.QRResult class SlogdetResultVSpace(NamedTupleVSpace): seq_type = np.linalg.linalg.SlogdetResult class SVDResultVSpace(NamedTupleVSpace): seq_type = np.linalg.linalg.SVDResult EigResultVSpace.register(np.linalg.linalg.EigResult) EighResultVSpace.register(np.linalg.linalg.EighResult) QRResultVSpace.register(np.linalg.linalg.QRResult) SlogdetResultVSpace.register(np.linalg.linalg.SlogdetResult) SVDResultVSpace.register(np.linalg.linalg.SVDResult)