Noha90 commited on
Commit
925a7e9
·
1 Parent(s): c8dd81c

change to final best model

Browse files
Files changed (2) hide show
  1. app.py +1 -1
  2. predict.py +24 -23
app.py CHANGED
@@ -9,7 +9,7 @@ demo = gr.Interface(
9
  gr.Image(label="Top-1 Class Example"),
10
  gr.Label(label="Top-5 Probabilities")
11
  ],
12
- title="Scene Classification with Reference Image Testing using VIT version 1.0",
13
  description="Upload an image to get the predicted class with a sample image and top-5 prediction chart."
14
  )
15
 
 
9
  gr.Image(label="Top-1 Class Example"),
10
  gr.Label(label="Top-5 Probabilities")
11
  ],
12
+ title="Scene Classification with Reference Image Testing using Large SWIN version 1.0",
13
  description="Upload an image to get the predicted class with a sample image and top-5 prediction chart."
14
  )
15
 
predict.py CHANGED
@@ -1,7 +1,7 @@
1
  import torch
2
  import torch.nn as nn
3
  import torch.nn.init as init
4
- from transformers import AutoImageProcessor, AutoModelForImageClassification
5
  from huggingface_hub import hf_hub_download
6
  from PIL import Image
7
  import json
@@ -14,50 +14,51 @@ with open("labels.json", "r") as f:
14
  class_names = json.load(f)
15
  print("class_names:", class_names)
16
 
17
- class ViT(nn.Module):
18
-
19
- def __init__(self, model_name="google/vit-base-patch16-224", num_classes=40, dropout_rate=0.1):
20
- super(ViT, self).__init__()
21
- self.extractor = AutoImageProcessor.from_pretrained(model_name)
22
- self.model = AutoModelForImageClassification.from_pretrained(model_name)
23
  in_features = self.model.classifier.in_features
24
  self.model.classifier = nn.Sequential(
25
- nn.Dropout(p=dropout_rate),
 
 
26
  nn.Linear(in_features, num_classes)
27
  )
28
- self.img_size = (self.extractor.size['height'], self.extractor.size['width'])
29
- self.normalize = transforms.Normalize(mean=self.extractor.image_mean, std=self.extractor.image_std)
 
 
30
 
31
  def forward(self, images):
32
- outputs = self.model(pixel_values=images)
33
  return outputs.logits
34
 
35
- def get_test_transforms(self):
36
- return transforms.Compose([
37
- transforms.Resize(self.img_size),
38
- transforms.ToTensor(),
39
- self.normalize
40
- ])
41
-
42
- # Download your fine-tuned model checkpoint from the Hub
43
- model_path = hf_hub_download(repo_id="Noha90/AML_16", filename="vit_best_model.pth")
44
  print("Model path:", model_path)
45
- model = ViT(model_name="google/vit-base-patch16-224", num_classes=40)
46
  state_dict = torch.load(model_path, map_location="cpu")
47
  if "model_state_dict" in state_dict:
48
  state_dict = state_dict["model_state_dict"]
49
  model.load_state_dict(state_dict, strict=False)
50
  model.eval()
51
 
52
- transform = model.get_test_transforms()
 
 
 
 
 
53
 
54
  def predict(image_path):
55
  image = Image.open(image_path).convert("RGB")
56
  x = transform(image).unsqueeze(0)
57
  with torch.no_grad():
58
  outputs = model(x)
59
- probs = torch.nn.functional.softmax(outputs, dim=1)[0]
60
  print("Logits:", outputs)
 
61
  print("Probs:", probs)
62
  print("Sum of probs:", probs.sum())
63
  top5 = torch.topk(probs, k=5)
 
1
  import torch
2
  import torch.nn as nn
3
  import torch.nn.init as init
4
+ from transformers import SwinForImageClassification
5
  from huggingface_hub import hf_hub_download
6
  from PIL import Image
7
  import json
 
14
  class_names = json.load(f)
15
  print("class_names:", class_names)
16
 
17
+ MODEL_NAME = "microsoft/swin-large-patch4-window7-224"
18
+
19
+ class SwinCustom(nn.Module):
20
+ def __init__(self, model_name=MODEL_NAME, num_classes=40):
21
+ super(SwinCustom, self).__init__()
22
+ self.model = SwinForImageClassification.from_pretrained(model_name, num_labels=num_classes, ignore_mismatched_sizes=True)
23
  in_features = self.model.classifier.in_features
24
  self.model.classifier = nn.Sequential(
25
+ nn.Linear(in_features, in_features),
26
+ nn.LeakyReLU(),
27
+ nn.Dropout(0.3),
28
  nn.Linear(in_features, num_classes)
29
  )
30
+ # Weight initialization
31
+ for m in self.model.classifier:
32
+ if isinstance(m, nn.Linear):
33
+ init.kaiming_uniform_(m.weight, a=0, mode='fan_in', nonlinearity='leaky_relu')
34
 
35
  def forward(self, images):
36
+ outputs = self.model(images)
37
  return outputs.logits
38
 
39
+ model_path = hf_hub_download(repo_id="Noha90/AML_16", filename="large_swin_best_model.pth")
 
 
 
 
 
 
 
 
40
  print("Model path:", model_path)
41
+ model = SwinCustom(model_name=MODEL_NAME, num_classes=40)
42
  state_dict = torch.load(model_path, map_location="cpu")
43
  if "model_state_dict" in state_dict:
44
  state_dict = state_dict["model_state_dict"]
45
  model.load_state_dict(state_dict, strict=False)
46
  model.eval()
47
 
48
+ # Preprocessing
49
+ transform = transforms.Compose([
50
+ transforms.Resize((224, 224)),
51
+ transforms.ToTensor(),
52
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
53
+ ])
54
 
55
  def predict(image_path):
56
  image = Image.open(image_path).convert("RGB")
57
  x = transform(image).unsqueeze(0)
58
  with torch.no_grad():
59
  outputs = model(x)
 
60
  print("Logits:", outputs)
61
+ probs = torch.nn.functional.softmax(outputs, dim=1)[0]
62
  print("Probs:", probs)
63
  print("Sum of probs:", probs.sum())
64
  top5 = torch.topk(probs, k=5)