Noha90 commited on
Commit
db6ba08
·
1 Parent(s): 4a73cb4
Files changed (1) hide show
  1. predict.py +22 -24
predict.py CHANGED
@@ -1,7 +1,7 @@
1
  import torch
2
  import torch.nn as nn
3
  import torch.nn.init as init
4
- from transformers import SwinForImageClassification
5
  from huggingface_hub import hf_hub_download
6
  from PIL import Image
7
  import json
@@ -14,52 +14,50 @@ with open("labels.json", "r") as f:
14
  class_names = json.load(f)
15
  print("class_names:", class_names)
16
 
17
- MODEL_NAME = "microsoft/swin-large-patch4-window7-224" # or your preferred Swin model
18
-
19
- class SwinCustom(nn.Module):
20
- def __init__(self, model_name=MODEL_NAME, num_classes=40):
21
- super(SwinCustom, self).__init__()
22
- self.model = SwinForImageClassification.from_pretrained(model_name, num_labels=num_classes, ignore_mismatched_sizes=True)
23
  in_features = self.model.classifier.in_features
24
  self.model.classifier = nn.Sequential(
25
- nn.Linear(in_features, in_features),
26
- nn.LeakyReLU(),
27
- nn.Dropout(0.3),
28
  nn.Linear(in_features, num_classes)
29
  )
30
- # Weight initialization
31
- for m in self.model.classifier:
32
- if isinstance(m, nn.Linear):
33
- init.kaiming_uniform_(m.weight, a=0, mode='fan_in', nonlinearity='leaky_relu')
34
 
35
  def forward(self, images):
36
- outputs = self.model(images)
37
  return outputs.logits
38
 
 
 
 
 
 
 
 
39
  # Download your fine-tuned model checkpoint from the Hub
40
- model_path = hf_hub_download(repo_id="Noha90/AML_16", filename="large_swin_best_model.pth")
41
  print("Model path:", model_path)
42
- model = SwinCustom(model_name=MODEL_NAME, num_classes=40)
43
  state_dict = torch.load(model_path, map_location="cpu")
44
  if "model_state_dict" in state_dict:
45
  state_dict = state_dict["model_state_dict"]
46
  model.load_state_dict(state_dict, strict=False)
47
  model.eval()
48
 
49
- # Preprocessing (Swin default)
50
- transform = transforms.Compose([
51
- transforms.Resize((224, 224)),
52
- transforms.ToTensor(),
53
- transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
54
- ])
55
 
56
  def predict(image_path):
57
  image = Image.open(image_path).convert("RGB")
58
  x = transform(image).unsqueeze(0)
59
  with torch.no_grad():
60
  outputs = model(x)
61
- print("Logits:", outputs)
62
  probs = torch.nn.functional.softmax(outputs, dim=1)[0]
 
63
  print("Probs:", probs)
64
  print("Sum of probs:", probs.sum())
65
  top5 = torch.topk(probs, k=5)
 
1
  import torch
2
  import torch.nn as nn
3
  import torch.nn.init as init
4
+ from transformers import AutoImageProcessor, AutoModelForImageClassification
5
  from huggingface_hub import hf_hub_download
6
  from PIL import Image
7
  import json
 
14
  class_names = json.load(f)
15
  print("class_names:", class_names)
16
 
17
+ class ViT(nn.Module):
18
+
19
+ def __init__(self, model_name="google/vit-base-patch16-224", num_classes=40, dropout_rate=0.1):
20
+ super(ViT, self).__init__()
21
+ self.extractor = AutoImageProcessor.from_pretrained(model_name)
22
+ self.model = AutoModelForImageClassification.from_pretrained(model_name)
23
  in_features = self.model.classifier.in_features
24
  self.model.classifier = nn.Sequential(
25
+ nn.Dropout(p=dropout_rate),
 
 
26
  nn.Linear(in_features, num_classes)
27
  )
28
+ self.img_size = (self.extractor.size['height'], self.extractor.size['width'])
29
+ self.normalize = transforms.Normalize(mean=self.extractor.image_mean, std=self.extractor.image_std)
 
 
30
 
31
  def forward(self, images):
32
+ outputs = self.model(pixel_values=images)
33
  return outputs.logits
34
 
35
+ def get_test_transforms(self):
36
+ return transforms.Compose([
37
+ transforms.Resize(self.img_size),
38
+ transforms.ToTensor(),
39
+ self.normalize
40
+ ])
41
+
42
  # Download your fine-tuned model checkpoint from the Hub
43
+ model_path = hf_hub_download(repo_id="Noha90/AML_16", filename="vit_best_model.pth")
44
  print("Model path:", model_path)
45
+ model = ViT(model_name="google/vit-base-patch16-224", num_classes=40)
46
  state_dict = torch.load(model_path, map_location="cpu")
47
  if "model_state_dict" in state_dict:
48
  state_dict = state_dict["model_state_dict"]
49
  model.load_state_dict(state_dict, strict=False)
50
  model.eval()
51
 
52
+ transform = model.get_test_transforms()
 
 
 
 
 
53
 
54
  def predict(image_path):
55
  image = Image.open(image_path).convert("RGB")
56
  x = transform(image).unsqueeze(0)
57
  with torch.no_grad():
58
  outputs = model(x)
 
59
  probs = torch.nn.functional.softmax(outputs, dim=1)[0]
60
+ print("Logits:", outputs)
61
  print("Probs:", probs)
62
  print("Sum of probs:", probs.sum())
63
  top5 = torch.topk(probs, k=5)