|
import torch |
|
import torch.nn as nn |
|
from transformers import SwinForImageClassification |
|
|
|
def quantize_model(model, mode="linear"): |
|
model.eval().cpu() |
|
if mode == "linear": |
|
return torch.quantization.quantize_dynamic(model, {nn.Linear}, dtype=torch.qint8) |
|
return model |
|
|
|
class SwinModel(nn.Module): |
|
def __init__(self, model_name="microsoft/swin-base-patch4-window7-224", num_classes=40, from_pretrained=False): |
|
super(SwinModel, self).__init__() |
|
|
|
if from_pretrained: |
|
self.model = SwinForImageClassification.from_pretrained(model_name) |
|
else: |
|
config = SwinForImageClassification.from_pretrained(model_name).config |
|
config.num_labels = num_classes |
|
self.model = SwinForImageClassification(config) |
|
|
|
in_features = self.model.classifier.in_features |
|
self.model.classifier = nn.Linear(in_features, num_classes) |
|
|
|
def forward(self, images): |
|
outputs = self.model(images) |
|
return outputs.logits |
|
|
|
def load_model(weights_path="best_model.pth", num_classes=40): |
|
model = SwinModel(num_classes=num_classes, from_pretrained=False) |
|
|
|
checkpoint = torch.load(weights_path, map_location="cpu") |
|
|
|
if "model_state_dict" in checkpoint: |
|
state_dict = checkpoint["model_state_dict"] |
|
else: |
|
state_dict = checkpoint |
|
|
|
filtered_state_dict = {k: v for k, v in state_dict.items() if "classifier" not in k} |
|
|
|
model.load_state_dict(filtered_state_dict, strict=False) |
|
|
|
model = quantize_model(model, mode="linear") |
|
model.eval() |
|
return model |
|
|