|
import torch |
|
from torchvision import transforms |
|
from PIL import Image |
|
import json |
|
import numpy as np |
|
|
|
from transformers import AutoImageProcessor, SwinForImageClassification |
|
import torch.nn as nn |
|
import os |
|
import pandas as pd |
|
import random |
|
|
|
|
|
with open("labels.json", "r") as f: |
|
class_names = json.load(f) |
|
print("class_names:", class_names) |
|
|
|
|
|
|
|
model = SwinForImageClassification.from_pretrained("microsoft/swin-base-patch4-window7-224") |
|
|
|
model.classifier = torch.nn.Linear(model.classifier.in_features, len(class_names)) |
|
|
|
state_dict = torch.load("best_model.pth", map_location="cpu") |
|
|
|
|
|
filtered_state_dict = {k: v for k, v in state_dict.items() if "classifier" not in k} |
|
model.load_state_dict(filtered_state_dict, strict=False) |
|
|
|
model.eval() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
transform = transforms.Compose([ |
|
transforms.Resize((224, 224)), |
|
transforms.ToTensor(), |
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], |
|
std=[0.229, 0.224, 0.225]) |
|
]) |
|
|
|
|
|
def predict(image_path): |
|
|
|
image = Image.open(image_path).convert("RGB") |
|
x = transform(image).unsqueeze(0) |
|
|
|
with torch.no_grad(): |
|
outputs = model(x) |
|
print("Logits:", outputs.logits) |
|
probs = torch.nn.functional.softmax(outputs.logits, dim=1)[0] |
|
print("Probs:", probs) |
|
print("Sum of probs:", probs.sum()) |
|
top5 = torch.topk(probs, k=5) |
|
|
|
top1_idx = int(top5.indices[0]) |
|
top1_label = class_names[top1_idx] |
|
|
|
|
|
class_folder = f"sample_images/{str(top1_label).replace(' ', '_')}" |
|
reference_image = None |
|
if os.path.isdir(class_folder): |
|
|
|
image_files = [f for f in os.listdir(class_folder) if f.lower().endswith((".jpg", ".jpeg", ".png", ".bmp", ".gif", ".webp"))] |
|
if image_files: |
|
chosen_file = random.choice(image_files) |
|
ref_path = os.path.join(class_folder, chosen_file) |
|
print(f"[DEBUG] Randomly selected reference image: {ref_path}") |
|
reference_image = Image.open(ref_path).convert("RGB") |
|
else: |
|
print(f"[DEBUG] No images found in {class_folder}") |
|
else: |
|
print(f"[DEBUG] Class folder does not exist: {class_folder}") |
|
|
|
|
|
top5_probs = {class_names[int(idx)]: float(score) for idx, score in zip(top5.indices, top5.values)} |
|
print(f"image path: {image_path}") |
|
print(f"top1_label: {top1_label}") |
|
print(f"[DEBUG] Top-5 indices: {top5.indices}") |
|
print(f"[DEBUG] Top-5 labels: {[class_names[int(idx)] for idx in top5.indices]}") |
|
print(f"[DEBUG] Top-5 probs: {top5_probs}") |
|
|
|
return image, reference_image, top5_probs |
|
|
|
|
|
|
|
|