Spaces:
Running
on
Zero
Running
on
Zero
import numpy as np | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from torch.nn import init | |
import einops | |
def conv3x3(in_channels, out_channels, stride=1, | |
padding=1, bias=True, groups=1): | |
return nn.Conv2d( | |
in_channels, | |
out_channels, | |
kernel_size=3, | |
stride=stride, | |
padding=padding, | |
bias=bias, | |
groups=groups) | |
def upconv2x2(in_channels, out_channels, mode='transpose'): | |
if mode == 'transpose': | |
return nn.ConvTranspose2d( | |
in_channels, | |
out_channels, | |
kernel_size=2, | |
stride=2) | |
else: | |
# out_channels is always going to be the same | |
# as in_channels | |
return nn.Sequential( | |
nn.Upsample(mode='bilinear', scale_factor=2), | |
conv1x1(in_channels, out_channels)) | |
def conv1x1(in_channels, out_channels, groups=1): | |
return nn.Conv2d( | |
in_channels, | |
out_channels, | |
kernel_size=1, | |
groups=groups, | |
stride=1) | |
class ConvTriplane3dAware(nn.Module): | |
""" 3D aware triplane conv (as described in RODIN) """ | |
def __init__(self, internal_conv_f, in_channels, out_channels, order='xz'): | |
""" | |
Args: | |
internal_conv_f: function that should return a 2D convolution Module | |
given in and out channels | |
order: if triplane input is in 'xz' order | |
""" | |
super(ConvTriplane3dAware, self).__init__() | |
# Need 3 seperate convolutions | |
self.in_channels = in_channels | |
self.out_channels = out_channels | |
assert order in ['xz', 'zx'] | |
self.order = order | |
# Going to stack from other planes | |
self.plane_convs = nn.ModuleList([ | |
internal_conv_f(3*self.in_channels, self.out_channels) for _ in range(3)]) | |
def forward(self, triplanes_list): | |
""" | |
Args: | |
triplanes_list: [(B,Ci,H,W)]*3 in xy,yz,(zx or xz) depending on order | |
Returns: | |
out_triplanes_list: [(B,Co,H,W)]*3 in xy,yz,(zx or xz) depending on order | |
""" | |
inps = list(triplanes_list) | |
xp = 1 #(yz) | |
yp = 2 #(zx) | |
zp = 0 #(xy) | |
if self.order == 'xz': | |
# get into zx order | |
inps[yp] = einops.rearrange(inps[yp], 'b c x z -> b c z x') | |
oplanes = [None]*3 | |
# order shouldn't matter | |
for iplane in [zp, xp, yp]: | |
# i_plane -> (j,k) | |
# need to average out i and convert to (j,k) | |
# j_plane -> (k,i) | |
# k_plane -> (i,j) | |
jplane = (iplane+1)%3 | |
kplane = (iplane+2)%3 | |
ifeat = inps[iplane] | |
# need to average out nonshared dim | |
# Average pool across | |
# j_plane -> (k,i) -> (k,1) -> (1,k) -> (j,k) | |
# b c k i -> b c k 1 | |
jpool = torch.mean(inps[jplane], dim=3 ,keepdim=True) | |
jpool = einops.rearrange(jpool, 'b c k 1 -> b c 1 k') | |
jpool = einops.repeat(jpool, 'b c 1 k -> b c j k', j=ifeat.size(2)) | |
# k_plane -> (i,j) -> (1,j) -> (j,1) -> (j,k) | |
# b c i j -> b c 1 j | |
kpool = torch.mean(inps[kplane], dim=2 ,keepdim=True) | |
kpool = einops.rearrange(kpool, 'b c 1 j -> b c j 1') | |
kpool = einops.repeat(kpool, 'b c j 1 -> b c j k', k=ifeat.size(3)) | |
# b c h w | |
# jpool = jpool.expand_as(ifeat) | |
# kpool = kpool.expand_as(ifeat) | |
# concat and conv on feature dim | |
catfeat = torch.cat([ifeat, jpool, kpool], dim=1) | |
oplane = self.plane_convs[iplane](catfeat) | |
oplanes[iplane] = oplane | |
if self.order == 'xz': | |
# get back into xz order | |
oplanes[yp] = einops.rearrange(oplanes[yp], 'b c z x -> b c x z') | |
return oplanes | |
def roll_triplanes(triplanes_list): | |
# B, C, tri, h, w | |
tristack = torch.stack((triplanes_list),dim=2) | |
return einops.rearrange(tristack, 'b c tri h w -> b c (tri h) w', tri=3) | |
def unroll_triplanes(rolled_triplane): | |
# B, C, tri*h, w | |
tristack = einops.rearrange(rolled_triplane, 'b c (tri h) w -> b c tri h w', tri=3) | |
return torch.unbind(tristack, dim=2) | |
def conv1x1triplane3daware(in_channels, out_channels, order='xz', **kwargs): | |
return ConvTriplane3dAware(lambda inp, out: conv1x1(inp,out,**kwargs), | |
in_channels, out_channels,order=order) | |
def Normalize(in_channels, num_groups=32): | |
num_groups = min(in_channels, num_groups) # avoid error if in_channels < 32 | |
return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True) | |
def nonlinearity(x): | |
# return F.relu(x) | |
# Swish | |
return x*torch.sigmoid(x) | |
class Upsample(nn.Module): | |
def __init__(self, in_channels, with_conv): | |
super().__init__() | |
self.with_conv = with_conv | |
if self.with_conv: | |
self.conv = torch.nn.Conv2d(in_channels, | |
in_channels, | |
kernel_size=3, | |
stride=1, | |
padding=1) | |
def forward(self, x): | |
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") | |
if self.with_conv: | |
x = self.conv(x) | |
return x | |
class Downsample(nn.Module): | |
def __init__(self, in_channels, with_conv): | |
super().__init__() | |
self.with_conv = with_conv | |
if self.with_conv: | |
# no asymmetric padding in torch conv, must do it ourselves | |
self.conv = torch.nn.Conv2d(in_channels, | |
in_channels, | |
kernel_size=3, | |
stride=2, | |
padding=0) | |
def forward(self, x): | |
if self.with_conv: | |
pad = (0,1,0,1) | |
x = torch.nn.functional.pad(x, pad, mode="constant", value=0) | |
x = self.conv(x) | |
else: | |
x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) | |
return x | |
class ResnetBlock3dAware(nn.Module): | |
def __init__(self, in_channels, out_channels=None): | |
#, conv_shortcut=False): | |
super().__init__() | |
self.in_channels = in_channels | |
out_channels = in_channels if out_channels is None else out_channels | |
self.out_channels = out_channels | |
# self.use_conv_shortcut = conv_shortcut | |
self.norm1 = Normalize(in_channels) | |
self.conv1 = conv3x3(self.in_channels, self.out_channels) | |
self.norm_mid = Normalize(out_channels) | |
self.conv_3daware = conv1x1triplane3daware(self.out_channels, self.out_channels) | |
self.norm2 = Normalize(out_channels) | |
self.conv2 = conv3x3(self.out_channels, self.out_channels) | |
if self.in_channels != self.out_channels: | |
self.nin_shortcut = torch.nn.Conv2d(in_channels, | |
out_channels, | |
kernel_size=1, | |
stride=1, | |
padding=0) | |
def forward(self, x): | |
# 3x3 plane comm | |
h = x | |
h = self.norm1(h) | |
h = nonlinearity(h) | |
h = self.conv1(h) | |
# 1x1 3d aware, crossplane comm | |
h = self.norm_mid(h) | |
h = nonlinearity(h) | |
h = unroll_triplanes(h) | |
h = self.conv_3daware(h) | |
h = roll_triplanes(h) | |
# 3x3 plane comm | |
h = self.norm2(h) | |
h = nonlinearity(h) | |
h = self.conv2(h) | |
if self.in_channels != self.out_channels: | |
x = self.nin_shortcut(x) | |
return x+h | |
class DownConv3dAware(nn.Module): | |
""" | |
A helper Module that performs 2 convolutions and 1 MaxPool. | |
A ReLU activation follows each convolution. | |
""" | |
def __init__(self, in_channels, out_channels, downsample=True, with_conv=False): | |
super(DownConv3dAware, self).__init__() | |
self.in_channels = in_channels | |
self.out_channels = out_channels | |
self.block = ResnetBlock3dAware(in_channels=in_channels, | |
out_channels=out_channels) | |
self.do_downsample = downsample | |
self.downsample = Downsample(out_channels, with_conv=with_conv) | |
def forward(self, x): | |
""" | |
rolled input, rolled output | |
Args: | |
x: rolled (b c (tri*h) w) | |
""" | |
x = self.block(x) | |
before_pool = x | |
# if self.pooling: | |
# x = self.pool(x) | |
if self.do_downsample: | |
# unroll and cat channel-wise (to prevent pooling across triplane boundaries) | |
x = einops.rearrange(x, 'b c (tri h) w -> b (c tri) h w', tri=3) | |
x = self.downsample(x) | |
# undo | |
x = einops.rearrange(x, 'b (c tri) h w -> b c (tri h) w', tri=3) | |
return x, before_pool | |
class UpConv3dAware(nn.Module): | |
""" | |
A helper Module that performs 2 convolutions and 1 UpConvolution. | |
A ReLU activation follows each convolution. | |
""" | |
def __init__(self, in_channels, out_channels, | |
merge_mode='concat', with_conv=False): #up_mode='transpose', ): | |
super(UpConv3dAware, self).__init__() | |
self.in_channels = in_channels | |
self.out_channels = out_channels | |
self.merge_mode = merge_mode | |
self.upsample = Upsample(in_channels, with_conv) | |
if self.merge_mode == 'concat': | |
self.norm1 = Normalize(in_channels+out_channels) | |
self.block = ResnetBlock3dAware(in_channels=in_channels+out_channels, | |
out_channels=out_channels) | |
else: | |
self.norm1 = Normalize(in_channels) | |
self.block = ResnetBlock3dAware(in_channels=in_channels, | |
out_channels=out_channels) | |
def forward(self, from_down, from_up): | |
""" Forward pass | |
rolled inputs, rolled output | |
rolled (b c (tri*h) w) | |
Arguments: | |
from_down: tensor from the encoder pathway | |
from_up: upconv'd tensor from the decoder pathway | |
""" | |
# from_up = self.upconv(from_up) | |
from_up = self.upsample(from_up) | |
if self.merge_mode == 'concat': | |
x = torch.cat((from_up, from_down), 1) | |
else: | |
x = from_up + from_down | |
x = self.norm1(x) | |
x = self.block(x) | |
return x | |
class UNetTriplane3dAware(nn.Module): | |
def __init__(self, out_channels, in_channels=3, depth=5, | |
start_filts=64,# up_mode='transpose', | |
use_initial_conv=False, | |
merge_mode='concat', **kwargs): | |
""" | |
Arguments: | |
in_channels: int, number of channels in the input tensor. | |
Default is 3 for RGB images. | |
depth: int, number of MaxPools in the U-Net. | |
start_filts: int, number of convolutional filters for the | |
first conv. | |
""" | |
super(UNetTriplane3dAware, self).__init__() | |
self.out_channels = out_channels | |
self.in_channels = in_channels | |
self.start_filts = start_filts | |
self.depth = depth | |
self.use_initial_conv = use_initial_conv | |
if use_initial_conv: | |
self.conv_initial = conv1x1(self.in_channels, self.start_filts) | |
self.down_convs = [] | |
self.up_convs = [] | |
# create the encoder pathway and add to a list | |
for i in range(depth): | |
if i == 0: | |
ins = self.start_filts if use_initial_conv else self.in_channels | |
else: | |
ins = outs | |
outs = self.start_filts*(2**i) | |
downsamp_it = True if i < depth-1 else False | |
down_conv = DownConv3dAware(ins, outs, downsample = downsamp_it) | |
self.down_convs.append(down_conv) | |
for i in range(depth-1): | |
ins = outs | |
outs = ins // 2 | |
up_conv = UpConv3dAware(ins, outs, | |
merge_mode=merge_mode) | |
self.up_convs.append(up_conv) | |
# add the list of modules to current module | |
self.down_convs = nn.ModuleList(self.down_convs) | |
self.up_convs = nn.ModuleList(self.up_convs) | |
self.norm_out = Normalize(outs) | |
self.conv_final = conv1x1(outs, self.out_channels) | |
self.reset_params() | |
def weight_init(m): | |
if isinstance(m, nn.Conv2d): | |
# init.xavier_normal_(m.weight, gain=0.1) | |
init.xavier_normal_(m.weight) | |
init.constant_(m.bias, 0) | |
def reset_params(self): | |
for i, m in enumerate(self.modules()): | |
self.weight_init(m) | |
def forward(self, x): | |
""" | |
Args: | |
x: Stacked triplane expected to be in (B,3,C,H,W) | |
""" | |
# Roll | |
x = einops.rearrange(x, 'b tri c h w -> b c (tri h) w', tri=3) | |
if self.use_initial_conv: | |
x = self.conv_initial(x) | |
encoder_outs = [] | |
# encoder pathway, save outputs for merging | |
for i, module in enumerate(self.down_convs): | |
x, before_pool = module(x) | |
encoder_outs.append(before_pool) | |
# Spend a block in the middle | |
# x = self.block_mid(x) | |
for i, module in enumerate(self.up_convs): | |
before_pool = encoder_outs[-(i+2)] | |
x = module(before_pool, x) | |
x = self.norm_out(x) | |
# No softmax is used. This means you need to use | |
# nn.CrossEntropyLoss is your training script, | |
# as this module includes a softmax already. | |
x = self.conv_final(nonlinearity(x)) | |
# Unroll | |
x = einops.rearrange(x, 'b c (tri h) w -> b tri c h w', tri=3) | |
return x | |
def setup_unet(output_channels, input_channels, unet_cfg): | |
if unet_cfg['use_3d_aware']: | |
assert(unet_cfg['rolled']) | |
unet = UNetTriplane3dAware( | |
out_channels=output_channels, | |
in_channels=input_channels, | |
depth=unet_cfg['depth'], | |
use_initial_conv=unet_cfg['use_initial_conv'], | |
start_filts=unet_cfg['start_hidden_channels'],) | |
else: | |
raise NotImplementedError | |
return unet | |