import argparse import os import random import torch import torch.nn as nn import torch.nn.parallel import torch.backends.cudnn as cudnn import torch.optim as optim import torch.utils.data import torchvision.transforms as transforms import torchvision.utils as vutils from torch.autograd import Variable from PIL import Image import numpy as np import pdb import torch.nn.functional as F from lib.pspnet import PSPNet psp_models = { 'resnet18': lambda: PSPNet(sizes=(1, 2, 3, 6), psp_size=512, deep_features_size=256, backend='resnet18'), 'resnet34': lambda: PSPNet(sizes=(1, 2, 3, 6), psp_size=512, deep_features_size=256, backend='resnet34'), 'resnet50': lambda: PSPNet(sizes=(1, 2, 3, 6), psp_size=2048, deep_features_size=1024, backend='resnet50'), 'resnet101': lambda: PSPNet(sizes=(1, 2, 3, 6), psp_size=2048, deep_features_size=1024, backend='resnet101'), 'resnet152': lambda: PSPNet(sizes=(1, 2, 3, 6), psp_size=2048, deep_features_size=1024, backend='resnet152') } class ModifiedResnet(nn.Module): def __init__(self, usegpu=True): super(ModifiedResnet, self).__init__() self.model = psp_models['resnet18'.lower()]() self.model = nn.DataParallel(self.model) def forward(self, x): x = self.model(x) return x class PoseNetFeat(nn.Module): def __init__(self, num_points): super(PoseNetFeat, self).__init__() self.conv1 = torch.nn.Conv1d(3, 64, 1) self.conv2 = torch.nn.Conv1d(64, 128, 1) self.e_conv1 = torch.nn.Conv1d(32, 64, 1) self.e_conv2 = torch.nn.Conv1d(64, 128, 1) self.conv5 = torch.nn.Conv1d(256, 512, 1) self.conv6 = torch.nn.Conv1d(512, 1024, 1) self.ap1 = torch.nn.AvgPool1d(num_points) self.num_points = num_points def forward(self, x, emb): x = F.relu(self.conv1(x)) emb = F.relu(self.e_conv1(emb)) pointfeat_1 = torch.cat((x, emb), dim=1) x = F.relu(self.conv2(x)) emb = F.relu(self.e_conv2(emb)) pointfeat_2 = torch.cat((x, emb), dim=1) x = F.relu(self.conv5(pointfeat_2)) x = F.relu(self.conv6(x)) ap_x = self.ap1(x) ap_x = ap_x.view(-1, 1024, 1).repeat(1, 1, self.num_points) return torch.cat([pointfeat_1, pointfeat_2, ap_x], 1) #128 + 256 + 1024 class PoseNet(nn.Module): def __init__(self, num_points, num_obj): super(PoseNet, self).__init__() self.num_points = num_points self.cnn = ModifiedResnet() self.feat = PoseNetFeat(num_points) self.conv1_r = torch.nn.Conv1d(1408, 640, 1) self.conv1_t = torch.nn.Conv1d(1408, 640, 1) self.conv1_c = torch.nn.Conv1d(1408, 640, 1) self.conv2_r = torch.nn.Conv1d(640, 256, 1) self.conv2_t = torch.nn.Conv1d(640, 256, 1) self.conv2_c = torch.nn.Conv1d(640, 256, 1) self.conv3_r = torch.nn.Conv1d(256, 128, 1) self.conv3_t = torch.nn.Conv1d(256, 128, 1) self.conv3_c = torch.nn.Conv1d(256, 128, 1) self.conv4_r = torch.nn.Conv1d(128, num_obj*4, 1) #quaternion self.conv4_t = torch.nn.Conv1d(128, num_obj*3, 1) #translation self.conv4_c = torch.nn.Conv1d(128, num_obj*1, 1) #confidence self.num_obj = num_obj def forward(self, img, x, choose, obj): out_img = self.cnn(img) bs, di, _, _ = out_img.size() emb = out_img.view(bs, di, -1) choose = choose.repeat(1, di, 1) emb = torch.gather(emb, 2, choose).contiguous() x = x.transpose(2, 1).contiguous() ap_x = self.feat(x, emb) rx = F.relu(self.conv1_r(ap_x)) tx = F.relu(self.conv1_t(ap_x)) cx = F.relu(self.conv1_c(ap_x)) rx = F.relu(self.conv2_r(rx)) tx = F.relu(self.conv2_t(tx)) cx = F.relu(self.conv2_c(cx)) rx = F.relu(self.conv3_r(rx)) tx = F.relu(self.conv3_t(tx)) cx = F.relu(self.conv3_c(cx)) rx = self.conv4_r(rx).view(bs, self.num_obj, 4, self.num_points) tx = self.conv4_t(tx).view(bs, self.num_obj, 3, self.num_points) cx = torch.sigmoid(self.conv4_c(cx)).view(bs, self.num_obj, 1, self.num_points) b = 0 out_rx = torch.index_select(rx[b], 0, obj[b]) out_tx = torch.index_select(tx[b], 0, obj[b]) out_cx = torch.index_select(cx[b], 0, obj[b]) out_rx = out_rx.contiguous().transpose(2, 1).contiguous() out_cx = out_cx.contiguous().transpose(2, 1).contiguous() out_tx = out_tx.contiguous().transpose(2, 1).contiguous() return out_rx, out_tx, out_cx, emb.detach() class PoseRefineNetFeat(nn.Module): def __init__(self, num_points): super(PoseRefineNetFeat, self).__init__() self.conv1 = torch.nn.Conv1d(3, 64, 1) self.conv2 = torch.nn.Conv1d(64, 128, 1) self.e_conv1 = torch.nn.Conv1d(32, 64, 1) self.e_conv2 = torch.nn.Conv1d(64, 128, 1) self.conv5 = torch.nn.Conv1d(384, 512, 1) self.conv6 = torch.nn.Conv1d(512, 1024, 1) self.ap1 = torch.nn.AvgPool1d(num_points) self.num_points = num_points def forward(self, x, emb): x = F.relu(self.conv1(x)) emb = F.relu(self.e_conv1(emb)) pointfeat_1 = torch.cat([x, emb], dim=1) x = F.relu(self.conv2(x)) emb = F.relu(self.e_conv2(emb)) pointfeat_2 = torch.cat([x, emb], dim=1) pointfeat_3 = torch.cat([pointfeat_1, pointfeat_2], dim=1) x = F.relu(self.conv5(pointfeat_3)) x = F.relu(self.conv6(x)) ap_x = self.ap1(x) ap_x = ap_x.view(-1, 1024) return ap_x class PoseRefineNet(nn.Module): def __init__(self, num_points, num_obj): super(PoseRefineNet, self).__init__() self.num_points = num_points self.feat = PoseRefineNetFeat(num_points) self.conv1_r = torch.nn.Linear(1024, 512) self.conv1_t = torch.nn.Linear(1024, 512) self.conv2_r = torch.nn.Linear(512, 128) self.conv2_t = torch.nn.Linear(512, 128) self.conv3_r = torch.nn.Linear(128, num_obj*4) #quaternion self.conv3_t = torch.nn.Linear(128, num_obj*3) #translation self.num_obj = num_obj def forward(self, x, emb, obj): bs = x.size()[0] x = x.transpose(2, 1).contiguous() ap_x = self.feat(x, emb) rx = F.relu(self.conv1_r(ap_x)) tx = F.relu(self.conv1_t(ap_x)) rx = F.relu(self.conv2_r(rx)) tx = F.relu(self.conv2_t(tx)) rx = self.conv3_r(rx).view(bs, self.num_obj, 4) tx = self.conv3_t(tx).view(bs, self.num_obj, 3) b = 0 out_rx = torch.index_select(rx[b], 0, obj[b]) out_tx = torch.index_select(tx[b], 0, obj[b]) return out_rx, out_tx