|
import torch |
|
import torch.nn as nn |
|
import torch.nn.init as init |
|
from transformers import SwinForImageClassification |
|
from huggingface_hub import hf_hub_download |
|
from PIL import Image |
|
import json |
|
import os |
|
import random |
|
from torchvision import transforms |
|
|
|
|
|
with open("labels.json", "r") as f: |
|
class_names = json.load(f) |
|
print("class_names:", class_names) |
|
|
|
MODEL_NAME = "microsoft/swin-large-patch4-window7-224" |
|
|
|
class SwinCustom(nn.Module): |
|
def __init__(self, model_name=MODEL_NAME, num_classes=40): |
|
super(SwinCustom, self).__init__() |
|
self.model = SwinForImageClassification.from_pretrained(model_name, num_labels=num_classes, ignore_mismatched_sizes=True) |
|
in_features = self.model.classifier.in_features |
|
self.model.classifier = nn.Sequential( |
|
nn.Linear(in_features, in_features), |
|
nn.LeakyReLU(), |
|
nn.Dropout(0.3), |
|
nn.Linear(in_features, num_classes) |
|
) |
|
|
|
for m in self.model.classifier: |
|
if isinstance(m, nn.Linear): |
|
init.kaiming_uniform_(m.weight, a=0, mode='fan_in', nonlinearity='leaky_relu') |
|
|
|
def forward(self, images): |
|
outputs = self.model(images) |
|
return outputs.logits |
|
|
|
model_path = hf_hub_download(repo_id="Noha90/AML_16", filename="large_swin_best_model.pth") |
|
print("Model path:", model_path) |
|
model = SwinCustom(model_name=MODEL_NAME, num_classes=40) |
|
state_dict = torch.load(model_path, map_location="cpu") |
|
if "model_state_dict" in state_dict: |
|
state_dict = state_dict["model_state_dict"] |
|
model.load_state_dict(state_dict, strict=False) |
|
model.eval() |
|
|
|
|
|
transform = transforms.Compose([ |
|
transforms.Resize((224, 224)), |
|
transforms.ToTensor(), |
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) |
|
]) |
|
|
|
def predict(image_path): |
|
image = Image.open(image_path).convert("RGB") |
|
x = transform(image).unsqueeze(0) |
|
with torch.no_grad(): |
|
outputs = model(x) |
|
print("Logits:", outputs) |
|
probs = torch.nn.functional.softmax(outputs, dim=1)[0] |
|
print("Probs:", probs) |
|
print("Sum of probs:", probs.sum()) |
|
top5 = torch.topk(probs, k=5) |
|
|
|
top1_idx = int(top5.indices[0]) |
|
top1_label = class_names[top1_idx] |
|
|
|
|
|
class_folder = f"sample_images/{str(top1_label).replace(' ', '_')}" |
|
reference_image = None |
|
if os.path.isdir(class_folder): |
|
image_files = [f for f in os.listdir(class_folder) if f.lower().endswith((".jpg", ".jpeg", ".png", ".bmp", ".gif", ".webp"))] |
|
if image_files: |
|
chosen_file = random.choice(image_files) |
|
ref_path = os.path.join(class_folder, chosen_file) |
|
print(f"[DEBUG] Randomly selected reference image: {ref_path}") |
|
reference_image = Image.open(ref_path).convert("RGB") |
|
else: |
|
print(f"[DEBUG] No images found in {class_folder}") |
|
else: |
|
print(f"[DEBUG] Class folder does not exist: {class_folder}") |
|
|
|
top5_probs = {class_names[int(idx)]: float(score) for idx, score in zip(top5.indices, top5.values)} |
|
print(f"image path: {image_path}") |
|
print(f"top1_label: {top1_label}") |
|
print(f"[DEBUG] Top-5 indices: {top5.indices}") |
|
print(f"[DEBUG] Top-5 labels: {[class_names[int(idx)] for idx in top5.indices]}") |
|
print(f"[DEBUG] Top-5 probs: {top5_probs}") |
|
|
|
return image, reference_image, top5_probs |