File size: 3,582 Bytes
2c8e31c
7220b97
 
 
 
2c8e31c
 
 
 
7220b97
2c8e31c
 
 
 
 
 
7220b97
 
 
 
 
 
07132cf
52b50f7
7220b97
 
 
52b50f7
 
7220b97
 
 
 
2c8e31c
07132cf
 
 
 
7220b97
 
52b50f7
7220b97
fb7d78b
2061c9f
 
 
2c8e31c
07132cf
7220b97
07132cf
 
 
7220b97
07132cf
 
2c8e31c
 
07132cf
2c8e31c
07132cf
 
 
2c8e31c
 
 
 
07132cf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2c8e31c
07132cf
2c8e31c
 
 
07132cf
2c8e31c
 
7220b97
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import torch
import torch.nn as nn
import torch.nn.init as init
from transformers import SwinForImageClassification
from huggingface_hub import hf_hub_download
from PIL import Image
import json
import os
import random
from torchvision import transforms

# Load labels
with open("labels.json", "r") as f:
    class_names = json.load(f)
print("class_names:", class_names)

MODEL_NAME = "microsoft/swin-large-patch4-window7-224"  # or your preferred Swin model

class SwinCustom(nn.Module):
    def __init__(self, model_name=MODEL_NAME, num_classes=40):
        super(SwinCustom, self).__init__()
        self.model = SwinForImageClassification.from_pretrained(model_name, num_labels=num_classes, ignore_mismatched_sizes=True)
        in_features = self.model.classifier.in_features
        self.model.classifier = nn.Sequential(
            nn.Linear(in_features, in_features),
            nn.LeakyReLU(),
            nn.Dropout(0.3),
            nn.Linear(in_features, num_classes)
        )
        # Weight initialization
        for m in self.model.classifier:
            if isinstance(m, nn.Linear):
                init.kaiming_uniform_(m.weight, a=0, mode='fan_in', nonlinearity='leaky_relu')

    def forward(self, images):
        outputs = self.model(images)
        return outputs.logits

# Download your fine-tuned model checkpoint from the Hub
model_path = hf_hub_download(repo_id="Noha90/AML_16", filename="large_swin_best_model.pth")
print("Model path:", model_path)
model = SwinCustom(model_name=MODEL_NAME, num_classes=40)
state_dict = torch.load(model_path, map_location="cpu")
if "model_state_dict" in state_dict:
    state_dict = state_dict["model_state_dict"]
model.load_state_dict(state_dict, strict=False)
model.eval()

# Preprocessing (Swin default)
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

def predict(image_path):
    image = Image.open(image_path).convert("RGB")
    x = transform(image).unsqueeze(0)
    with torch.no_grad():
        outputs = model(x)
        print("Logits:", outputs)
        probs = torch.nn.functional.softmax(outputs, dim=1)[0]
        print("Probs:", probs)
        print("Sum of probs:", probs.sum())
        top5 = torch.topk(probs, k=5)

    top1_idx = int(top5.indices[0])
    top1_label = class_names[top1_idx]

    # Select a random image from the class subfolder
    class_folder = f"sample_images/{str(top1_label).replace(' ', '_')}"
    reference_image = None
    if os.path.isdir(class_folder):
        image_files = [f for f in os.listdir(class_folder) if f.lower().endswith((".jpg", ".jpeg", ".png", ".bmp", ".gif", ".webp"))]
        if image_files:
            chosen_file = random.choice(image_files)
            ref_path = os.path.join(class_folder, chosen_file)
            print(f"[DEBUG] Randomly selected reference image: {ref_path}")
            reference_image = Image.open(ref_path).convert("RGB")
        else:
            print(f"[DEBUG] No images found in {class_folder}")
    else:
        print(f"[DEBUG] Class folder does not exist: {class_folder}")

    top5_probs = {class_names[int(idx)]: float(score) for idx, score in zip(top5.indices, top5.values)}
    print(f"image path: {image_path}")
    print(f"top1_label: {top1_label}")
    print(f"[DEBUG] Top-5 indices: {top5.indices}")
    print(f"[DEBUG] Top-5 labels: {[class_names[int(idx)] for idx in top5.indices]}")
    print(f"[DEBUG] Top-5 probs: {top5_probs}")

    return image, reference_image, top5_probs