File size: 1,166 Bytes
29858c0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 |
import unittest
import gc
import operator as op
import functools
import torch
from torch.autograd import Variable, Function
from lib.knn import knn_pytorch as knn_pytorch
class KNearestNeighbor(Function):
""" Compute k nearest neighbors for each query point.
"""
def __init__(self, k):
self.k = k
def forward(self, ref, query):
ref = ref.float().cuda()
query = query.float().cuda()
inds = torch.empty(query.shape[0], self.k, query.shape[2]).long().cuda()
knn_pytorch.knn(ref, query, inds)
return inds
class TestKNearestNeighbor(unittest.TestCase):
def test_forward(self):
knn = KNearestNeighbor(2)
while(1):
D, N, M = 128, 100, 1000
ref = Variable(torch.rand(2, D, N))
query = Variable(torch.rand(2, D, M))
inds = knn(ref, query)
for obj in gc.get_objects():
if torch.is_tensor(obj):
print(functools.reduce(op.mul, obj.size()) if len(obj.size()) > 0 else 0, type(obj), obj.size())
#ref = ref.cpu()
#query = query.cpu()
print(inds)
if __name__ == '__main__':
unittest.main()
|