jeffacce
initial commit
393d3de
import torch
import torchvision
import torch.nn as nn
class resnet18(nn.Module):
def __init__(
self,
pretrained: bool = True,
output_dim: int = 512, # fixed for resnet18; included for consistency with config
unit_norm: bool = False,
):
super().__init__()
resnet = torchvision.models.resnet18(pretrained=pretrained)
self.resnet = nn.Sequential(*list(resnet.children())[:-1])
self.flatten = nn.Flatten()
self.pretrained = pretrained
self.normalize = torchvision.transforms.Normalize(
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
)
self.unit_norm = unit_norm
def forward(self, x):
dims = len(x.shape)
orig_shape = x.shape
if dims == 3:
x = x.unsqueeze(0)
elif dims > 4:
# flatten all dimensions to batch, then reshape back at the end
x = x.reshape(-1, *orig_shape[-3:])
x = self.normalize(x)
out = self.resnet(x)
out = self.flatten(out)
if self.unit_norm:
out = torch.nn.functional.normalize(out, p=2, dim=-1)
if dims == 3:
out = out.squeeze(0)
elif dims > 4:
out = out.reshape(*orig_shape[:-3], -1)
return out