thng292's picture
Upload 18 files
d0ac7e9 verified
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
from typing import Any, Optional, Tuple, Type
from torchvision.models import swin_b, convnext_base
from .transformer import TwoWayTransformer, LayerNorm2d
from transformers.utils.generic import ModelOutput
class MLP(nn.Module):
def __init__(
self,
input_dim: int,
hidden_dim: int,
output_dim: int,
num_layers: int,
sigmoid_output: bool = False,
) -> None:
super().__init__()
self.num_layers = num_layers
h = [hidden_dim] * (num_layers - 1)
self.layers = nn.ModuleList(
nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])
)
self.sigmoid_output = sigmoid_output
def forward(self, x):
for i, layer in enumerate(self.layers):
x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
if self.sigmoid_output:
x = F.sigmoid(x)
return x
class FaceDecoder(nn.Module):
def __init__(
self,
*,
transformer_dim: 256,
transformer: nn.Module,
activation: Type[nn.Module] = nn.GELU,
) -> None:
super().__init__()
self.transformer_dim = transformer_dim
self.transformer = transformer
self.landmarks_token = nn.Embedding(1, transformer_dim)
self.pose_token = nn.Embedding(1, transformer_dim)
self.attribute_token = nn.Embedding(1, transformer_dim)
self.visibility_token = nn.Embedding(1, transformer_dim)
self.age_token = nn.Embedding(1, transformer_dim)
self.gender_token = nn.Embedding(1, transformer_dim)
self.race_token = nn.Embedding(1, transformer_dim)
self.mask_tokens = nn.Embedding(11, transformer_dim)
self.output_upscaling = nn.Sequential(
nn.ConvTranspose2d(
transformer_dim, transformer_dim // 4, kernel_size=2, stride=2
),
LayerNorm2d(transformer_dim // 4),
activation(),
nn.ConvTranspose2d(
transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2
),
activation(),
)
self.output_hypernetwork_mlps = MLP(
transformer_dim, transformer_dim, transformer_dim // 8, 3
)
self.landmarks_prediction_head = MLP(transformer_dim, transformer_dim, 136, 3)
self.pose_prediction_head = MLP(transformer_dim, transformer_dim, 3, 3)
self.attribute_prediction_head = MLP(transformer_dim, transformer_dim, 40, 3)
self.visibility_prediction_head = MLP(transformer_dim, transformer_dim, 29, 3)
self.age_prediction_head = MLP(transformer_dim, transformer_dim, 8, 3)
self.gender_prediction_head = MLP(transformer_dim, transformer_dim, 2, 3)
self.race_prediction_head = MLP(transformer_dim, transformer_dim, 5, 3)
def forward(
self,
image_embeddings: torch.Tensor,
image_pe: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
output_tokens = torch.cat(
[
self.landmarks_token.weight,
self.pose_token.weight,
self.attribute_token.weight,
self.visibility_token.weight,
self.age_token.weight,
self.gender_token.weight,
self.race_token.weight,
self.mask_tokens.weight,
],
dim=0,
)
tokens = output_tokens.unsqueeze(0).expand(image_embeddings.size(0), -1, -1)
src = image_embeddings
pos_src = image_pe.expand(image_embeddings.size(0), -1, -1, -1)
b, c, h, w = src.shape
hs, src = self.transformer(src, pos_src, tokens)
landmarks_token_out = hs[:, 0, :]
pose_token_out = hs[:, 1, :]
attribute_token_out = hs[:, 2, :]
visibility_token_out = hs[:, 3, :]
age_token_out = hs[:, 4, :]
gender_token_out = hs[:, 5, :]
race_token_out = hs[:, 6, :]
mask_token_out = hs[:, 7:, :]
landmark_output = self.landmarks_prediction_head(landmarks_token_out)
headpose_output = self.pose_prediction_head(pose_token_out)
attribute_output = self.attribute_prediction_head(attribute_token_out)
visibility_output = self.visibility_prediction_head(visibility_token_out)
age_output = self.age_prediction_head(age_token_out)
gender_output = self.gender_prediction_head(gender_token_out)
race_output = self.race_prediction_head(race_token_out)
src = src.transpose(1, 2).view(b, c, h, w)
upscaled_embedding = self.output_upscaling(src)
hyper_in = self.output_hypernetwork_mlps(mask_token_out)
b, c, h, w = upscaled_embedding.shape
seg_output = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w)
return (
landmark_output,
headpose_output,
attribute_output,
visibility_output,
age_output,
gender_output,
race_output,
seg_output,
)
class PositionEmbeddingRandom(nn.Module):
"""
Positional encoding using random spatial frequencies.
"""
def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None:
super().__init__()
if scale is None or scale <= 0.0:
scale = 1.0
self.register_buffer(
"positional_encoding_gaussian_matrix",
scale * torch.randn((2, num_pos_feats)),
)
def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor:
"""Positionally encode points that are normalized to [0,1]."""
# assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape
coords = 2 * coords - 1
coords = coords @ self.positional_encoding_gaussian_matrix
coords = 2 * np.pi * coords
# outputs d_1 x ... x d_n x C shape
return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1)
def forward(self, size: Tuple[int, int]) -> torch.Tensor:
"""Generate positional encoding for a grid of the specified size."""
h, w = size
device: Any = self.positional_encoding_gaussian_matrix.device
grid = torch.ones((h, w), device=device, dtype=torch.float32)
y_embed = grid.cumsum(dim=0) - 0.5
x_embed = grid.cumsum(dim=1) - 0.5
y_embed = y_embed / h
x_embed = x_embed / w
pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1))
return pe.permute(2, 0, 1) # C x H x W
def forward_with_coords(
self, coords_input: torch.Tensor, image_size: Tuple[int, int]
) -> torch.Tensor:
"""Positionally encode points that are not normalized to [0,1]."""
coords = coords_input.clone()
coords[:, :, 0] = coords[:, :, 0] / image_size[1]
coords[:, :, 1] = coords[:, :, 1] / image_size[0]
return self._pe_encoding(coords.to(torch.float)) # B x N x C
class FaceXFormerMLP(nn.Module):
def __init__(self, input_dim):
super().__init__()
self.proj = nn.Linear(input_dim, 256) # 128, 256, 512, 1024 => 256
def forward(self, hidden_states: torch.Tensor):
hidden_states = hidden_states.flatten(2).transpose(1, 2)
hidden_states = self.proj(hidden_states)
return hidden_states
class FaceXFormer(nn.Module):
def __init__(self):
super(FaceXFormer, self).__init__()
# Backbone: Swin-B
swin_v2 = swin_b(weights="IMAGENET1K_V1")
self.backbone = torch.nn.Sequential(*(list(swin_v2.children())[:-1]))
self.backbone.requires_grad_(False)
# # Backbone: ConvNext-B
# convnext_v2 = convnext_base(weights='IMAGENET1K_V1')
# self.backbone = torch.nn.Sequential(
# *(list(convnext_v2.children())[:-1]))
self.target_layer_names = ["0.1", "0.3", "0.5", "0.7"]
self.multi_scale_features = []
embed_dim = 1024
out_chans = 256
self.pe_layer = PositionEmbeddingRandom(out_chans // 2)
for name, module in self.backbone.named_modules():
if name in self.target_layer_names:
module.register_forward_hook(self.save_features_hook(name))
self.face_decoder = FaceDecoder(
transformer_dim=256,
transformer=TwoWayTransformer(
depth=2,
embedding_dim=256,
mlp_dim=2048,
num_heads=8,
),
)
num_encoder_blocks = 4
hidden_sizes = [128, 256, 512, 1024]
decoder_hidden_size = 256
mlps = []
for i in range(num_encoder_blocks):
mlp = FaceXFormerMLP(input_dim=hidden_sizes[i])
mlps.append(mlp)
self.linear_c = nn.ModuleList(mlps)
self.linear_fuse = nn.Conv2d(
in_channels=decoder_hidden_size * num_encoder_blocks, # 1024
out_channels=decoder_hidden_size, # 256
kernel_size=1,
bias=False,
)
def save_features_hook(self, name):
def hook(module, input, output):
self.multi_scale_features.append(output.permute(0, 3, 1, 2).contiguous())
return hook
def predict(self, x, labels, tasks):
self.multi_scale_features.clear()
_, _, h, w = x.shape
features = self.backbone(x).squeeze()
batch_size = self.multi_scale_features[-1].shape[0]
all_hidden_states = ()
for encoder_hidden_state, mlp in zip(self.multi_scale_features, self.linear_c):
height, width = encoder_hidden_state.shape[2], encoder_hidden_state.shape[3]
encoder_hidden_state = mlp(encoder_hidden_state)
encoder_hidden_state = encoder_hidden_state.permute(0, 2, 1)
encoder_hidden_state = encoder_hidden_state.reshape(
batch_size, -1, height, width
)
encoder_hidden_state = nn.functional.interpolate(
encoder_hidden_state,
size=self.multi_scale_features[0].size()[2:],
mode="bilinear",
align_corners=False,
)
all_hidden_states += (encoder_hidden_state,)
fused_states = self.linear_fuse(torch.cat(all_hidden_states[::-1], dim=1))
image_pe = self.pe_layer(
(fused_states.shape[2], fused_states.shape[3])
).unsqueeze(0)
(
landmark_output,
headpose_output,
attribute_output,
visibility_output,
age_output,
gender_output,
race_output,
seg_output,
) = self.face_decoder(image_embeddings=fused_states, image_pe=image_pe)
segmentation_indices = tasks == 0
seg_output = seg_output[segmentation_indices]
landmarks_indices = tasks == 1
landmark_output = landmark_output[landmarks_indices]
headpose_indices = tasks == 2
headpose_output = headpose_output[headpose_indices]
attribute_indices = tasks == 3
attribute_output = attribute_output[attribute_indices]
age_indices = tasks == 4
age_output = age_output[age_indices]
gender_output = gender_output[age_indices]
race_output = race_output[age_indices]
visibility_indices = tasks == 5
visibility_output = visibility_output[visibility_indices]
return (
landmark_output,
headpose_output,
attribute_output,
visibility_output,
age_output,
gender_output,
race_output,
seg_output,
)
def loss(
self, predictions: torch.Tensor, labels: torch.Tensor, num_items_in_batch=None
):
# print(predictions.shape)
# print(labels.shape)
# print("predic:", predictions)
# print("labels:", labels)
# Used L2 loss for now
loss = torch.nn.functional.mse_loss(predictions, labels, reduction="sum")
if num_items_in_batch:
loss /= num_items_in_batch
return loss
def forward(self, x, labels, num_items_in_batch=None):
self.multi_scale_features.clear()
_, _, h, w = x.shape
features = self.backbone(x).squeeze()
batch_size = self.multi_scale_features[-1].shape[0]
all_hidden_states = ()
for encoder_hidden_state, mlp in zip(self.multi_scale_features, self.linear_c):
height, width = encoder_hidden_state.shape[2], encoder_hidden_state.shape[3]
encoder_hidden_state = mlp(encoder_hidden_state)
encoder_hidden_state = encoder_hidden_state.permute(0, 2, 1)
encoder_hidden_state = encoder_hidden_state.reshape(
batch_size, -1, height, width
)
encoder_hidden_state = nn.functional.interpolate(
encoder_hidden_state,
size=self.multi_scale_features[0].size()[2:],
mode="bilinear",
align_corners=False,
)
all_hidden_states += (encoder_hidden_state,)
fused_states = self.linear_fuse(torch.cat(all_hidden_states[::-1], dim=1))
image_pe = self.pe_layer(
(fused_states.shape[2], fused_states.shape[3])
).unsqueeze(0)
(
landmark_output,
headpose_output,
attribute_output,
visibility_output,
age_output,
gender_output,
race_output,
seg_output,
) = self.face_decoder(image_embeddings=fused_states, image_pe=image_pe)
# All tasks are landmark prediction
if labels is not None:
loss = self.loss(landmark_output.view(-1, 68, 2), labels)
else:
loss = None
return ModelOutput(
loss=loss,
)