Noha90 commited on
Commit
52b50f7
·
1 Parent(s): 07132cf

roll back to deit

Browse files
Files changed (1) hide show
  1. predict.py +17 -11
predict.py CHANGED
@@ -3,7 +3,8 @@ from torchvision import transforms
3
  from PIL import Image
4
  import json
5
  import numpy as np
6
- from transformers import AutoImageProcessor, SwinForImageClassification, ViTForImageClassification, ViTImageProcessor
 
7
  import torch.nn as nn
8
  import os
9
  import pandas as pd
@@ -15,29 +16,31 @@ with open("labels.json", "r") as f:
15
  class_names = json.load(f)
16
  print("class_names:", class_names)
17
 
18
- model_path = hf_hub_download(repo_id="Noha90/AML_16", filename="vit_best_model.pth")
19
-
20
- processor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224')
21
-
22
- class ViTCustom(nn.Module):
23
- def __init__(self, model_name="google/vit-base-patch16-224", num_classes=40):
24
- super(ViTCustom, self).__init__()
25
  self.model = ViTForImageClassification.from_pretrained(model_name)
26
  in_features = self.model.classifier.in_features
27
- self.model.classifier = nn.Linear(in_features, num_classes)
 
 
28
 
29
  def forward(self, images):
30
  outputs = self.model(images)
31
  return outputs.logits
32
 
33
  # Load model
34
- model = ViTCustom(model_name="google/vit-base-patch16-224", num_classes=40)
 
 
35
  state_dict = torch.load(model_path, map_location="cpu")
36
  if "model_state_dict" in state_dict:
37
  state_dict = state_dict["model_state_dict"]
38
  model.load_state_dict(state_dict, strict=False)
39
  model.eval()
40
 
 
 
41
  transform = transforms.Compose([
42
  transforms.Resize((224, 224)),
43
  transforms.ToTensor(),
@@ -45,7 +48,9 @@ transform = transforms.Compose([
45
  std=[0.229, 0.224, 0.225])
46
  ])
47
 
 
48
  def predict(image_path):
 
49
  image = Image.open(image_path).convert("RGB")
50
  x = transform(image).unsqueeze(0)
51
 
@@ -64,6 +69,7 @@ def predict(image_path):
64
  class_folder = f"sample_images/{str(top1_label).replace(' ', '_')}"
65
  reference_image = None
66
  if os.path.isdir(class_folder):
 
67
  image_files = [f for f in os.listdir(class_folder) if f.lower().endswith((".jpg", ".jpeg", ".png", ".bmp", ".gif", ".webp"))]
68
  if image_files:
69
  chosen_file = random.choice(image_files)
@@ -75,6 +81,7 @@ def predict(image_path):
75
  else:
76
  print(f"[DEBUG] Class folder does not exist: {class_folder}")
77
 
 
78
  top5_probs = {class_names[int(idx)]: float(score) for idx, score in zip(top5.indices, top5.values)}
79
  print(f"image path: {image_path}")
80
  print(f"top1_label: {top1_label}")
@@ -85,4 +92,3 @@ def predict(image_path):
85
  return image, reference_image, top5_probs
86
 
87
 
88
-
 
3
  from PIL import Image
4
  import json
5
  import numpy as np
6
+ # from model import load_model
7
+ from transformers import AutoImageProcessor, SwinForImageClassification, ViTForImageClassification
8
  import torch.nn as nn
9
  import os
10
  import pandas as pd
 
16
  class_names = json.load(f)
17
  print("class_names:", class_names)
18
 
19
+ class DeiT(nn.Module):
20
+ def __init__(self, model_name="facebook/deit-small-patch16-224", num_classes=40):
21
+ super(DeiT, self).__init__()
 
 
 
 
22
  self.model = ViTForImageClassification.from_pretrained(model_name)
23
  in_features = self.model.classifier.in_features
24
+ self.model.classifier = nn.Sequential(
25
+ nn.Linear(in_features, num_classes)
26
+ )
27
 
28
  def forward(self, images):
29
  outputs = self.model(images)
30
  return outputs.logits
31
 
32
  # Load model
33
+ model_path = hf_hub_download(repo_id="Noha90/AML_16", filename="deit_best_model_1.pth")
34
+ print("Model path:", model_path)
35
+ model = DeiT(model_name="facebook/deit-tiny-patch16-224", num_classes=40)
36
  state_dict = torch.load(model_path, map_location="cpu")
37
  if "model_state_dict" in state_dict:
38
  state_dict = state_dict["model_state_dict"]
39
  model.load_state_dict(state_dict, strict=False)
40
  model.eval()
41
 
42
+
43
+ #Swin
44
  transform = transforms.Compose([
45
  transforms.Resize((224, 224)),
46
  transforms.ToTensor(),
 
48
  std=[0.229, 0.224, 0.225])
49
  ])
50
 
51
+
52
  def predict(image_path):
53
+ # Load and prepare image
54
  image = Image.open(image_path).convert("RGB")
55
  x = transform(image).unsqueeze(0)
56
 
 
69
  class_folder = f"sample_images/{str(top1_label).replace(' ', '_')}"
70
  reference_image = None
71
  if os.path.isdir(class_folder):
72
+ # List all image files in the folder
73
  image_files = [f for f in os.listdir(class_folder) if f.lower().endswith((".jpg", ".jpeg", ".png", ".bmp", ".gif", ".webp"))]
74
  if image_files:
75
  chosen_file = random.choice(image_files)
 
81
  else:
82
  print(f"[DEBUG] Class folder does not exist: {class_folder}")
83
 
84
+ # Format Top-5 for gr.Label with class names
85
  top5_probs = {class_names[int(idx)]: float(score) for idx, score in zip(top5.indices, top5.values)}
86
  print(f"image path: {image_path}")
87
  print(f"top1_label: {top1_label}")
 
92
  return image, reference_image, top5_probs
93
 
94