|
import unittest
|
|
import gc
|
|
import operator as op
|
|
import functools
|
|
import torch
|
|
from torch.autograd import Variable, Function
|
|
from lib.knn import knn_pytorch as knn_pytorch
|
|
|
|
class KNearestNeighbor(Function):
|
|
""" Compute k nearest neighbors for each query point.
|
|
"""
|
|
def __init__(self, k):
|
|
self.k = k
|
|
|
|
def forward(self, ref, query):
|
|
ref = ref.float().cuda()
|
|
query = query.float().cuda()
|
|
|
|
inds = torch.empty(query.shape[0], self.k, query.shape[2]).long().cuda()
|
|
|
|
knn_pytorch.knn(ref, query, inds)
|
|
|
|
return inds
|
|
|
|
|
|
class TestKNearestNeighbor(unittest.TestCase):
|
|
|
|
def test_forward(self):
|
|
knn = KNearestNeighbor(2)
|
|
while(1):
|
|
D, N, M = 128, 100, 1000
|
|
ref = Variable(torch.rand(2, D, N))
|
|
query = Variable(torch.rand(2, D, M))
|
|
|
|
inds = knn(ref, query)
|
|
for obj in gc.get_objects():
|
|
if torch.is_tensor(obj):
|
|
print(functools.reduce(op.mul, obj.size()) if len(obj.size()) > 0 else 0, type(obj), obj.size())
|
|
|
|
|
|
print(inds)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
unittest.main()
|
|
|