Spaces:
Running
on
Zero
Running
on
Zero
File size: 1,194 Bytes
bcb05d1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 |
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
class DinoEncoder(nn.Module):
def __init__(
self,
model="facebookresearch/dinov2",
version="dinov2_vitl14_reg",
size=518,
):
super().__init__()
dino_model = torch.hub.load(model, version, pretrained=True)
dino_model = dino_model.eval()
self.encoder = dino_model
self.transform = transforms.Compose(
[
transforms.Resize(size, transforms.InterpolationMode.BILINEAR, antialias=True),
transforms.CenterCrop(size),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225],
),
]
)
def forward(self, image, image_mask=None):
z = self.encoder(self.transform(image), is_training=True)['x_prenorm']
z = F.layer_norm(z, z.shape[-1:])
if image_mask is not None:
image_mask_patch = F.max_pool2d(image_mask, kernel_size=14, stride=14).squeeze(1) > 0
return z, image_mask_patch
return z
|