Noha90 commited on
Commit
7220b97
·
1 Parent(s): 02a7128

test large swin

Browse files
Files changed (1) hide show
  1. predict.py +24 -26
predict.py CHANGED
@@ -1,59 +1,61 @@
1
  import torch
2
- from torchvision import transforms
 
 
 
3
  from PIL import Image
4
  import json
5
- import numpy as np
6
- # from model import load_model
7
- from transformers import AutoImageProcessor, SwinForImageClassification, ViTForImageClassification
8
- import torch.nn as nn
9
  import os
10
- import pandas as pd
11
  import random
12
- from huggingface_hub import hf_hub_download
13
 
14
  # Load labels
15
  with open("labels.json", "r") as f:
16
  class_names = json.load(f)
17
  print("class_names:", class_names)
18
 
19
- class DeiT(nn.Module):
20
- def __init__(self, model_name="facebook/deit-base-patch16-224", num_classes=40):
21
- super(DeiT, self).__init__()
22
- self.model = ViTForImageClassification.from_pretrained(model_name)
 
 
23
  in_features = self.model.classifier.in_features
24
  self.model.classifier = nn.Sequential(
 
 
 
25
  nn.Linear(in_features, num_classes)
26
  )
 
 
 
 
27
 
28
  def forward(self, images):
29
  outputs = self.model(images)
30
  return outputs.logits
31
 
32
- # Load model
33
- model_path = hf_hub_download(repo_id="Noha90/AML_16", filename="deit_base_best_model.pth")
34
  print("Model path:", model_path)
35
- model = DeiT(model_name="facebook/deit-base-patch16-224", num_classes=40)
36
  state_dict = torch.load(model_path, map_location="cpu")
37
  if "model_state_dict" in state_dict:
38
  state_dict = state_dict["model_state_dict"]
39
  model.load_state_dict(state_dict, strict=False)
40
  model.eval()
41
 
42
-
43
- #Swin
44
  transform = transforms.Compose([
45
  transforms.Resize((224, 224)),
46
  transforms.ToTensor(),
47
- transforms.Normalize(mean=[0.485, 0.456, 0.406],
48
- std=[0.229, 0.224, 0.225])
49
  ])
50
 
51
-
52
  def predict(image_path):
53
- # Load and prepare image
54
  image = Image.open(image_path).convert("RGB")
55
  x = transform(image).unsqueeze(0)
56
-
57
  with torch.no_grad():
58
  outputs = model(x)
59
  print("Logits:", outputs)
@@ -69,7 +71,6 @@ def predict(image_path):
69
  class_folder = f"sample_images/{str(top1_label).replace(' ', '_')}"
70
  reference_image = None
71
  if os.path.isdir(class_folder):
72
- # List all image files in the folder
73
  image_files = [f for f in os.listdir(class_folder) if f.lower().endswith((".jpg", ".jpeg", ".png", ".bmp", ".gif", ".webp"))]
74
  if image_files:
75
  chosen_file = random.choice(image_files)
@@ -81,7 +82,6 @@ def predict(image_path):
81
  else:
82
  print(f"[DEBUG] Class folder does not exist: {class_folder}")
83
 
84
- # Format Top-5 for gr.Label with class names
85
  top5_probs = {class_names[int(idx)]: float(score) for idx, score in zip(top5.indices, top5.values)}
86
  print(f"image path: {image_path}")
87
  print(f"top1_label: {top1_label}")
@@ -89,6 +89,4 @@ def predict(image_path):
89
  print(f"[DEBUG] Top-5 labels: {[class_names[int(idx)] for idx in top5.indices]}")
90
  print(f"[DEBUG] Top-5 probs: {top5_probs}")
91
 
92
- return image, reference_image, top5_probs
93
-
94
-
 
1
  import torch
2
+ import torch.nn as nn
3
+ import torch.nn.init as init
4
+ from transformers import SwinForImageClassification
5
+ from huggingface_hub import hf_hub_download
6
  from PIL import Image
7
  import json
 
 
 
 
8
  import os
 
9
  import random
10
+ from torchvision import transforms
11
 
12
  # Load labels
13
  with open("labels.json", "r") as f:
14
  class_names = json.load(f)
15
  print("class_names:", class_names)
16
 
17
+ MODEL_NAME = "microsoft/swin-large-patch4-window7-224" # or your preferred Swin model
18
+
19
+ class SwinCustom(nn.Module):
20
+ def __init__(self, model_name=MODEL_NAME, num_classes=40):
21
+ super(SwinCustom, self).__init__()
22
+ self.model = SwinForImageClassification.from_pretrained(model_name, num_labels=num_classes, ignore_mismatched_sizes=True)
23
  in_features = self.model.classifier.in_features
24
  self.model.classifier = nn.Sequential(
25
+ nn.Linear(in_features, in_features),
26
+ nn.LeakyReLU(),
27
+ nn.Dropout(0.3),
28
  nn.Linear(in_features, num_classes)
29
  )
30
+ # Weight initialization
31
+ for m in self.model.classifier:
32
+ if isinstance(m, nn.Linear):
33
+ init.kaiming_uniform_(m.weight, a=0, mode='fan_in', nonlinearity='leaky_relu')
34
 
35
  def forward(self, images):
36
  outputs = self.model(images)
37
  return outputs.logits
38
 
39
+ # Download your fine-tuned model checkpoint from the Hub
40
+ model_path = hf_hub_download(repo_id="Noha90/AML_16", filename="large_swin_best_model.pth")
41
  print("Model path:", model_path)
42
+ model = SwinCustom(model_name=MODEL_NAME, num_classes=40)
43
  state_dict = torch.load(model_path, map_location="cpu")
44
  if "model_state_dict" in state_dict:
45
  state_dict = state_dict["model_state_dict"]
46
  model.load_state_dict(state_dict, strict=False)
47
  model.eval()
48
 
49
+ # Preprocessing (Swin default)
 
50
  transform = transforms.Compose([
51
  transforms.Resize((224, 224)),
52
  transforms.ToTensor(),
53
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
 
54
  ])
55
 
 
56
  def predict(image_path):
 
57
  image = Image.open(image_path).convert("RGB")
58
  x = transform(image).unsqueeze(0)
 
59
  with torch.no_grad():
60
  outputs = model(x)
61
  print("Logits:", outputs)
 
71
  class_folder = f"sample_images/{str(top1_label).replace(' ', '_')}"
72
  reference_image = None
73
  if os.path.isdir(class_folder):
 
74
  image_files = [f for f in os.listdir(class_folder) if f.lower().endswith((".jpg", ".jpeg", ".png", ".bmp", ".gif", ".webp"))]
75
  if image_files:
76
  chosen_file = random.choice(image_files)
 
82
  else:
83
  print(f"[DEBUG] Class folder does not exist: {class_folder}")
84
 
 
85
  top5_probs = {class_names[int(idx)]: float(score) for idx, score in zip(top5.indices, top5.values)}
86
  print(f"image path: {image_path}")
87
  print(f"top1_label: {top1_label}")
 
89
  print(f"[DEBUG] Top-5 labels: {[class_names[int(idx)] for idx in top5.indices]}")
90
  print(f"[DEBUG] Top-5 probs: {top5_probs}")
91
 
92
+ return image, reference_image, top5_probs