thng292 commited on
Commit
6f3dcb0
·
verified ·
1 Parent(s): d0ac7e9

Upload landmark_detection.py

Browse files
Files changed (1) hide show
  1. landmark_detection.py +13 -2
landmark_detection.py CHANGED
@@ -1,5 +1,7 @@
 
1
  import torch
2
  import torchvision
 
3
  from torchvision.transforms import InterpolationMode
4
  from network.models.facexformer import FaceXFormer
5
  from dataclasses import dataclass
@@ -12,8 +14,8 @@ import numpy as np
12
  # device = "cuda:0"
13
  device = "cuda" if torch.cuda.is_available() else "cpu"
14
  dtype = torch.float32
15
- # weights_path = "ckpts/model.pt"
16
- weights_path = "ckpts/pytorch_model.bin"
17
  # face_model_path = "ckpts/blaze_face_short_range.tflite"
18
 
19
  # import mediapipe as mp
@@ -46,6 +48,15 @@ transforms_image = torchvision.transforms.Compose(
46
 
47
  def load_model(weights_path):
48
  model = FaceXFormer().to(device)
 
 
 
 
 
 
 
 
 
49
  checkpoint = torch.load(weights_path, map_location=device)
50
  model.load_state_dict(checkpoint)
51
  # model.load_state_dict(checkpoint["state_dict_backbone"])
 
1
+ import os
2
  import torch
3
  import torchvision
4
+ import huggingface_hub
5
  from torchvision.transforms import InterpolationMode
6
  from network.models.facexformer import FaceXFormer
7
  from dataclasses import dataclass
 
14
  # device = "cuda:0"
15
  device = "cuda" if torch.cuda.is_available() else "cpu"
16
  dtype = torch.float32
17
+ weights_path = "ckpts/model.pt"
18
+ # weights_path = "ckpts/pytorch_model.bin"
19
  # face_model_path = "ckpts/blaze_face_short_range.tflite"
20
 
21
  # import mediapipe as mp
 
48
 
49
  def load_model(weights_path):
50
  model = FaceXFormer().to(device)
51
+ if not os.path.exists(weights_path):
52
+ name = os.path.basename(weights_path)
53
+ path = os.path.dirname(weights_path)
54
+ huggingface_hub.hf_hub_download(
55
+ "kartiknarayan/facexformer",
56
+ "ckpts/model.pt",
57
+ repo_type="model",
58
+ local_dir=path,
59
+ )
60
  checkpoint = torch.load(weights_path, map_location=device)
61
  model.load_state_dict(checkpoint)
62
  # model.load_state_dict(checkpoint["state_dict_backbone"])