File size: 1,194 Bytes
bcb05d1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms


class DinoEncoder(nn.Module):

    def __init__(
            self,
            model="facebookresearch/dinov2",
            version="dinov2_vitl14_reg",
            size=518,
    ):
        super().__init__()

        dino_model = torch.hub.load(model, version, pretrained=True)
        dino_model = dino_model.eval()
        self.encoder = dino_model
        self.transform = transforms.Compose(
            [
                transforms.Resize(size, transforms.InterpolationMode.BILINEAR, antialias=True),
                transforms.CenterCrop(size), 
                transforms.Normalize(
                    mean=[0.485, 0.456, 0.406],
                    std=[0.229, 0.224, 0.225],
                ),
            ]
        )

    def forward(self, image, image_mask=None):

        z = self.encoder(self.transform(image), is_training=True)['x_prenorm']
        z = F.layer_norm(z, z.shape[-1:])

        if image_mask is not None:
            image_mask_patch = F.max_pool2d(image_mask, kernel_size=14, stride=14).squeeze(1) > 0
            return z, image_mask_patch

        return z