Noha90 commited on
Commit
a531bac
·
1 Parent(s): fa25033

test quantized model

Browse files
Files changed (1) hide show
  1. predict.py +15 -3
predict.py CHANGED
@@ -8,6 +8,7 @@ import json
8
  import os
9
  import random
10
  from torchvision import transforms
 
11
 
12
  # Load labels
13
  with open("labels.json", "r") as f:
@@ -36,14 +37,25 @@ class SwinCustom(nn.Module):
36
  outputs = self.model(images)
37
  return outputs.logits
38
 
39
- model_path = hf_hub_download(repo_id="Noha90/AML_16", filename="large_swin_best_model.pth")
 
40
  print("Model path:", model_path)
 
 
41
  model = SwinCustom(model_name=MODEL_NAME, num_classes=40)
 
 
 
 
 
42
  state_dict = torch.load(model_path, map_location="cpu")
43
  if "model_state_dict" in state_dict:
44
  state_dict = state_dict["model_state_dict"]
45
- model.load_state_dict(state_dict, strict=False)
46
- model.eval()
 
 
 
47
 
48
  # Preprocessing
49
  transform = transforms.Compose([
 
8
  import os
9
  import random
10
  from torchvision import transforms
11
+ from torch.quantization import quantize_dynamic
12
 
13
  # Load labels
14
  with open("labels.json", "r") as f:
 
37
  outputs = self.model(images)
38
  return outputs.logits
39
 
40
+ # Download quantized model weights
41
+ model_path = hf_hub_download(repo_id="Noha90/AML_16", filename="swin_large_quantised.pth")
42
  print("Model path:", model_path)
43
+
44
+ # Build the model
45
  model = SwinCustom(model_name=MODEL_NAME, num_classes=40)
46
+
47
+ # Quantize the model
48
+ quantized_model = quantize_dynamic(model, {nn.Linear}, dtype=torch.qint8)
49
+
50
+ # Load the quantized state dict
51
  state_dict = torch.load(model_path, map_location="cpu")
52
  if "model_state_dict" in state_dict:
53
  state_dict = state_dict["model_state_dict"]
54
+ quantized_model.load_state_dict(state_dict, strict=False)
55
+ quantized_model.eval()
56
+
57
+ # Use quantized_model
58
+ model = quantized_model
59
 
60
  # Preprocessing
61
  transform = transforms.Compose([