Spaces:
Running
Running
import numpy as np | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import torchvision.models as models | |
from typing import Any, Optional, Tuple, Type | |
from torchvision.models import swin_b, convnext_base | |
from .transformer import TwoWayTransformer, LayerNorm2d | |
from transformers.utils.generic import ModelOutput | |
class MLP(nn.Module): | |
def __init__( | |
self, | |
input_dim: int, | |
hidden_dim: int, | |
output_dim: int, | |
num_layers: int, | |
sigmoid_output: bool = False, | |
) -> None: | |
super().__init__() | |
self.num_layers = num_layers | |
h = [hidden_dim] * (num_layers - 1) | |
self.layers = nn.ModuleList( | |
nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]) | |
) | |
self.sigmoid_output = sigmoid_output | |
def forward(self, x): | |
for i, layer in enumerate(self.layers): | |
x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) | |
if self.sigmoid_output: | |
x = F.sigmoid(x) | |
return x | |
class FaceDecoder(nn.Module): | |
def __init__( | |
self, | |
*, | |
transformer_dim: 256, | |
transformer: nn.Module, | |
activation: Type[nn.Module] = nn.GELU, | |
) -> None: | |
super().__init__() | |
self.transformer_dim = transformer_dim | |
self.transformer = transformer | |
self.landmarks_token = nn.Embedding(1, transformer_dim) | |
self.pose_token = nn.Embedding(1, transformer_dim) | |
self.attribute_token = nn.Embedding(1, transformer_dim) | |
self.visibility_token = nn.Embedding(1, transformer_dim) | |
self.age_token = nn.Embedding(1, transformer_dim) | |
self.gender_token = nn.Embedding(1, transformer_dim) | |
self.race_token = nn.Embedding(1, transformer_dim) | |
self.mask_tokens = nn.Embedding(11, transformer_dim) | |
self.output_upscaling = nn.Sequential( | |
nn.ConvTranspose2d( | |
transformer_dim, transformer_dim // 4, kernel_size=2, stride=2 | |
), | |
LayerNorm2d(transformer_dim // 4), | |
activation(), | |
nn.ConvTranspose2d( | |
transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2 | |
), | |
activation(), | |
) | |
self.output_hypernetwork_mlps = MLP( | |
transformer_dim, transformer_dim, transformer_dim // 8, 3 | |
) | |
self.landmarks_prediction_head = MLP(transformer_dim, transformer_dim, 136, 3) | |
self.pose_prediction_head = MLP(transformer_dim, transformer_dim, 3, 3) | |
self.attribute_prediction_head = MLP(transformer_dim, transformer_dim, 40, 3) | |
self.visibility_prediction_head = MLP(transformer_dim, transformer_dim, 29, 3) | |
self.age_prediction_head = MLP(transformer_dim, transformer_dim, 8, 3) | |
self.gender_prediction_head = MLP(transformer_dim, transformer_dim, 2, 3) | |
self.race_prediction_head = MLP(transformer_dim, transformer_dim, 5, 3) | |
def forward( | |
self, | |
image_embeddings: torch.Tensor, | |
image_pe: torch.Tensor, | |
) -> Tuple[torch.Tensor, torch.Tensor]: | |
output_tokens = torch.cat( | |
[ | |
self.landmarks_token.weight, | |
self.pose_token.weight, | |
self.attribute_token.weight, | |
self.visibility_token.weight, | |
self.age_token.weight, | |
self.gender_token.weight, | |
self.race_token.weight, | |
self.mask_tokens.weight, | |
], | |
dim=0, | |
) | |
tokens = output_tokens.unsqueeze(0).expand(image_embeddings.size(0), -1, -1) | |
src = image_embeddings | |
pos_src = image_pe.expand(image_embeddings.size(0), -1, -1, -1) | |
b, c, h, w = src.shape | |
hs, src = self.transformer(src, pos_src, tokens) | |
landmarks_token_out = hs[:, 0, :] | |
pose_token_out = hs[:, 1, :] | |
attribute_token_out = hs[:, 2, :] | |
visibility_token_out = hs[:, 3, :] | |
age_token_out = hs[:, 4, :] | |
gender_token_out = hs[:, 5, :] | |
race_token_out = hs[:, 6, :] | |
mask_token_out = hs[:, 7:, :] | |
landmark_output = self.landmarks_prediction_head(landmarks_token_out) | |
headpose_output = self.pose_prediction_head(pose_token_out) | |
attribute_output = self.attribute_prediction_head(attribute_token_out) | |
visibility_output = self.visibility_prediction_head(visibility_token_out) | |
age_output = self.age_prediction_head(age_token_out) | |
gender_output = self.gender_prediction_head(gender_token_out) | |
race_output = self.race_prediction_head(race_token_out) | |
src = src.transpose(1, 2).view(b, c, h, w) | |
upscaled_embedding = self.output_upscaling(src) | |
hyper_in = self.output_hypernetwork_mlps(mask_token_out) | |
b, c, h, w = upscaled_embedding.shape | |
seg_output = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w) | |
return ( | |
landmark_output, | |
headpose_output, | |
attribute_output, | |
visibility_output, | |
age_output, | |
gender_output, | |
race_output, | |
seg_output, | |
) | |
class PositionEmbeddingRandom(nn.Module): | |
""" | |
Positional encoding using random spatial frequencies. | |
""" | |
def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None: | |
super().__init__() | |
if scale is None or scale <= 0.0: | |
scale = 1.0 | |
self.register_buffer( | |
"positional_encoding_gaussian_matrix", | |
scale * torch.randn((2, num_pos_feats)), | |
) | |
def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor: | |
"""Positionally encode points that are normalized to [0,1].""" | |
# assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape | |
coords = 2 * coords - 1 | |
coords = coords @ self.positional_encoding_gaussian_matrix | |
coords = 2 * np.pi * coords | |
# outputs d_1 x ... x d_n x C shape | |
return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1) | |
def forward(self, size: Tuple[int, int]) -> torch.Tensor: | |
"""Generate positional encoding for a grid of the specified size.""" | |
h, w = size | |
device: Any = self.positional_encoding_gaussian_matrix.device | |
grid = torch.ones((h, w), device=device, dtype=torch.float32) | |
y_embed = grid.cumsum(dim=0) - 0.5 | |
x_embed = grid.cumsum(dim=1) - 0.5 | |
y_embed = y_embed / h | |
x_embed = x_embed / w | |
pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1)) | |
return pe.permute(2, 0, 1) # C x H x W | |
def forward_with_coords( | |
self, coords_input: torch.Tensor, image_size: Tuple[int, int] | |
) -> torch.Tensor: | |
"""Positionally encode points that are not normalized to [0,1].""" | |
coords = coords_input.clone() | |
coords[:, :, 0] = coords[:, :, 0] / image_size[1] | |
coords[:, :, 1] = coords[:, :, 1] / image_size[0] | |
return self._pe_encoding(coords.to(torch.float)) # B x N x C | |
class FaceXFormerMLP(nn.Module): | |
def __init__(self, input_dim): | |
super().__init__() | |
self.proj = nn.Linear(input_dim, 256) # 128, 256, 512, 1024 => 256 | |
def forward(self, hidden_states: torch.Tensor): | |
hidden_states = hidden_states.flatten(2).transpose(1, 2) | |
hidden_states = self.proj(hidden_states) | |
return hidden_states | |
class FaceXFormer(nn.Module): | |
def __init__(self): | |
super(FaceXFormer, self).__init__() | |
# Backbone: Swin-B | |
swin_v2 = swin_b(weights="IMAGENET1K_V1") | |
self.backbone = torch.nn.Sequential(*(list(swin_v2.children())[:-1])) | |
self.backbone.requires_grad_(False) | |
# # Backbone: ConvNext-B | |
# convnext_v2 = convnext_base(weights='IMAGENET1K_V1') | |
# self.backbone = torch.nn.Sequential( | |
# *(list(convnext_v2.children())[:-1])) | |
self.target_layer_names = ["0.1", "0.3", "0.5", "0.7"] | |
self.multi_scale_features = [] | |
embed_dim = 1024 | |
out_chans = 256 | |
self.pe_layer = PositionEmbeddingRandom(out_chans // 2) | |
for name, module in self.backbone.named_modules(): | |
if name in self.target_layer_names: | |
module.register_forward_hook(self.save_features_hook(name)) | |
self.face_decoder = FaceDecoder( | |
transformer_dim=256, | |
transformer=TwoWayTransformer( | |
depth=2, | |
embedding_dim=256, | |
mlp_dim=2048, | |
num_heads=8, | |
), | |
) | |
num_encoder_blocks = 4 | |
hidden_sizes = [128, 256, 512, 1024] | |
decoder_hidden_size = 256 | |
mlps = [] | |
for i in range(num_encoder_blocks): | |
mlp = FaceXFormerMLP(input_dim=hidden_sizes[i]) | |
mlps.append(mlp) | |
self.linear_c = nn.ModuleList(mlps) | |
self.linear_fuse = nn.Conv2d( | |
in_channels=decoder_hidden_size * num_encoder_blocks, # 1024 | |
out_channels=decoder_hidden_size, # 256 | |
kernel_size=1, | |
bias=False, | |
) | |
def save_features_hook(self, name): | |
def hook(module, input, output): | |
self.multi_scale_features.append(output.permute(0, 3, 1, 2).contiguous()) | |
return hook | |
def predict(self, x, labels, tasks): | |
self.multi_scale_features.clear() | |
_, _, h, w = x.shape | |
features = self.backbone(x).squeeze() | |
batch_size = self.multi_scale_features[-1].shape[0] | |
all_hidden_states = () | |
for encoder_hidden_state, mlp in zip(self.multi_scale_features, self.linear_c): | |
height, width = encoder_hidden_state.shape[2], encoder_hidden_state.shape[3] | |
encoder_hidden_state = mlp(encoder_hidden_state) | |
encoder_hidden_state = encoder_hidden_state.permute(0, 2, 1) | |
encoder_hidden_state = encoder_hidden_state.reshape( | |
batch_size, -1, height, width | |
) | |
encoder_hidden_state = nn.functional.interpolate( | |
encoder_hidden_state, | |
size=self.multi_scale_features[0].size()[2:], | |
mode="bilinear", | |
align_corners=False, | |
) | |
all_hidden_states += (encoder_hidden_state,) | |
fused_states = self.linear_fuse(torch.cat(all_hidden_states[::-1], dim=1)) | |
image_pe = self.pe_layer( | |
(fused_states.shape[2], fused_states.shape[3]) | |
).unsqueeze(0) | |
( | |
landmark_output, | |
headpose_output, | |
attribute_output, | |
visibility_output, | |
age_output, | |
gender_output, | |
race_output, | |
seg_output, | |
) = self.face_decoder(image_embeddings=fused_states, image_pe=image_pe) | |
segmentation_indices = tasks == 0 | |
seg_output = seg_output[segmentation_indices] | |
landmarks_indices = tasks == 1 | |
landmark_output = landmark_output[landmarks_indices] | |
headpose_indices = tasks == 2 | |
headpose_output = headpose_output[headpose_indices] | |
attribute_indices = tasks == 3 | |
attribute_output = attribute_output[attribute_indices] | |
age_indices = tasks == 4 | |
age_output = age_output[age_indices] | |
gender_output = gender_output[age_indices] | |
race_output = race_output[age_indices] | |
visibility_indices = tasks == 5 | |
visibility_output = visibility_output[visibility_indices] | |
return ( | |
landmark_output, | |
headpose_output, | |
attribute_output, | |
visibility_output, | |
age_output, | |
gender_output, | |
race_output, | |
seg_output, | |
) | |
def loss( | |
self, predictions: torch.Tensor, labels: torch.Tensor, num_items_in_batch=None | |
): | |
# print(predictions.shape) | |
# print(labels.shape) | |
# print("predic:", predictions) | |
# print("labels:", labels) | |
# Used L2 loss for now | |
loss = torch.nn.functional.mse_loss(predictions, labels, reduction="sum") | |
if num_items_in_batch: | |
loss /= num_items_in_batch | |
return loss | |
def forward(self, x, labels, num_items_in_batch=None): | |
self.multi_scale_features.clear() | |
_, _, h, w = x.shape | |
features = self.backbone(x).squeeze() | |
batch_size = self.multi_scale_features[-1].shape[0] | |
all_hidden_states = () | |
for encoder_hidden_state, mlp in zip(self.multi_scale_features, self.linear_c): | |
height, width = encoder_hidden_state.shape[2], encoder_hidden_state.shape[3] | |
encoder_hidden_state = mlp(encoder_hidden_state) | |
encoder_hidden_state = encoder_hidden_state.permute(0, 2, 1) | |
encoder_hidden_state = encoder_hidden_state.reshape( | |
batch_size, -1, height, width | |
) | |
encoder_hidden_state = nn.functional.interpolate( | |
encoder_hidden_state, | |
size=self.multi_scale_features[0].size()[2:], | |
mode="bilinear", | |
align_corners=False, | |
) | |
all_hidden_states += (encoder_hidden_state,) | |
fused_states = self.linear_fuse(torch.cat(all_hidden_states[::-1], dim=1)) | |
image_pe = self.pe_layer( | |
(fused_states.shape[2], fused_states.shape[3]) | |
).unsqueeze(0) | |
( | |
landmark_output, | |
headpose_output, | |
attribute_output, | |
visibility_output, | |
age_output, | |
gender_output, | |
race_output, | |
seg_output, | |
) = self.face_decoder(image_embeddings=fused_states, image_pe=image_pe) | |
# All tasks are landmark prediction | |
if labels is not None: | |
loss = self.loss(landmark_output.view(-1, 68, 2), labels) | |
else: | |
loss = None | |
return ModelOutput( | |
loss=loss, | |
) | |