import torch import torch.nn as nn import torch.nn.init as init from transformers import SwinForImageClassification from huggingface_hub import hf_hub_download from PIL import Image import json import os import random from torchvision import transforms # Load labels with open("labels.json", "r") as f: class_names = json.load(f) print("class_names:", class_names) MODEL_NAME = "microsoft/swin-large-patch4-window7-224" class SwinCustom(nn.Module): def __init__(self, model_name=MODEL_NAME, num_classes=40): super(SwinCustom, self).__init__() self.model = SwinForImageClassification.from_pretrained(model_name, num_labels=num_classes, ignore_mismatched_sizes=True) in_features = self.model.classifier.in_features self.model.classifier = nn.Sequential( nn.Linear(in_features, in_features), nn.LeakyReLU(), nn.Dropout(0.3), nn.Linear(in_features, num_classes) ) # Weight initialization for m in self.model.classifier: if isinstance(m, nn.Linear): init.kaiming_uniform_(m.weight, a=0, mode='fan_in', nonlinearity='leaky_relu') def forward(self, images): outputs = self.model(images) return outputs.logits model_path = hf_hub_download(repo_id="Noha90/AML_16", filename="large_swin_best_model.pth") print("Model path:", model_path) model = SwinCustom(model_name=MODEL_NAME, num_classes=40) state_dict = torch.load(model_path, map_location="cpu") if "model_state_dict" in state_dict: state_dict = state_dict["model_state_dict"] model.load_state_dict(state_dict, strict=False) model.eval() # Preprocessing transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) def predict(image_path): image = Image.open(image_path).convert("RGB") x = transform(image).unsqueeze(0) with torch.no_grad(): outputs = model(x) print("Logits:", outputs) probs = torch.nn.functional.softmax(outputs, dim=1)[0] print("Probs:", probs) print("Sum of probs:", probs.sum()) top5 = torch.topk(probs, k=5) top1_idx = int(top5.indices[0]) top1_label = class_names[top1_idx] # Select a random image from the class subfolder class_folder = f"sample_images/{str(top1_label).replace(' ', '_')}" reference_image = None if os.path.isdir(class_folder): image_files = [f for f in os.listdir(class_folder) if f.lower().endswith((".jpg", ".jpeg", ".png", ".bmp", ".gif", ".webp"))] if image_files: chosen_file = random.choice(image_files) ref_path = os.path.join(class_folder, chosen_file) print(f"[DEBUG] Randomly selected reference image: {ref_path}") reference_image = Image.open(ref_path).convert("RGB") else: print(f"[DEBUG] No images found in {class_folder}") else: print(f"[DEBUG] Class folder does not exist: {class_folder}") top5_probs = {class_names[int(idx)]: float(score) for idx, score in zip(top5.indices, top5.values)} print(f"image path: {image_path}") print(f"top1_label: {top1_label}") print(f"[DEBUG] Top-5 indices: {top5.indices}") print(f"[DEBUG] Top-5 labels: {[class_names[int(idx)] for idx in top5.indices]}") print(f"[DEBUG] Top-5 probs: {top5_probs}") return image, reference_image, top5_probs