Noha90 commited on
Commit
f7e2efb
·
1 Parent(s): a531bac

update final deploy

Browse files
Files changed (1) hide show
  1. predict.py +3 -15
predict.py CHANGED
@@ -8,7 +8,6 @@ import json
8
  import os
9
  import random
10
  from torchvision import transforms
11
- from torch.quantization import quantize_dynamic
12
 
13
  # Load labels
14
  with open("labels.json", "r") as f:
@@ -37,25 +36,14 @@ class SwinCustom(nn.Module):
37
  outputs = self.model(images)
38
  return outputs.logits
39
 
40
- # Download quantized model weights
41
- model_path = hf_hub_download(repo_id="Noha90/AML_16", filename="swin_large_quantised.pth")
42
  print("Model path:", model_path)
43
-
44
- # Build the model
45
  model = SwinCustom(model_name=MODEL_NAME, num_classes=40)
46
-
47
- # Quantize the model
48
- quantized_model = quantize_dynamic(model, {nn.Linear}, dtype=torch.qint8)
49
-
50
- # Load the quantized state dict
51
  state_dict = torch.load(model_path, map_location="cpu")
52
  if "model_state_dict" in state_dict:
53
  state_dict = state_dict["model_state_dict"]
54
- quantized_model.load_state_dict(state_dict, strict=False)
55
- quantized_model.eval()
56
-
57
- # Use quantized_model
58
- model = quantized_model
59
 
60
  # Preprocessing
61
  transform = transforms.Compose([
 
8
  import os
9
  import random
10
  from torchvision import transforms
 
11
 
12
  # Load labels
13
  with open("labels.json", "r") as f:
 
36
  outputs = self.model(images)
37
  return outputs.logits
38
 
39
+ model_path = hf_hub_download(repo_id="Noha90/AML_16", filename="large_swin_best_model.pth")
 
40
  print("Model path:", model_path)
 
 
41
  model = SwinCustom(model_name=MODEL_NAME, num_classes=40)
 
 
 
 
 
42
  state_dict = torch.load(model_path, map_location="cpu")
43
  if "model_state_dict" in state_dict:
44
  state_dict = state_dict["model_state_dict"]
45
+ model.load_state_dict(state_dict, strict=False)
46
+ model.eval()
 
 
 
47
 
48
  # Preprocessing
49
  transform = transforms.Compose([