|
import torch |
|
import torch.nn as nn |
|
from ldm.modules.attention import BasicTransformerBlock |
|
from ldm.modules.diffusionmodules.util import checkpoint, FourierEmbedder |
|
import torch.nn.functional as F |
|
|
|
|
|
|
|
class PositionNet(nn.Module): |
|
def __init__(self, positive_len, out_dim, fourier_freqs=8): |
|
super().__init__() |
|
self.positive_len = positive_len |
|
self.out_dim = out_dim |
|
|
|
self.fourier_embedder = FourierEmbedder(num_freqs=fourier_freqs) |
|
self.position_dim = fourier_freqs*2*4 |
|
|
|
self.linears = nn.Sequential( |
|
nn.Linear( self.positive_len + self.position_dim, 512), |
|
nn.SiLU(), |
|
nn.Linear( 512, 512), |
|
nn.SiLU(), |
|
nn.Linear(512, out_dim), |
|
) |
|
|
|
self.null_positive_feature = torch.nn.Parameter(torch.zeros([self.positive_len])) |
|
self.null_position_feature = torch.nn.Parameter(torch.zeros([self.position_dim])) |
|
|
|
|
|
def forward(self, boxes, masks, positive_embeddings): |
|
B, N, _ = boxes.shape |
|
masks = masks.unsqueeze(-1) |
|
|
|
|
|
xyxy_embedding = self.fourier_embedder(boxes) |
|
|
|
|
|
positive_null = self.null_positive_feature.view(1,1,-1) |
|
xyxy_null = self.null_position_feature.view(1,1,-1) |
|
|
|
|
|
positive_embeddings = positive_embeddings*masks + (1-masks)*positive_null |
|
xyxy_embedding = xyxy_embedding*masks + (1-masks)*xyxy_null |
|
|
|
objs = self.linears( torch.cat([positive_embeddings, xyxy_embedding], dim=-1) ) |
|
assert objs.shape == torch.Size([B,N,self.out_dim]) |
|
return objs |
|
|
|
|
|
|
|
|