File size: 3,479 Bytes
2c8e31c
7220b97
 
925a7e9
7220b97
2c8e31c
 
 
 
7220b97
2c8e31c
 
 
 
 
 
925a7e9
 
 
 
 
 
07132cf
52b50f7
925a7e9
 
 
52b50f7
 
925a7e9
 
 
 
2c8e31c
07132cf
925a7e9
07132cf
 
cacd37b
52b50f7
925a7e9
fb7d78b
2061c9f
 
 
2c8e31c
07132cf
925a7e9
 
 
 
 
 
07132cf
2c8e31c
 
07132cf
2c8e31c
07132cf
db6ba08
925a7e9
2c8e31c
 
 
 
07132cf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2c8e31c
07132cf
2c8e31c
 
 
07132cf
2c8e31c
 
7220b97
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import torch
import torch.nn as nn
import torch.nn.init as init
from transformers import SwinForImageClassification
from huggingface_hub import hf_hub_download
from PIL import Image
import json
import os
import random
from torchvision import transforms

# Load labels
with open("labels.json", "r") as f:
    class_names = json.load(f)
print("class_names:", class_names)

MODEL_NAME = "microsoft/swin-large-patch4-window7-224"  

class SwinCustom(nn.Module):
    def __init__(self, model_name=MODEL_NAME, num_classes=40):
        super(SwinCustom, self).__init__()
        self.model = SwinForImageClassification.from_pretrained(model_name, num_labels=num_classes, ignore_mismatched_sizes=True)
        in_features = self.model.classifier.in_features
        self.model.classifier = nn.Sequential(
            nn.Linear(in_features, in_features),
            nn.LeakyReLU(),
            nn.Dropout(0.3),
            nn.Linear(in_features, num_classes)
        )
        # Weight initialization
        for m in self.model.classifier:
            if isinstance(m, nn.Linear):
                init.kaiming_uniform_(m.weight, a=0, mode='fan_in', nonlinearity='leaky_relu')

    def forward(self, images):
        outputs = self.model(images)
        return outputs.logits

model_path = hf_hub_download(repo_id="Noha90/AML_16", filename="swin_large_quantised.pth")
print("Model path:", model_path)
model = SwinCustom(model_name=MODEL_NAME, num_classes=40)
state_dict = torch.load(model_path, map_location="cpu")
if "model_state_dict" in state_dict:
    state_dict = state_dict["model_state_dict"]
model.load_state_dict(state_dict, strict=False)
model.eval()

# Preprocessing
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

def predict(image_path):
    image = Image.open(image_path).convert("RGB")
    x = transform(image).unsqueeze(0)
    with torch.no_grad():
        outputs = model(x)
        print("Logits:", outputs)
        probs = torch.nn.functional.softmax(outputs, dim=1)[0]
        print("Probs:", probs)
        print("Sum of probs:", probs.sum())
        top5 = torch.topk(probs, k=5)

    top1_idx = int(top5.indices[0])
    top1_label = class_names[top1_idx]

    # Select a random image from the class subfolder
    class_folder = f"sample_images/{str(top1_label).replace(' ', '_')}"
    reference_image = None
    if os.path.isdir(class_folder):
        image_files = [f for f in os.listdir(class_folder) if f.lower().endswith((".jpg", ".jpeg", ".png", ".bmp", ".gif", ".webp"))]
        if image_files:
            chosen_file = random.choice(image_files)
            ref_path = os.path.join(class_folder, chosen_file)
            print(f"[DEBUG] Randomly selected reference image: {ref_path}")
            reference_image = Image.open(ref_path).convert("RGB")
        else:
            print(f"[DEBUG] No images found in {class_folder}")
    else:
        print(f"[DEBUG] Class folder does not exist: {class_folder}")

    top5_probs = {class_names[int(idx)]: float(score) for idx, score in zip(top5.indices, top5.values)}
    print(f"image path: {image_path}")
    print(f"top1_label: {top1_label}")
    print(f"[DEBUG] Top-5 indices: {top5.indices}")
    print(f"[DEBUG] Top-5 labels: {[class_names[int(idx)] for idx in top5.indices]}")
    print(f"[DEBUG] Top-5 probs: {top5_probs}")

    return image, reference_image, top5_probs