File size: 3,479 Bytes
2c8e31c 7220b97 925a7e9 7220b97 2c8e31c 7220b97 2c8e31c 925a7e9 07132cf 52b50f7 925a7e9 52b50f7 925a7e9 2c8e31c 07132cf 925a7e9 07132cf cacd37b 52b50f7 925a7e9 fb7d78b 2061c9f 2c8e31c 07132cf 925a7e9 07132cf 2c8e31c 07132cf 2c8e31c 07132cf db6ba08 925a7e9 2c8e31c 07132cf 2c8e31c 07132cf 2c8e31c 07132cf 2c8e31c 7220b97 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 |
import torch
import torch.nn as nn
import torch.nn.init as init
from transformers import SwinForImageClassification
from huggingface_hub import hf_hub_download
from PIL import Image
import json
import os
import random
from torchvision import transforms
# Load labels
with open("labels.json", "r") as f:
class_names = json.load(f)
print("class_names:", class_names)
MODEL_NAME = "microsoft/swin-large-patch4-window7-224"
class SwinCustom(nn.Module):
def __init__(self, model_name=MODEL_NAME, num_classes=40):
super(SwinCustom, self).__init__()
self.model = SwinForImageClassification.from_pretrained(model_name, num_labels=num_classes, ignore_mismatched_sizes=True)
in_features = self.model.classifier.in_features
self.model.classifier = nn.Sequential(
nn.Linear(in_features, in_features),
nn.LeakyReLU(),
nn.Dropout(0.3),
nn.Linear(in_features, num_classes)
)
# Weight initialization
for m in self.model.classifier:
if isinstance(m, nn.Linear):
init.kaiming_uniform_(m.weight, a=0, mode='fan_in', nonlinearity='leaky_relu')
def forward(self, images):
outputs = self.model(images)
return outputs.logits
model_path = hf_hub_download(repo_id="Noha90/AML_16", filename="swin_large_quantised.pth")
print("Model path:", model_path)
model = SwinCustom(model_name=MODEL_NAME, num_classes=40)
state_dict = torch.load(model_path, map_location="cpu")
if "model_state_dict" in state_dict:
state_dict = state_dict["model_state_dict"]
model.load_state_dict(state_dict, strict=False)
model.eval()
# Preprocessing
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
def predict(image_path):
image = Image.open(image_path).convert("RGB")
x = transform(image).unsqueeze(0)
with torch.no_grad():
outputs = model(x)
print("Logits:", outputs)
probs = torch.nn.functional.softmax(outputs, dim=1)[0]
print("Probs:", probs)
print("Sum of probs:", probs.sum())
top5 = torch.topk(probs, k=5)
top1_idx = int(top5.indices[0])
top1_label = class_names[top1_idx]
# Select a random image from the class subfolder
class_folder = f"sample_images/{str(top1_label).replace(' ', '_')}"
reference_image = None
if os.path.isdir(class_folder):
image_files = [f for f in os.listdir(class_folder) if f.lower().endswith((".jpg", ".jpeg", ".png", ".bmp", ".gif", ".webp"))]
if image_files:
chosen_file = random.choice(image_files)
ref_path = os.path.join(class_folder, chosen_file)
print(f"[DEBUG] Randomly selected reference image: {ref_path}")
reference_image = Image.open(ref_path).convert("RGB")
else:
print(f"[DEBUG] No images found in {class_folder}")
else:
print(f"[DEBUG] Class folder does not exist: {class_folder}")
top5_probs = {class_names[int(idx)]: float(score) for idx, score in zip(top5.indices, top5.values)}
print(f"image path: {image_path}")
print(f"top1_label: {top1_label}")
print(f"[DEBUG] Top-5 indices: {top5.indices}")
print(f"[DEBUG] Top-5 labels: {[class_names[int(idx)] for idx in top5.indices]}")
print(f"[DEBUG] Top-5 probs: {top5_probs}")
return image, reference_image, top5_probs |