File size: 3,529 Bytes
2c8e31c 7220b97 db6ba08 7220b97 2c8e31c 7220b97 2c8e31c db6ba08 07132cf 52b50f7 db6ba08 52b50f7 db6ba08 2c8e31c 07132cf db6ba08 07132cf db6ba08 7220b97 db6ba08 52b50f7 db6ba08 fb7d78b 2061c9f 2c8e31c 07132cf db6ba08 07132cf 2c8e31c 07132cf 2c8e31c 07132cf db6ba08 2c8e31c 07132cf 2c8e31c 07132cf 2c8e31c 07132cf 2c8e31c 7220b97 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 |
import torch
import torch.nn as nn
import torch.nn.init as init
from transformers import AutoImageProcessor, AutoModelForImageClassification
from huggingface_hub import hf_hub_download
from PIL import Image
import json
import os
import random
from torchvision import transforms
# Load labels
with open("labels.json", "r") as f:
class_names = json.load(f)
print("class_names:", class_names)
class ViT(nn.Module):
def __init__(self, model_name="google/vit-base-patch16-224", num_classes=40, dropout_rate=0.1):
super(ViT, self).__init__()
self.extractor = AutoImageProcessor.from_pretrained(model_name)
self.model = AutoModelForImageClassification.from_pretrained(model_name)
in_features = self.model.classifier.in_features
self.model.classifier = nn.Sequential(
nn.Dropout(p=dropout_rate),
nn.Linear(in_features, num_classes)
)
self.img_size = (self.extractor.size['height'], self.extractor.size['width'])
self.normalize = transforms.Normalize(mean=self.extractor.image_mean, std=self.extractor.image_std)
def forward(self, images):
outputs = self.model(pixel_values=images)
return outputs.logits
def get_test_transforms(self):
return transforms.Compose([
transforms.Resize(self.img_size),
transforms.ToTensor(),
self.normalize
])
# Download your fine-tuned model checkpoint from the Hub
model_path = hf_hub_download(repo_id="Noha90/AML_16", filename="vit_best_model.pth")
print("Model path:", model_path)
model = ViT(model_name="google/vit-base-patch16-224", num_classes=40)
state_dict = torch.load(model_path, map_location="cpu")
if "model_state_dict" in state_dict:
state_dict = state_dict["model_state_dict"]
model.load_state_dict(state_dict, strict=False)
model.eval()
transform = model.get_test_transforms()
def predict(image_path):
image = Image.open(image_path).convert("RGB")
x = transform(image).unsqueeze(0)
with torch.no_grad():
outputs = model(x)
probs = torch.nn.functional.softmax(outputs, dim=1)[0]
print("Logits:", outputs)
print("Probs:", probs)
print("Sum of probs:", probs.sum())
top5 = torch.topk(probs, k=5)
top1_idx = int(top5.indices[0])
top1_label = class_names[top1_idx]
# Select a random image from the class subfolder
class_folder = f"sample_images/{str(top1_label).replace(' ', '_')}"
reference_image = None
if os.path.isdir(class_folder):
image_files = [f for f in os.listdir(class_folder) if f.lower().endswith((".jpg", ".jpeg", ".png", ".bmp", ".gif", ".webp"))]
if image_files:
chosen_file = random.choice(image_files)
ref_path = os.path.join(class_folder, chosen_file)
print(f"[DEBUG] Randomly selected reference image: {ref_path}")
reference_image = Image.open(ref_path).convert("RGB")
else:
print(f"[DEBUG] No images found in {class_folder}")
else:
print(f"[DEBUG] Class folder does not exist: {class_folder}")
top5_probs = {class_names[int(idx)]: float(score) for idx, score in zip(top5.indices, top5.values)}
print(f"image path: {image_path}")
print(f"top1_label: {top1_label}")
print(f"[DEBUG] Top-5 indices: {top5.indices}")
print(f"[DEBUG] Top-5 labels: {[class_names[int(idx)] for idx in top5.indices]}")
print(f"[DEBUG] Top-5 probs: {top5_probs}")
return image, reference_image, top5_probs |