Noha90 commited on
Commit
07132cf
·
1 Parent(s): 1d84837

use custom vit

Browse files
Files changed (1) hide show
  1. predict.py +44 -24
predict.py CHANGED
@@ -3,7 +3,6 @@ from torchvision import transforms
3
  from PIL import Image
4
  import json
5
  import numpy as np
6
- # from model import load_model
7
  from transformers import AutoImageProcessor, SwinForImageClassification, ViTForImageClassification, ViTImageProcessor
8
  import torch.nn as nn
9
  import os
@@ -16,53 +15,74 @@ with open("labels.json", "r") as f:
16
  class_names = json.load(f)
17
  print("class_names:", class_names)
18
 
19
- # Download your fine-tuned model checkpoint from the Hub
20
  model_path = hf_hub_download(repo_id="Noha90/AML_16", filename="vit_best_model.pth")
21
 
22
  processor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224')
23
- model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224')
24
 
25
- # Replace the classifier head for your number of classes
26
- in_features = model.classifier.in_features
27
- model.classifier = nn.Linear(in_features, 40)
 
 
 
28
 
 
 
 
 
 
 
29
  state_dict = torch.load(model_path, map_location="cpu")
30
  if "model_state_dict" in state_dict:
31
  state_dict = state_dict["model_state_dict"]
32
  model.load_state_dict(state_dict, strict=False)
33
  model.eval()
34
- # transform = transforms.Compose([
35
- # transforms.Resize((224, 224)),
36
- # transforms.ToTensor(),
37
- # transforms.Normalize(mean=[0.485, 0.456, 0.406],
38
- # std=[0.229, 0.224, 0.225])
39
- # ])
 
 
40
  def predict(image_path):
41
- # Load and prepare image
42
  image = Image.open(image_path).convert("RGB")
43
- inputs = processor(images=image, return_tensors="pt")
44
 
45
  with torch.no_grad():
46
- outputs = model(**inputs)
47
- logits = outputs.logits
48
- probs = torch.nn.functional.softmax(logits, dim=1)[0]
49
- print("Logits:", logits)
50
  print("Probs:", probs)
51
  print("Sum of probs:", probs.sum())
52
  top5 = torch.topk(probs, k=5)
53
 
54
- predicted_class_idx = logits.argmax(-1).item()
55
- top1_label = model.config.id2label[predicted_class_idx]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
 
57
- # Format Top-5 for gr.Label with class names
58
- top5_probs = {model.config.id2label[int(idx)]: float(score) for idx, score in zip(top5.indices, top5.values)}
59
  print(f"image path: {image_path}")
60
  print(f"top1_label: {top1_label}")
61
  print(f"[DEBUG] Top-5 indices: {top5.indices}")
62
- print(f"[DEBUG] Top-5 labels: {[model.config.id2label[int(idx)] for idx in top5.indices]}")
63
  print(f"[DEBUG] Top-5 probs: {top5_probs}")
64
 
65
- return image, None, top5_probs
66
 
67
 
68
 
 
3
  from PIL import Image
4
  import json
5
  import numpy as np
 
6
  from transformers import AutoImageProcessor, SwinForImageClassification, ViTForImageClassification, ViTImageProcessor
7
  import torch.nn as nn
8
  import os
 
15
  class_names = json.load(f)
16
  print("class_names:", class_names)
17
 
 
18
  model_path = hf_hub_download(repo_id="Noha90/AML_16", filename="vit_best_model.pth")
19
 
20
  processor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224')
 
21
 
22
+ class ViTCustom(nn.Module):
23
+ def __init__(self, model_name="google/vit-base-patch16-224", num_classes=40):
24
+ super(ViTCustom, self).__init__()
25
+ self.model = ViTForImageClassification.from_pretrained(model_name)
26
+ in_features = self.model.classifier.in_features
27
+ self.model.classifier = nn.Linear(in_features, num_classes)
28
 
29
+ def forward(self, images):
30
+ outputs = self.model(images)
31
+ return outputs.logits
32
+
33
+ # Load model
34
+ model = ViTCustom(model_name="google/vit-base-patch16-224", num_classes=40)
35
  state_dict = torch.load(model_path, map_location="cpu")
36
  if "model_state_dict" in state_dict:
37
  state_dict = state_dict["model_state_dict"]
38
  model.load_state_dict(state_dict, strict=False)
39
  model.eval()
40
+
41
+ transform = transforms.Compose([
42
+ transforms.Resize((224, 224)),
43
+ transforms.ToTensor(),
44
+ transforms.Normalize(mean=[0.485, 0.456, 0.406],
45
+ std=[0.229, 0.224, 0.225])
46
+ ])
47
+
48
  def predict(image_path):
 
49
  image = Image.open(image_path).convert("RGB")
50
+ x = transform(image).unsqueeze(0)
51
 
52
  with torch.no_grad():
53
+ outputs = model(x)
54
+ print("Logits:", outputs)
55
+ probs = torch.nn.functional.softmax(outputs, dim=1)[0]
 
56
  print("Probs:", probs)
57
  print("Sum of probs:", probs.sum())
58
  top5 = torch.topk(probs, k=5)
59
 
60
+ top1_idx = int(top5.indices[0])
61
+ top1_label = class_names[top1_idx]
62
+
63
+ # Select a random image from the class subfolder
64
+ class_folder = f"sample_images/{str(top1_label).replace(' ', '_')}"
65
+ reference_image = None
66
+ if os.path.isdir(class_folder):
67
+ image_files = [f for f in os.listdir(class_folder) if f.lower().endswith((".jpg", ".jpeg", ".png", ".bmp", ".gif", ".webp"))]
68
+ if image_files:
69
+ chosen_file = random.choice(image_files)
70
+ ref_path = os.path.join(class_folder, chosen_file)
71
+ print(f"[DEBUG] Randomly selected reference image: {ref_path}")
72
+ reference_image = Image.open(ref_path).convert("RGB")
73
+ else:
74
+ print(f"[DEBUG] No images found in {class_folder}")
75
+ else:
76
+ print(f"[DEBUG] Class folder does not exist: {class_folder}")
77
 
78
+ top5_probs = {class_names[int(idx)]: float(score) for idx, score in zip(top5.indices, top5.values)}
 
79
  print(f"image path: {image_path}")
80
  print(f"top1_label: {top1_label}")
81
  print(f"[DEBUG] Top-5 indices: {top5.indices}")
82
+ print(f"[DEBUG] Top-5 labels: {[class_names[int(idx)] for idx in top5.indices]}")
83
  print(f"[DEBUG] Top-5 probs: {top5_probs}")
84
 
85
+ return image, reference_image, top5_probs
86
 
87
 
88