Noha90 commited on
Commit
092bd1b
·
1 Parent(s): 52b50f7

change to base deit

Browse files
Files changed (2) hide show
  1. app.py +1 -1
  2. predict.py +2 -2
app.py CHANGED
@@ -9,7 +9,7 @@ demo = gr.Interface(
9
  gr.Image(label="Top-1 Class Example"),
10
  gr.Label(label="Top-5 Probabilities")
11
  ],
12
- title="Scene Classification with Reference Image Tesing using DeiT version 1.0",
13
  description="Upload an image to get the predicted class with a sample image and top-5 prediction chart."
14
  )
15
 
 
9
  gr.Image(label="Top-1 Class Example"),
10
  gr.Label(label="Top-5 Probabilities")
11
  ],
12
+ title="Scene Classification with Reference Image Testing using DeiT version 1.0",
13
  description="Upload an image to get the predicted class with a sample image and top-5 prediction chart."
14
  )
15
 
predict.py CHANGED
@@ -17,7 +17,7 @@ with open("labels.json", "r") as f:
17
  print("class_names:", class_names)
18
 
19
  class DeiT(nn.Module):
20
- def __init__(self, model_name="facebook/deit-small-patch16-224", num_classes=40):
21
  super(DeiT, self).__init__()
22
  self.model = ViTForImageClassification.from_pretrained(model_name)
23
  in_features = self.model.classifier.in_features
@@ -30,7 +30,7 @@ class DeiT(nn.Module):
30
  return outputs.logits
31
 
32
  # Load model
33
- model_path = hf_hub_download(repo_id="Noha90/AML_16", filename="deit_best_model_1.pth")
34
  print("Model path:", model_path)
35
  model = DeiT(model_name="facebook/deit-tiny-patch16-224", num_classes=40)
36
  state_dict = torch.load(model_path, map_location="cpu")
 
17
  print("class_names:", class_names)
18
 
19
  class DeiT(nn.Module):
20
+ def __init__(self, model_name="facebook/deit-base-patch16-224", num_classes=40):
21
  super(DeiT, self).__init__()
22
  self.model = ViTForImageClassification.from_pretrained(model_name)
23
  in_features = self.model.classifier.in_features
 
30
  return outputs.logits
31
 
32
  # Load model
33
+ model_path = hf_hub_download(repo_id="Noha90/AML_16", filename="deit_base_best_model.pth")
34
  print("Model path:", model_path)
35
  model = DeiT(model_name="facebook/deit-tiny-patch16-224", num_classes=40)
36
  state_dict = torch.load(model_path, map_location="cpu")