import torch from torchvision import transforms from PIL import Image import json import numpy as np # from model import load_model from transformers import AutoImageProcessor, SwinForImageClassification import torch.nn as nn import os import pandas as pd import random # Load labels with open("labels.json", "r") as f: class_names = json.load(f) print("class_names:", class_names) # Load model model = SwinForImageClassification.from_pretrained("microsoft/swin-base-patch4-window7-224") model.classifier = torch.nn.Linear(model.classifier.in_features, len(class_names)) state_dict = torch.load("best_model.pth", map_location="cpu") # Remove incompatible keys (classifier weights) filtered_state_dict = {k: v for k, v in state_dict.items() if "classifier" not in k} model.load_state_dict(filtered_state_dict, strict=False) model.eval() # Image transform # transform = transforms.Compose([ # transforms.Resize((224, 224)), # transforms.ToTensor(), # transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) # ]) #Swin transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) def predict(image_path): # Load and prepare image image = Image.open(image_path).convert("RGB") x = transform(image).unsqueeze(0) with torch.no_grad(): outputs = model(x) print("Logits:", outputs.logits) probs = torch.nn.functional.softmax(outputs.logits, dim=1)[0] print("Probs:", probs) print("Sum of probs:", probs.sum()) top5 = torch.topk(probs, k=5) top1_idx = int(top5.indices[0]) top1_label = class_names[top1_idx] # Select a random image from the class subfolder class_folder = f"sample_images/{str(top1_label).replace(' ', '_')}" reference_image = None if os.path.isdir(class_folder): # List all image files in the folder image_files = [f for f in os.listdir(class_folder) if f.lower().endswith((".jpg", ".jpeg", ".png", ".bmp", ".gif", ".webp"))] if image_files: chosen_file = random.choice(image_files) ref_path = os.path.join(class_folder, chosen_file) print(f"[DEBUG] Randomly selected reference image: {ref_path}") reference_image = Image.open(ref_path).convert("RGB") else: print(f"[DEBUG] No images found in {class_folder}") else: print(f"[DEBUG] Class folder does not exist: {class_folder}") # Format Top-5 for gr.Label with class names top5_probs = {class_names[int(idx)]: float(score) for idx, score in zip(top5.indices, top5.values)} print(f"image path: {image_path}") print(f"top1_label: {top1_label}") print(f"[DEBUG] Top-5 indices: {top5.indices}") print(f"[DEBUG] Top-5 labels: {[class_names[int(idx)] for idx in top5.indices]}") print(f"[DEBUG] Top-5 probs: {top5_probs}") return image, reference_image, top5_probs