|
import torch |
|
from torchvision import transforms |
|
from PIL import Image |
|
import json |
|
import numpy as np |
|
|
|
from transformers import AutoImageProcessor, SwinForImageClassification, ViTForImageClassification |
|
import torch.nn as nn |
|
import os |
|
import pandas as pd |
|
import random |
|
from huggingface_hub import hf_hub_download |
|
|
|
|
|
with open("labels.json", "r") as f: |
|
class_names = json.load(f) |
|
print("class_names:", class_names) |
|
|
|
class DeiT(nn.Module): |
|
def __init__(self, model_name="facebook/deit-small-patch16-224", num_classes=None): |
|
super(DeiT, self).__init__() |
|
self.model = ViTForImageClassification.from_pretrained(model_name) |
|
in_features = self.model.classifier.in_features |
|
self.model.classifier = nn.Sequential( |
|
nn.Linear(in_features, num_classes) |
|
) |
|
|
|
def forward(self, images): |
|
outputs = self.model(images) |
|
return outputs.logits |
|
|
|
|
|
model_path = hf_hub_download(repo_id="Noha90/AML_16", filename="deit_best_model.pth") |
|
print("Model path:", model_path) |
|
model = DeiT(num_classes=len(class_names)) |
|
state_dict = torch.load(model_path, map_location="cpu") |
|
if "model_state_dict" in state_dict: |
|
state_dict = state_dict["model_state_dict"] |
|
model.load_state_dict(state_dict, strict=False) |
|
model.eval() |
|
|
|
|
|
transform = transforms.Compose([ |
|
transforms.Resize((224, 224)), |
|
transforms.ToTensor(), |
|
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) |
|
]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def predict(image_path): |
|
|
|
image = Image.open(image_path).convert("RGB") |
|
x = transform(image).unsqueeze(0) |
|
|
|
with torch.no_grad(): |
|
outputs = model(x) |
|
print("Logits:", outputs.logits) |
|
probs = torch.nn.functional.softmax(outputs.logits, dim=1)[0] |
|
print("Probs:", probs) |
|
print("Sum of probs:", probs.sum()) |
|
top5 = torch.topk(probs, k=5) |
|
|
|
top1_idx = int(top5.indices[0]) |
|
top1_label = class_names[top1_idx] |
|
|
|
|
|
class_folder = f"sample_images/{str(top1_label).replace(' ', '_')}" |
|
reference_image = None |
|
if os.path.isdir(class_folder): |
|
|
|
image_files = [f for f in os.listdir(class_folder) if f.lower().endswith((".jpg", ".jpeg", ".png", ".bmp", ".gif", ".webp"))] |
|
if image_files: |
|
chosen_file = random.choice(image_files) |
|
ref_path = os.path.join(class_folder, chosen_file) |
|
print(f"[DEBUG] Randomly selected reference image: {ref_path}") |
|
reference_image = Image.open(ref_path).convert("RGB") |
|
else: |
|
print(f"[DEBUG] No images found in {class_folder}") |
|
else: |
|
print(f"[DEBUG] Class folder does not exist: {class_folder}") |
|
|
|
|
|
top5_probs = {class_names[int(idx)]: float(score) for idx, score in zip(top5.indices, top5.values)} |
|
print(f"image path: {image_path}") |
|
print(f"top1_label: {top1_label}") |
|
print(f"[DEBUG] Top-5 indices: {top5.indices}") |
|
print(f"[DEBUG] Top-5 labels: {[class_names[int(idx)] for idx in top5.indices]}") |
|
print(f"[DEBUG] Top-5 probs: {top5_probs}") |
|
|
|
return image, reference_image, top5_probs |
|
|
|
|
|
|
|
|