import torch from torchvision import transforms from PIL import Image import json import numpy as np # from model import load_model from transformers import AutoImageProcessor, SwinForImageClassification, ViTForImageClassification import torch.nn as nn import os import pandas as pd import random from huggingface_hub import hf_hub_download # Load labels with open("labels.json", "r") as f: class_names = json.load(f) print("class_names:", class_names) class DeiT(nn.Module): def __init__(self, model_name="facebook/deit-small-patch16-224", num_classes=None): super(DeiT, self).__init__() self.model = ViTForImageClassification.from_pretrained(model_name) in_features = self.model.classifier.in_features self.model.classifier = nn.Sequential( nn.Linear(in_features, num_classes) ) def forward(self, images): outputs = self.model(images) return outputs.logits # Load model model_path = hf_hub_download(repo_id="Noha90/AML_16", filename="deit_best_model.pth") print("Model path_check:", model_path) model = DeiT(num_classes=len(class_names)) state_dict = torch.load(model_path, map_location="cpu") model.load_state_dict(state_dict) model.eval() #deit transform transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) ]) #Swin # transform = transforms.Compose([ # transforms.Resize((224, 224)), # transforms.ToTensor(), # transforms.Normalize(mean=[0.485, 0.456, 0.406], # std=[0.229, 0.224, 0.225]) # ]) def predict(image_path): # Load and prepare image image = Image.open(image_path).convert("RGB") x = transform(image).unsqueeze(0) with torch.no_grad(): outputs = model(x) print("Logits:", outputs.logits) probs = torch.nn.functional.softmax(outputs.logits, dim=1)[0] print("Probs:", probs) print("Sum of probs:", probs.sum()) top5 = torch.topk(probs, k=5) top1_idx = int(top5.indices[0]) top1_label = class_names[top1_idx] # Select a random image from the class subfolder class_folder = f"sample_images/{str(top1_label).replace(' ', '_')}" reference_image = None if os.path.isdir(class_folder): # List all image files in the folder image_files = [f for f in os.listdir(class_folder) if f.lower().endswith((".jpg", ".jpeg", ".png", ".bmp", ".gif", ".webp"))] if image_files: chosen_file = random.choice(image_files) ref_path = os.path.join(class_folder, chosen_file) print(f"[DEBUG] Randomly selected reference image: {ref_path}") reference_image = Image.open(ref_path).convert("RGB") else: print(f"[DEBUG] No images found in {class_folder}") else: print(f"[DEBUG] Class folder does not exist: {class_folder}") # Format Top-5 for gr.Label with class names top5_probs = {class_names[int(idx)]: float(score) for idx, score in zip(top5.indices, top5.values)} print(f"image path: {image_path}") print(f"top1_label: {top1_label}") print(f"[DEBUG] Top-5 indices: {top5.indices}") print(f"[DEBUG] Top-5 labels: {[class_names[int(idx)] for idx in top5.indices]}") print(f"[DEBUG] Top-5 probs: {top5_probs}") return image, reference_image, top5_probs