AML_16 / model.py
Noha90's picture
Clean start: add all files with LFS tracking
2c8e31c
raw
history blame
1.58 kB
import torch
import torch.nn as nn
from transformers import SwinForImageClassification
def quantize_model(model, mode="linear"):
model.eval().cpu()
if mode == "linear":
return torch.quantization.quantize_dynamic(model, {nn.Linear}, dtype=torch.qint8)
return model
class SwinModel(nn.Module):
def __init__(self, model_name="microsoft/swin-base-patch4-window7-224", num_classes=40, from_pretrained=False):
super(SwinModel, self).__init__()
if from_pretrained:
self.model = SwinForImageClassification.from_pretrained(model_name)
else:
config = SwinForImageClassification.from_pretrained(model_name).config
config.num_labels = num_classes
self.model = SwinForImageClassification(config)
in_features = self.model.classifier.in_features
self.model.classifier = nn.Linear(in_features, num_classes)
def forward(self, images):
outputs = self.model(images)
return outputs.logits
def load_model(weights_path="best_model.pth", num_classes=40):
model = SwinModel(num_classes=num_classes, from_pretrained=False)
checkpoint = torch.load(weights_path, map_location="cpu")
if "model_state_dict" in checkpoint:
state_dict = checkpoint["model_state_dict"]
else:
state_dict = checkpoint
filtered_state_dict = {k: v for k, v in state_dict.items() if "classifier" not in k}
model.load_state_dict(filtered_state_dict, strict=False)
model = quantize_model(model, mode="linear")
model.eval()
return model