test / direct3d_s2 /models /conditioner.py
wushuang98's picture
Upload 197 files
bcb05d1 verified
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
class DinoEncoder(nn.Module):
def __init__(
self,
model="facebookresearch/dinov2",
version="dinov2_vitl14_reg",
size=518,
):
super().__init__()
dino_model = torch.hub.load(model, version, pretrained=True)
dino_model = dino_model.eval()
self.encoder = dino_model
self.transform = transforms.Compose(
[
transforms.Resize(size, transforms.InterpolationMode.BILINEAR, antialias=True),
transforms.CenterCrop(size),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225],
),
]
)
def forward(self, image, image_mask=None):
z = self.encoder(self.transform(image), is_training=True)['x_prenorm']
z = F.layer_norm(z, z.shape[-1:])
if image_mask is not None:
image_mask_patch = F.max_pool2d(image_mask, kernel_size=14, stride=14).squeeze(1) > 0
return z, image_mask_patch
return z