import torch import torch.nn as nn from transformers import SwinForImageClassification from huggingface_hub import hf_hub_download def quantize_model(model, mode="linear"): model.eval().cpu() if mode == "linear": return torch.quantization.quantize_dynamic(model, {nn.Linear}, dtype=torch.qint8) return model class SwinModel(nn.Module): def __init__(self, model_name="microsoft/swin-base-patch4-window7-224", num_classes=40, from_pretrained=False): super(SwinModel, self).__init__() if from_pretrained: self.model = SwinForImageClassification.from_pretrained(model_name) else: config = SwinForImageClassification.from_pretrained(model_name).config config.num_labels = num_classes self.model = SwinForImageClassification(config) in_features = self.model.classifier.in_features self.model.classifier = nn.Linear(in_features, num_classes) def forward(self, images): outputs = self.model(images) return outputs.logits def load_model(weights_path=None, num_classes=40): if weights_path is None: # Download from Hugging Face Hub if not provided weights_path = hf_hub_download(repo_id="Noha90/AML_16", filename="best_model.pth") model = SwinModel(num_classes=num_classes, from_pretrained=False) checkpoint = torch.load(weights_path, map_location="cpu") if "model_state_dict" in checkpoint: state_dict = checkpoint["model_state_dict"] else: state_dict = checkpoint filtered_state_dict = {k: v for k, v in state_dict.items() if "classifier" not in k} model.load_state_dict(filtered_state_dict, strict=False) model = quantize_model(model, mode="linear") model.eval() return model