|
import copy |
|
import torch |
|
from torch import nn, Tensor |
|
from models.ops.modules import MSDeformAttn |
|
import torch.nn.functional as F |
|
|
|
|
|
class DeformableTransformerEncoderLayer(nn.Module): |
|
def __init__(self, |
|
d_model=256, d_ffn=1024, |
|
dropout=0.1, activation="relu", |
|
n_levels=4, n_heads=8, n_points=4): |
|
super().__init__() |
|
|
|
|
|
self.self_attn = MSDeformAttn(d_model, n_levels, n_heads, n_points) |
|
self.dropout1 = nn.Dropout(dropout) |
|
self.norm1 = nn.LayerNorm(d_model) |
|
|
|
|
|
self.linear1 = nn.Linear(d_model, d_ffn) |
|
self.activation = _get_activation_fn(activation) |
|
self.dropout2 = nn.Dropout(dropout) |
|
self.linear2 = nn.Linear(d_ffn, d_model) |
|
self.dropout3 = nn.Dropout(dropout) |
|
self.norm2 = nn.LayerNorm(d_model) |
|
|
|
@staticmethod |
|
def with_pos_embed(tensor, pos): |
|
return tensor if pos is None else tensor + pos |
|
|
|
def forward_ffn(self, src): |
|
src2 = self.linear2(self.dropout2(self.activation(self.linear1(src)))) |
|
src = src + self.dropout3(src2) |
|
src = self.norm2(src) |
|
return src |
|
|
|
def forward(self, src, pos, reference_points, spatial_shapes, level_start_index, padding_mask=None): |
|
|
|
src2 = self.self_attn(self.with_pos_embed(src, pos), reference_points, src, spatial_shapes, level_start_index, |
|
padding_mask) |
|
src = src + self.dropout1(src2) |
|
src = self.norm1(src) |
|
|
|
|
|
src = self.forward_ffn(src) |
|
|
|
return src |
|
|
|
|
|
class DeformableTransformerEncoder(nn.Module): |
|
def __init__(self, encoder_layer, num_layers): |
|
super().__init__() |
|
self.layers = _get_clones(encoder_layer, num_layers) |
|
self.num_layers = num_layers |
|
|
|
@staticmethod |
|
def get_reference_points(spatial_shapes, valid_ratios, device): |
|
reference_points_list = [] |
|
for lvl, (H_, W_) in enumerate(spatial_shapes): |
|
ref_y, ref_x = torch.meshgrid(torch.linspace(0.5, H_ - 0.5, H_, dtype=torch.float32, device=device), |
|
torch.linspace(0.5, W_ - 0.5, W_, dtype=torch.float32, device=device)) |
|
ref_y = ref_y.reshape(-1)[None] / (valid_ratios[:, None, lvl, 1] * H_) |
|
ref_x = ref_x.reshape(-1)[None] / (valid_ratios[:, None, lvl, 0] * W_) |
|
ref = torch.stack((ref_x, ref_y), -1) |
|
reference_points_list.append(ref) |
|
reference_points = torch.cat(reference_points_list, 1) |
|
reference_points = reference_points[:, :, None] * valid_ratios[:, None] |
|
return reference_points |
|
|
|
def forward(self, src, spatial_shapes, level_start_index, valid_ratios, pos=None, padding_mask=None): |
|
output = src |
|
reference_points = self.get_reference_points(spatial_shapes, valid_ratios, device=src.device) |
|
for _, layer in enumerate(self.layers): |
|
output = layer(output, pos, reference_points, spatial_shapes, level_start_index, padding_mask) |
|
|
|
return output |
|
|
|
|
|
class DeformableAttnDecoderLayer(nn.Module): |
|
def __init__(self, d_model=256, d_ffn=1024, |
|
dropout=0.1, activation="relu", |
|
n_levels=4, n_heads=8, n_points=4): |
|
super().__init__() |
|
|
|
self.cross_attn = MSDeformAttn(d_model, n_levels, n_heads, n_points) |
|
self.dropout1 = nn.Dropout(dropout) |
|
self.norm1 = nn.LayerNorm(d_model) |
|
|
|
|
|
self.linear1 = nn.Linear(d_model, d_ffn) |
|
self.activation = _get_activation_fn(activation) |
|
self.dropout3 = nn.Dropout(dropout) |
|
self.linear2 = nn.Linear(d_ffn, d_model) |
|
self.dropout4 = nn.Dropout(dropout) |
|
self.norm3 = nn.LayerNorm(d_model) |
|
|
|
@staticmethod |
|
def with_pos_embed(tensor, pos): |
|
return tensor if pos is None else tensor + pos |
|
|
|
def forward_ffn(self, tgt): |
|
tgt2 = self.linear2(self.dropout3(self.activation(self.linear1(tgt)))) |
|
tgt = tgt + self.dropout4(tgt2) |
|
tgt = self.norm3(tgt) |
|
return tgt |
|
|
|
def forward(self, tgt, query_pos, reference_points, src, src_spatial_shapes, level_start_index, |
|
src_padding_mask=None, |
|
key_padding_mask=None): |
|
|
|
tgt2 = self.cross_attn(self.with_pos_embed(tgt, query_pos), |
|
reference_points, |
|
src, src_spatial_shapes, level_start_index, src_padding_mask) |
|
tgt = tgt + self.dropout1(tgt2) |
|
tgt = self.norm1(tgt) |
|
|
|
|
|
tgt = self.forward_ffn(tgt) |
|
|
|
return tgt |
|
|
|
|
|
|
|
class DeformableTransformerDecoderLayer(nn.Module): |
|
def __init__(self, d_model=256, d_ffn=1024, |
|
dropout=0.1, activation="relu", |
|
n_levels=4, n_heads=8, n_points=4): |
|
super().__init__() |
|
|
|
self.cross_attn = MSDeformAttn(d_model, n_levels, n_heads, n_points) |
|
self.dropout1 = nn.Dropout(dropout) |
|
self.norm1 = nn.LayerNorm(d_model) |
|
|
|
|
|
self.self_attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout) |
|
self.dropout2 = nn.Dropout(dropout) |
|
self.norm2 = nn.LayerNorm(d_model) |
|
|
|
|
|
self.linear1 = nn.Linear(d_model, d_ffn) |
|
self.activation = _get_activation_fn(activation) |
|
self.dropout3 = nn.Dropout(dropout) |
|
self.linear2 = nn.Linear(d_ffn, d_model) |
|
self.dropout4 = nn.Dropout(dropout) |
|
self.norm3 = nn.LayerNorm(d_model) |
|
|
|
@staticmethod |
|
def with_pos_embed(tensor, pos): |
|
return tensor if pos is None else tensor + pos |
|
|
|
def forward_ffn(self, tgt): |
|
tgt2 = self.linear2(self.dropout3(self.activation(self.linear1(tgt)))) |
|
tgt = tgt + self.dropout4(tgt2) |
|
tgt = self.norm3(tgt) |
|
return tgt |
|
|
|
def forward(self, tgt, query_pos, reference_points, src, src_spatial_shapes, level_start_index, |
|
src_padding_mask=None, |
|
key_padding_mask=None, |
|
get_image_feat=True): |
|
|
|
q = k = self.with_pos_embed(tgt, query_pos) |
|
tgt2 = \ |
|
self.self_attn(q.transpose(0, 1), k.transpose(0, 1), tgt.transpose(0, 1), key_padding_mask=key_padding_mask)[ |
|
0].transpose(0, 1) |
|
tgt = tgt + self.dropout2(tgt2) |
|
tgt = self.norm2(tgt) |
|
|
|
if get_image_feat: |
|
|
|
tgt2 = self.cross_attn(self.with_pos_embed(tgt, query_pos), |
|
reference_points, |
|
src, src_spatial_shapes, level_start_index, src_padding_mask) |
|
tgt = tgt + self.dropout1(tgt2) |
|
tgt = self.norm1(tgt) |
|
|
|
|
|
tgt = self.forward_ffn(tgt) |
|
|
|
return tgt |
|
|
|
|
|
class DeformableTransformerDecoder(nn.Module): |
|
def __init__(self, decoder_layer, num_layers, return_intermediate=False, with_sa=True): |
|
super().__init__() |
|
self.layers = _get_clones(decoder_layer, num_layers) |
|
self.num_layers = num_layers |
|
self.return_intermediate = return_intermediate |
|
|
|
self.with_sa = with_sa |
|
|
|
def forward(self, tgt, reference_points, src, src_spatial_shapes, src_level_start_index, src_valid_ratios, |
|
query_pos=None, src_padding_mask=None, key_padding_mask=None, get_image_feat=True): |
|
output = tgt |
|
|
|
intermediate = [] |
|
intermediate_reference_points = [] |
|
for lid, layer in enumerate(self.layers): |
|
if reference_points.shape[-1] == 4: |
|
reference_points_input = reference_points[:, :, None] \ |
|
* torch.cat([src_valid_ratios, src_valid_ratios], -1)[:, None] |
|
else: |
|
assert reference_points.shape[-1] == 2 |
|
reference_points_input = reference_points[:, :, None] * src_valid_ratios[:, None] |
|
if self.with_sa: |
|
output = layer(output, query_pos, reference_points_input, src, src_spatial_shapes, src_level_start_index, |
|
src_padding_mask, key_padding_mask, get_image_feat) |
|
else: |
|
output = layer(output, query_pos, reference_points_input, src, src_spatial_shapes, |
|
src_level_start_index, |
|
src_padding_mask, key_padding_mask) |
|
|
|
if self.return_intermediate: |
|
intermediate.append(output) |
|
intermediate_reference_points.append(reference_points) |
|
|
|
if self.return_intermediate: |
|
return torch.stack(intermediate), torch.stack(intermediate_reference_points) |
|
|
|
return output, reference_points |
|
|
|
|
|
def _get_clones(module, N): |
|
return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) |
|
|
|
|
|
def _get_activation_fn(activation): |
|
"""Return an activation function given a string""" |
|
if activation == "relu": |
|
return F.relu |
|
if activation == "gelu": |
|
return F.gelu |
|
if activation == "glu": |
|
return F.glu |
|
raise RuntimeError(F"activation should be relu/gelu, not {activation}.") |
|
|