Tailor3D / openlrm /models /modeling_lrm.py
alexzyqi's picture
20240706
52d68d4
# Copyright (c) 2023-2024, Zexin He
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn as nn
from accelerate.logging import get_logger
from .embedder import CameraEmbedder
from .transformer import TransformerDecoder
from .rendering.synthesizer import TriplaneSynthesizer
from .utils import zero_module
import loratorch as lora
from .swin_transformer import CrossAttentionLayer
logger = get_logger(__name__)
class ModelLRM(nn.Module):
"""
Full model of the basic single-view large reconstruction model.
"""
def __init__(self, camera_embed_dim: int, rendering_samples_per_ray: int,
transformer_dim: int, transformer_layers: int, transformer_heads: int,
triplane_low_res: int, triplane_high_res: int, triplane_dim: int,
encoder_freeze: bool = True, encoder_type: str = 'dino',
encoder_model_name: str = 'facebook/dino-vitb16', encoder_feat_dim: int = 768,
model_lora_rank: int = 0, conv_fuse=False,
swin_ca_fuse=False, ca_dim=32, ca_depth=2, ca_num_heads=8, ca_window_size=2):
super().__init__()
# attributes
self.encoder_feat_dim = encoder_feat_dim
self.camera_embed_dim = camera_embed_dim
self.triplane_low_res = triplane_low_res
self.triplane_high_res = triplane_high_res
self.triplane_dim = triplane_dim
self.conv_fuse = conv_fuse
self.swin_ca_fuse = swin_ca_fuse
# modules
self.encoder = self._encoder_fn(encoder_type)(
model_name=encoder_model_name,
freeze=encoder_freeze,
)
self.camera_embedder = CameraEmbedder(
raw_dim=12+4, embed_dim=camera_embed_dim,
)
# initialize pos_embed with 1/sqrt(dim) * N(0, 1)
self.pos_embed = nn.Parameter(torch.randn(1, 3*triplane_low_res**2, transformer_dim) * (1. / transformer_dim) ** 0.5)
if model_lora_rank > 0:
self.transformer = TransformerDecoder(
block_type='cond_mod',
num_layers=transformer_layers, num_heads=transformer_heads,
inner_dim=transformer_dim, cond_dim=encoder_feat_dim, mod_dim=camera_embed_dim,
lora_rank=model_lora_rank
)
lora.mark_only_lora_as_trainable(self.transformer)
else:
self.transformer = TransformerDecoder(
block_type='cond_mod',
num_layers=transformer_layers, num_heads=transformer_heads,
inner_dim=transformer_dim, cond_dim=encoder_feat_dim, mod_dim=camera_embed_dim,
)
self.upsampler = nn.ConvTranspose2d(transformer_dim, triplane_dim, kernel_size=2, stride=2, padding=0)
self.synthesizer = TriplaneSynthesizer(
triplane_dim=triplane_dim, samples_per_ray=rendering_samples_per_ray,
)
if model_lora_rank > 0:
if self.conv_fuse:
# self.front_back_conv = nn.Conv2d(in_channels=triplane_dim*2, out_channels=triplane_dim, kernel_size=(3, 3), stride=(1, 1), padding=1)
# zero_module(self.front_back_conv)
self.front_back_conv = nn.ModuleList([
nn.Conv2d(in_channels=triplane_dim*2, out_channels=triplane_dim*4, kernel_size=(3, 3), stride=(1, 1), padding=1),
nn.LayerNorm([triplane_dim*4, triplane_high_res, triplane_high_res]), # Using Layer Normalization
nn.GELU(), # Using GELU activation
nn.Conv2d(in_channels=triplane_dim*4, out_channels=triplane_dim*4, kernel_size=(3, 3), stride=(1, 1), padding=1),
nn.LayerNorm([triplane_dim*4, triplane_high_res, triplane_high_res]), # Using Layer Normalization
nn.GELU(), # Using GELU activation
nn.Conv2d(in_channels=triplane_dim*4, out_channels=triplane_dim, kernel_size=(3, 3), stride=(1, 1), padding=1)
])
self.freeze_modules(encoder=True, camera_embedder=True,
pos_embed=False, transformer=False, upsampler=False,
synthesizer=False)
elif self.swin_ca_fuse:
self.swin_cross_attention = CrossAttentionLayer(dim=ca_dim, depth=ca_depth, num_heads=ca_num_heads, window_size=ca_window_size)
self.freeze_modules(encoder=True, camera_embedder=True,
pos_embed=False, transformer=False, upsampler=False,
synthesizer=False)
else:
raise ValueError("You need to specify a method for fusing the front and the back.")
def freeze_modules(self, encoder=False, camera_embedder=False,
pos_embed=False, transformer=False, upsampler=False,
synthesizer=False):
"""
Freeze specified modules
"""
if encoder:
for param in self.encoder.parameters():
param.requires_grad = False
if camera_embedder:
for param in self.camera_embedder.parameters():
param.requires_grad = False
if pos_embed:
for param in self.pos_embed.parameters():
param.requires_grad = False
if transformer:
for param in self.transformer.parameters():
param.requires_grad = False
if upsampler:
for param in self.upsampler.parameters():
param.requires_grad = False
if synthesizer:
for param in self.synthesizer.parameters():
param.requires_grad = False
@staticmethod
def _encoder_fn(encoder_type: str):
encoder_type = encoder_type.lower()
assert encoder_type in ['dino', 'dinov2'], "Unsupported encoder type"
if encoder_type == 'dino':
from .encoders.dino_wrapper import DinoWrapper
logger.info("Using DINO as the encoder")
return DinoWrapper
elif encoder_type == 'dinov2':
from .encoders.dinov2_wrapper import Dinov2Wrapper
logger.info("Using DINOv2 as the encoder")
return Dinov2Wrapper
def forward_transformer(self, image_feats, camera_embeddings):
assert image_feats.shape[0] == camera_embeddings.shape[0], \
"Batch size mismatch for image_feats and camera_embeddings!"
N = image_feats.shape[0]
x = self.pos_embed.repeat(N, 1, 1) # [N, L, D]
x = self.transformer(
x,
cond=image_feats,
mod=camera_embeddings,
)
return x
def reshape_upsample(self, tokens):
N = tokens.shape[0]
H = W = self.triplane_low_res
x = tokens.view(N, 3, H, W, -1)
x = torch.einsum('nihwd->indhw', x) # [3, N, D, H, W]
x = x.contiguous().view(3*N, -1, H, W) # [3*N, D, H, W]
x = self.upsampler(x) # [3*N, D', H', W']
x = x.view(3, N, *x.shape[-3:]) # [3, N, D', H', W']
x = torch.einsum('indhw->nidhw', x) # [N, 3, D', H', W']
x = x.contiguous()
return x
@torch.compile
def forward_planes(self, image, camera):
# image: [N, C_img, H_img, W_img]
# camera: [N, D_cam_raw]
N = image.shape[0]
# encode image
image_feats = self.encoder(image)
assert image_feats.shape[-1] == self.encoder_feat_dim, \
f"Feature dimension mismatch: {image_feats.shape[-1]} vs {self.encoder_feat_dim}"
# embed camera
camera_embeddings = self.camera_embedder(camera)
assert camera_embeddings.shape[-1] == self.camera_embed_dim, \
f"Feature dimension mismatch: {camera_embeddings.shape[-1]} vs {self.camera_embed_dim}"
# transformer generating planes
tokens = self.forward_transformer(image_feats, camera_embeddings)
planes = self.reshape_upsample(tokens)
assert planes.shape[0] == N, "Batch size mismatch for planes"
assert planes.shape[1] == 3, "Planes should have 3 channels"
return planes
def forward(self, image, source_camera, render_cameras, render_anchors, render_resolutions, render_bg_colors, render_region_size: int,
image_back=None,):
# image: [N, C_img, H_img, W_img]
# source_camera: [N, D_cam_raw]
# render_cameras: [N, M, D_cam_render]
# render_anchors: [N, M, 2]
# render_resolutions: [N, M, 1]
# render_bg_colors: [N, M, 1]
# render_region_size: int
assert image.shape[0] == source_camera.shape[0], "Batch size mismatch for image and source_camera"
assert image.shape[0] == render_cameras.shape[0], "Batch size mismatch for image and render_cameras"
assert image.shape[0] == render_anchors.shape[0], "Batch size mismatch for image and render_anchors"
assert image.shape[0] == render_bg_colors.shape[0], "Batch size mismatch for image and render_bg_colors"
N, M = render_cameras.shape[:2]
if image_back is not None:
front_planes = self.forward_planes(image, source_camera)
back_planes = self.forward_planes(image_back, source_camera)
# XY Plane
back_planes[:, 0, :, :, :] = torch.flip(back_planes[:, 0, :, :, :], dims=[-2, -1])
# XZ Plane
back_planes[:, 1, :, :, :] = torch.flip(back_planes[:, 1, :, :, :], dims=[-1])
# YZ Plane
back_planes[:, 2, :, :, :] = torch.flip(back_planes[:, 2, :, :, :], dims=[-2])
# To fuse the front planes and the back planes
bs, num_planes, channels, height, width = front_planes.shape
if self.conv_fuse:
planes = torch.cat((front_planes, back_planes), dim=2)
planes = planes.reshape(-1, channels*2, height, width)
# Apply multiple convolutional layers
for layer in self.front_back_conv:
planes = layer(planes)
planes = planes.view(bs, num_planes, -1, height, width)
# planes = self.front_back_conv(planes).view(bs, num_planes, -1, height, width) # only one layer.
elif self.swin_ca_fuse:
front_planes = front_planes.reshape(bs*num_planes, channels, height*width).permute(0, 2, 1).contiguous() # [8, 3, 32, 64, 64] -> [24, 32, 4096] -> [24, 4096, 32]
back_planes = back_planes.reshape(bs*num_planes, channels, height*width).permute(0, 2, 1).contiguous()
planes = self.swin_cross_attention(front_planes, back_planes, height, width)[0].permute(0, 2, 1).reshape(bs, num_planes, channels, height, width)
else:
planes = self.forward_planes(image, source_camera)
# render target views
render_results = self.synthesizer(planes, render_cameras, render_anchors, render_resolutions, render_bg_colors, render_region_size)
assert render_results['images_rgb'].shape[0] == N, "Batch size mismatch for render_results"
assert render_results['images_rgb'].shape[1] == M, "Number of rendered views should be consistent with render_cameras"
return {
'planes': planes,
**render_results,
}