from torch.utils.ffi import _wrap_function | |
from ._knn_pytorch import lib as _lib, ffi as _ffi | |
__all__ = [] | |
def _import_symbols(locals): | |
for symbol in dir(_lib): | |
fn = getattr(_lib, symbol) | |
if callable(fn): | |
locals[symbol] = _wrap_function(fn, _ffi) | |
else: | |
locals[symbol] = fn | |
__all__.append(symbol) | |
_import_symbols(locals()) | |