File size: 1,780 Bytes
2c8e31c
 
 
937214e
2c8e31c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
937214e
 
 
 
2c8e31c
937214e
2c8e31c
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import torch
import torch.nn as nn
from transformers import SwinForImageClassification
from huggingface_hub import hf_hub_download

def quantize_model(model, mode="linear"):
    model.eval().cpu()
    if mode == "linear":
        return torch.quantization.quantize_dynamic(model, {nn.Linear}, dtype=torch.qint8)
    return model

class SwinModel(nn.Module):
    def __init__(self, model_name="microsoft/swin-base-patch4-window7-224", num_classes=40, from_pretrained=False):
        super(SwinModel, self).__init__()

        if from_pretrained:
            self.model = SwinForImageClassification.from_pretrained(model_name)
        else:
            config = SwinForImageClassification.from_pretrained(model_name).config
            config.num_labels = num_classes
            self.model = SwinForImageClassification(config)

        in_features = self.model.classifier.in_features
        self.model.classifier = nn.Linear(in_features, num_classes)

    def forward(self, images):
        outputs = self.model(images)
        return outputs.logits

def load_model(weights_path=None, num_classes=40):
    if weights_path is None:
        # Download from Hugging Face Hub if not provided
        weights_path = hf_hub_download(repo_id="Noha90/AML_16", filename="best_model.pth")

    model = SwinModel(num_classes=num_classes, from_pretrained=False)
    checkpoint = torch.load(weights_path, map_location="cpu")
    if "model_state_dict" in checkpoint:
        state_dict = checkpoint["model_state_dict"]
    else:
        state_dict = checkpoint
    filtered_state_dict = {k: v for k, v in state_dict.items() if "classifier" not in k}
    model.load_state_dict(filtered_state_dict, strict=False)
    model = quantize_model(model, mode="linear")
    model.eval()
    return model