File size: 3,582 Bytes
2c8e31c 7220b97 2c8e31c 7220b97 2c8e31c 7220b97 07132cf 52b50f7 7220b97 52b50f7 7220b97 2c8e31c 07132cf 7220b97 52b50f7 7220b97 fb7d78b 2061c9f 2c8e31c 07132cf 7220b97 07132cf 7220b97 07132cf 2c8e31c 07132cf 2c8e31c 07132cf 2c8e31c 07132cf 2c8e31c 07132cf 2c8e31c 07132cf 2c8e31c 7220b97 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 |
import torch
import torch.nn as nn
import torch.nn.init as init
from transformers import SwinForImageClassification
from huggingface_hub import hf_hub_download
from PIL import Image
import json
import os
import random
from torchvision import transforms
# Load labels
with open("labels.json", "r") as f:
class_names = json.load(f)
print("class_names:", class_names)
MODEL_NAME = "microsoft/swin-large-patch4-window7-224" # or your preferred Swin model
class SwinCustom(nn.Module):
def __init__(self, model_name=MODEL_NAME, num_classes=40):
super(SwinCustom, self).__init__()
self.model = SwinForImageClassification.from_pretrained(model_name, num_labels=num_classes, ignore_mismatched_sizes=True)
in_features = self.model.classifier.in_features
self.model.classifier = nn.Sequential(
nn.Linear(in_features, in_features),
nn.LeakyReLU(),
nn.Dropout(0.3),
nn.Linear(in_features, num_classes)
)
# Weight initialization
for m in self.model.classifier:
if isinstance(m, nn.Linear):
init.kaiming_uniform_(m.weight, a=0, mode='fan_in', nonlinearity='leaky_relu')
def forward(self, images):
outputs = self.model(images)
return outputs.logits
# Download your fine-tuned model checkpoint from the Hub
model_path = hf_hub_download(repo_id="Noha90/AML_16", filename="large_swin_best_model.pth")
print("Model path:", model_path)
model = SwinCustom(model_name=MODEL_NAME, num_classes=40)
state_dict = torch.load(model_path, map_location="cpu")
if "model_state_dict" in state_dict:
state_dict = state_dict["model_state_dict"]
model.load_state_dict(state_dict, strict=False)
model.eval()
# Preprocessing (Swin default)
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
def predict(image_path):
image = Image.open(image_path).convert("RGB")
x = transform(image).unsqueeze(0)
with torch.no_grad():
outputs = model(x)
print("Logits:", outputs)
probs = torch.nn.functional.softmax(outputs, dim=1)[0]
print("Probs:", probs)
print("Sum of probs:", probs.sum())
top5 = torch.topk(probs, k=5)
top1_idx = int(top5.indices[0])
top1_label = class_names[top1_idx]
# Select a random image from the class subfolder
class_folder = f"sample_images/{str(top1_label).replace(' ', '_')}"
reference_image = None
if os.path.isdir(class_folder):
image_files = [f for f in os.listdir(class_folder) if f.lower().endswith((".jpg", ".jpeg", ".png", ".bmp", ".gif", ".webp"))]
if image_files:
chosen_file = random.choice(image_files)
ref_path = os.path.join(class_folder, chosen_file)
print(f"[DEBUG] Randomly selected reference image: {ref_path}")
reference_image = Image.open(ref_path).convert("RGB")
else:
print(f"[DEBUG] No images found in {class_folder}")
else:
print(f"[DEBUG] Class folder does not exist: {class_folder}")
top5_probs = {class_names[int(idx)]: float(score) for idx, score in zip(top5.indices, top5.values)}
print(f"image path: {image_path}")
print(f"top1_label: {top1_label}")
print(f"[DEBUG] Top-5 indices: {top5.indices}")
print(f"[DEBUG] Top-5 labels: {[class_names[int(idx)] for idx in top5.indices]}")
print(f"[DEBUG] Top-5 probs: {top5_probs}")
return image, reference_image, top5_probs |