File size: 1,575 Bytes
2c8e31c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import torch
import torch.nn as nn
from transformers import SwinForImageClassification

def quantize_model(model, mode="linear"):
    model.eval().cpu()
    if mode == "linear":
        return torch.quantization.quantize_dynamic(model, {nn.Linear}, dtype=torch.qint8)
    return model

class SwinModel(nn.Module):
    def __init__(self, model_name="microsoft/swin-base-patch4-window7-224", num_classes=40, from_pretrained=False):
        super(SwinModel, self).__init__()

        if from_pretrained:
            self.model = SwinForImageClassification.from_pretrained(model_name)
        else:
            config = SwinForImageClassification.from_pretrained(model_name).config
            config.num_labels = num_classes
            self.model = SwinForImageClassification(config)

        in_features = self.model.classifier.in_features
        self.model.classifier = nn.Linear(in_features, num_classes)

    def forward(self, images):
        outputs = self.model(images)
        return outputs.logits

def load_model(weights_path="best_model.pth", num_classes=40):
    model = SwinModel(num_classes=num_classes, from_pretrained=False)

    checkpoint = torch.load(weights_path, map_location="cpu")

    if "model_state_dict" in checkpoint:
        state_dict = checkpoint["model_state_dict"]
    else:
        state_dict = checkpoint

    filtered_state_dict = {k: v for k, v in state_dict.items() if "classifier" not in k}

    model.load_state_dict(filtered_state_dict, strict=False)

    model = quantize_model(model, mode="linear")
    model.eval()
    return model