|
import argparse
|
|
import os
|
|
import random
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.parallel
|
|
import torch.backends.cudnn as cudnn
|
|
import torch.optim as optim
|
|
import torch.utils.data
|
|
import torchvision.transforms as transforms
|
|
import torchvision.utils as vutils
|
|
from torch.autograd import Variable
|
|
from PIL import Image
|
|
import numpy as np
|
|
import pdb
|
|
import torch.nn.functional as F
|
|
from lib.pspnet import PSPNet
|
|
|
|
psp_models = {
|
|
'resnet18': lambda: PSPNet(sizes=(1, 2, 3, 6), psp_size=512, deep_features_size=256, backend='resnet18'),
|
|
'resnet34': lambda: PSPNet(sizes=(1, 2, 3, 6), psp_size=512, deep_features_size=256, backend='resnet34'),
|
|
'resnet50': lambda: PSPNet(sizes=(1, 2, 3, 6), psp_size=2048, deep_features_size=1024, backend='resnet50'),
|
|
'resnet101': lambda: PSPNet(sizes=(1, 2, 3, 6), psp_size=2048, deep_features_size=1024, backend='resnet101'),
|
|
'resnet152': lambda: PSPNet(sizes=(1, 2, 3, 6), psp_size=2048, deep_features_size=1024, backend='resnet152')
|
|
}
|
|
|
|
class ModifiedResnet(nn.Module):
|
|
|
|
def __init__(self, usegpu=True):
|
|
super(ModifiedResnet, self).__init__()
|
|
|
|
self.model = psp_models['resnet18'.lower()]()
|
|
self.model = nn.DataParallel(self.model)
|
|
|
|
def forward(self, x):
|
|
x = self.model(x)
|
|
return x
|
|
|
|
class PoseNetFeat(nn.Module):
|
|
def __init__(self, num_points):
|
|
super(PoseNetFeat, self).__init__()
|
|
self.conv1 = torch.nn.Conv1d(3, 64, 1)
|
|
self.conv2 = torch.nn.Conv1d(64, 128, 1)
|
|
|
|
self.e_conv1 = torch.nn.Conv1d(32, 64, 1)
|
|
self.e_conv2 = torch.nn.Conv1d(64, 128, 1)
|
|
|
|
self.conv5 = torch.nn.Conv1d(256, 512, 1)
|
|
self.conv6 = torch.nn.Conv1d(512, 1024, 1)
|
|
|
|
self.ap1 = torch.nn.AvgPool1d(num_points)
|
|
self.num_points = num_points
|
|
def forward(self, x, emb):
|
|
x = F.relu(self.conv1(x))
|
|
emb = F.relu(self.e_conv1(emb))
|
|
pointfeat_1 = torch.cat((x, emb), dim=1)
|
|
|
|
x = F.relu(self.conv2(x))
|
|
emb = F.relu(self.e_conv2(emb))
|
|
pointfeat_2 = torch.cat((x, emb), dim=1)
|
|
|
|
x = F.relu(self.conv5(pointfeat_2))
|
|
x = F.relu(self.conv6(x))
|
|
|
|
ap_x = self.ap1(x)
|
|
|
|
ap_x = ap_x.view(-1, 1024, 1).repeat(1, 1, self.num_points)
|
|
return torch.cat([pointfeat_1, pointfeat_2, ap_x], 1)
|
|
|
|
class PoseNet(nn.Module):
|
|
def __init__(self, num_points, num_obj):
|
|
super(PoseNet, self).__init__()
|
|
self.num_points = num_points
|
|
self.cnn = ModifiedResnet()
|
|
self.feat = PoseNetFeat(num_points)
|
|
|
|
self.conv1_r = torch.nn.Conv1d(1408, 640, 1)
|
|
self.conv1_t = torch.nn.Conv1d(1408, 640, 1)
|
|
self.conv1_c = torch.nn.Conv1d(1408, 640, 1)
|
|
|
|
self.conv2_r = torch.nn.Conv1d(640, 256, 1)
|
|
self.conv2_t = torch.nn.Conv1d(640, 256, 1)
|
|
self.conv2_c = torch.nn.Conv1d(640, 256, 1)
|
|
|
|
self.conv3_r = torch.nn.Conv1d(256, 128, 1)
|
|
self.conv3_t = torch.nn.Conv1d(256, 128, 1)
|
|
self.conv3_c = torch.nn.Conv1d(256, 128, 1)
|
|
|
|
self.conv4_r = torch.nn.Conv1d(128, num_obj*4, 1)
|
|
self.conv4_t = torch.nn.Conv1d(128, num_obj*3, 1)
|
|
self.conv4_c = torch.nn.Conv1d(128, num_obj*1, 1)
|
|
|
|
self.num_obj = num_obj
|
|
|
|
def forward(self, img, x, choose, obj):
|
|
out_img = self.cnn(img)
|
|
|
|
bs, di, _, _ = out_img.size()
|
|
|
|
emb = out_img.view(bs, di, -1)
|
|
choose = choose.repeat(1, di, 1)
|
|
emb = torch.gather(emb, 2, choose).contiguous()
|
|
|
|
x = x.transpose(2, 1).contiguous()
|
|
ap_x = self.feat(x, emb)
|
|
|
|
rx = F.relu(self.conv1_r(ap_x))
|
|
tx = F.relu(self.conv1_t(ap_x))
|
|
cx = F.relu(self.conv1_c(ap_x))
|
|
|
|
rx = F.relu(self.conv2_r(rx))
|
|
tx = F.relu(self.conv2_t(tx))
|
|
cx = F.relu(self.conv2_c(cx))
|
|
|
|
rx = F.relu(self.conv3_r(rx))
|
|
tx = F.relu(self.conv3_t(tx))
|
|
cx = F.relu(self.conv3_c(cx))
|
|
|
|
rx = self.conv4_r(rx).view(bs, self.num_obj, 4, self.num_points)
|
|
tx = self.conv4_t(tx).view(bs, self.num_obj, 3, self.num_points)
|
|
cx = torch.sigmoid(self.conv4_c(cx)).view(bs, self.num_obj, 1, self.num_points)
|
|
|
|
b = 0
|
|
out_rx = torch.index_select(rx[b], 0, obj[b])
|
|
out_tx = torch.index_select(tx[b], 0, obj[b])
|
|
out_cx = torch.index_select(cx[b], 0, obj[b])
|
|
|
|
out_rx = out_rx.contiguous().transpose(2, 1).contiguous()
|
|
out_cx = out_cx.contiguous().transpose(2, 1).contiguous()
|
|
out_tx = out_tx.contiguous().transpose(2, 1).contiguous()
|
|
|
|
return out_rx, out_tx, out_cx, emb.detach()
|
|
|
|
|
|
|
|
class PoseRefineNetFeat(nn.Module):
|
|
def __init__(self, num_points):
|
|
super(PoseRefineNetFeat, self).__init__()
|
|
self.conv1 = torch.nn.Conv1d(3, 64, 1)
|
|
self.conv2 = torch.nn.Conv1d(64, 128, 1)
|
|
|
|
self.e_conv1 = torch.nn.Conv1d(32, 64, 1)
|
|
self.e_conv2 = torch.nn.Conv1d(64, 128, 1)
|
|
|
|
self.conv5 = torch.nn.Conv1d(384, 512, 1)
|
|
self.conv6 = torch.nn.Conv1d(512, 1024, 1)
|
|
|
|
self.ap1 = torch.nn.AvgPool1d(num_points)
|
|
self.num_points = num_points
|
|
|
|
def forward(self, x, emb):
|
|
x = F.relu(self.conv1(x))
|
|
emb = F.relu(self.e_conv1(emb))
|
|
pointfeat_1 = torch.cat([x, emb], dim=1)
|
|
|
|
x = F.relu(self.conv2(x))
|
|
emb = F.relu(self.e_conv2(emb))
|
|
pointfeat_2 = torch.cat([x, emb], dim=1)
|
|
|
|
pointfeat_3 = torch.cat([pointfeat_1, pointfeat_2], dim=1)
|
|
|
|
x = F.relu(self.conv5(pointfeat_3))
|
|
x = F.relu(self.conv6(x))
|
|
|
|
ap_x = self.ap1(x)
|
|
|
|
ap_x = ap_x.view(-1, 1024)
|
|
return ap_x
|
|
|
|
class PoseRefineNet(nn.Module):
|
|
def __init__(self, num_points, num_obj):
|
|
super(PoseRefineNet, self).__init__()
|
|
self.num_points = num_points
|
|
self.feat = PoseRefineNetFeat(num_points)
|
|
|
|
self.conv1_r = torch.nn.Linear(1024, 512)
|
|
self.conv1_t = torch.nn.Linear(1024, 512)
|
|
|
|
self.conv2_r = torch.nn.Linear(512, 128)
|
|
self.conv2_t = torch.nn.Linear(512, 128)
|
|
|
|
self.conv3_r = torch.nn.Linear(128, num_obj*4)
|
|
self.conv3_t = torch.nn.Linear(128, num_obj*3)
|
|
|
|
self.num_obj = num_obj
|
|
|
|
def forward(self, x, emb, obj):
|
|
bs = x.size()[0]
|
|
|
|
x = x.transpose(2, 1).contiguous()
|
|
ap_x = self.feat(x, emb)
|
|
|
|
rx = F.relu(self.conv1_r(ap_x))
|
|
tx = F.relu(self.conv1_t(ap_x))
|
|
|
|
rx = F.relu(self.conv2_r(rx))
|
|
tx = F.relu(self.conv2_t(tx))
|
|
|
|
rx = self.conv3_r(rx).view(bs, self.num_obj, 4)
|
|
tx = self.conv3_t(tx).view(bs, self.num_obj, 3)
|
|
|
|
b = 0
|
|
out_rx = torch.index_select(rx[b], 0, obj[b])
|
|
out_tx = torch.index_select(tx[b], 0, obj[b])
|
|
|
|
return out_rx, out_tx
|
|
|