|
import torch |
|
import torchvision |
|
import torch.nn as nn |
|
|
|
|
|
class resnet18(nn.Module): |
|
def __init__( |
|
self, |
|
pretrained: bool = True, |
|
output_dim: int = 512, |
|
unit_norm: bool = False, |
|
): |
|
super().__init__() |
|
resnet = torchvision.models.resnet18(pretrained=pretrained) |
|
self.resnet = nn.Sequential(*list(resnet.children())[:-1]) |
|
self.flatten = nn.Flatten() |
|
self.pretrained = pretrained |
|
self.normalize = torchvision.transforms.Normalize( |
|
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] |
|
) |
|
self.unit_norm = unit_norm |
|
|
|
def forward(self, x): |
|
dims = len(x.shape) |
|
orig_shape = x.shape |
|
if dims == 3: |
|
x = x.unsqueeze(0) |
|
elif dims > 4: |
|
|
|
x = x.reshape(-1, *orig_shape[-3:]) |
|
x = self.normalize(x) |
|
out = self.resnet(x) |
|
out = self.flatten(out) |
|
if self.unit_norm: |
|
out = torch.nn.functional.normalize(out, p=2, dim=-1) |
|
if dims == 3: |
|
out = out.squeeze(0) |
|
elif dims > 4: |
|
out = out.reshape(*orig_shape[:-3], -1) |
|
return out |
|
|