File size: 3,529 Bytes
2c8e31c
7220b97
 
db6ba08
7220b97
2c8e31c
 
 
 
7220b97
2c8e31c
 
 
 
 
 
db6ba08
 
 
 
 
 
07132cf
52b50f7
db6ba08
52b50f7
 
db6ba08
 
2c8e31c
07132cf
db6ba08
07132cf
 
db6ba08
 
 
 
 
 
 
7220b97
db6ba08
52b50f7
db6ba08
fb7d78b
2061c9f
 
 
2c8e31c
07132cf
db6ba08
07132cf
2c8e31c
 
07132cf
2c8e31c
07132cf
 
db6ba08
2c8e31c
 
 
 
07132cf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2c8e31c
07132cf
2c8e31c
 
 
07132cf
2c8e31c
 
7220b97
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import torch
import torch.nn as nn
import torch.nn.init as init
from transformers import AutoImageProcessor, AutoModelForImageClassification
from huggingface_hub import hf_hub_download
from PIL import Image
import json
import os
import random
from torchvision import transforms

# Load labels
with open("labels.json", "r") as f:
    class_names = json.load(f)
print("class_names:", class_names)

class ViT(nn.Module):
   
    def __init__(self, model_name="google/vit-base-patch16-224", num_classes=40, dropout_rate=0.1):
        super(ViT, self).__init__()
        self.extractor = AutoImageProcessor.from_pretrained(model_name)
        self.model = AutoModelForImageClassification.from_pretrained(model_name)
        in_features = self.model.classifier.in_features
        self.model.classifier = nn.Sequential(
            nn.Dropout(p=dropout_rate),
            nn.Linear(in_features, num_classes)
        )
        self.img_size = (self.extractor.size['height'], self.extractor.size['width'])
        self.normalize = transforms.Normalize(mean=self.extractor.image_mean, std=self.extractor.image_std)

    def forward(self, images):
        outputs = self.model(pixel_values=images)
        return outputs.logits

    def get_test_transforms(self):
        return transforms.Compose([
            transforms.Resize(self.img_size),
            transforms.ToTensor(),
            self.normalize
        ])

# Download your fine-tuned model checkpoint from the Hub
model_path = hf_hub_download(repo_id="Noha90/AML_16", filename="vit_best_model.pth")
print("Model path:", model_path)
model = ViT(model_name="google/vit-base-patch16-224", num_classes=40)
state_dict = torch.load(model_path, map_location="cpu")
if "model_state_dict" in state_dict:
    state_dict = state_dict["model_state_dict"]
model.load_state_dict(state_dict, strict=False)
model.eval()

transform = model.get_test_transforms()

def predict(image_path):
    image = Image.open(image_path).convert("RGB")
    x = transform(image).unsqueeze(0)
    with torch.no_grad():
        outputs = model(x)
        probs = torch.nn.functional.softmax(outputs, dim=1)[0]
        print("Logits:", outputs)
        print("Probs:", probs)
        print("Sum of probs:", probs.sum())
        top5 = torch.topk(probs, k=5)

    top1_idx = int(top5.indices[0])
    top1_label = class_names[top1_idx]

    # Select a random image from the class subfolder
    class_folder = f"sample_images/{str(top1_label).replace(' ', '_')}"
    reference_image = None
    if os.path.isdir(class_folder):
        image_files = [f for f in os.listdir(class_folder) if f.lower().endswith((".jpg", ".jpeg", ".png", ".bmp", ".gif", ".webp"))]
        if image_files:
            chosen_file = random.choice(image_files)
            ref_path = os.path.join(class_folder, chosen_file)
            print(f"[DEBUG] Randomly selected reference image: {ref_path}")
            reference_image = Image.open(ref_path).convert("RGB")
        else:
            print(f"[DEBUG] No images found in {class_folder}")
    else:
        print(f"[DEBUG] Class folder does not exist: {class_folder}")

    top5_probs = {class_names[int(idx)]: float(score) for idx, score in zip(top5.indices, top5.values)}
    print(f"image path: {image_path}")
    print(f"top1_label: {top1_label}")
    print(f"[DEBUG] Top-5 indices: {top5.indices}")
    print(f"[DEBUG] Top-5 labels: {[class_names[int(idx)] for idx in top5.indices]}")
    print(f"[DEBUG] Top-5 probs: {top5_probs}")

    return image, reference_image, top5_probs