import torch import torchvision import torch.nn as nn class resnet18(nn.Module): def __init__( self, pretrained: bool = True, output_dim: int = 512, # fixed for resnet18; included for consistency with config unit_norm: bool = False, ): super().__init__() resnet = torchvision.models.resnet18(pretrained=pretrained) self.resnet = nn.Sequential(*list(resnet.children())[:-1]) self.flatten = nn.Flatten() self.pretrained = pretrained self.normalize = torchvision.transforms.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] ) self.unit_norm = unit_norm def forward(self, x): dims = len(x.shape) orig_shape = x.shape if dims == 3: x = x.unsqueeze(0) elif dims > 4: # flatten all dimensions to batch, then reshape back at the end x = x.reshape(-1, *orig_shape[-3:]) x = self.normalize(x) out = self.resnet(x) out = self.flatten(out) if self.unit_norm: out = torch.nn.functional.normalize(out, p=2, dim=-1) if dims == 3: out = out.squeeze(0) elif dims > 4: out = out.reshape(*orig_shape[:-3], -1) return out