Spaces:
Sleeping
Sleeping
Upload landmark_detection.py
Browse files- landmark_detection.py +13 -2
landmark_detection.py
CHANGED
@@ -1,5 +1,7 @@
|
|
|
|
1 |
import torch
|
2 |
import torchvision
|
|
|
3 |
from torchvision.transforms import InterpolationMode
|
4 |
from network.models.facexformer import FaceXFormer
|
5 |
from dataclasses import dataclass
|
@@ -12,8 +14,8 @@ import numpy as np
|
|
12 |
# device = "cuda:0"
|
13 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
14 |
dtype = torch.float32
|
15 |
-
|
16 |
-
weights_path = "ckpts/pytorch_model.bin"
|
17 |
# face_model_path = "ckpts/blaze_face_short_range.tflite"
|
18 |
|
19 |
# import mediapipe as mp
|
@@ -46,6 +48,15 @@ transforms_image = torchvision.transforms.Compose(
|
|
46 |
|
47 |
def load_model(weights_path):
|
48 |
model = FaceXFormer().to(device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
49 |
checkpoint = torch.load(weights_path, map_location=device)
|
50 |
model.load_state_dict(checkpoint)
|
51 |
# model.load_state_dict(checkpoint["state_dict_backbone"])
|
|
|
1 |
+
import os
|
2 |
import torch
|
3 |
import torchvision
|
4 |
+
import huggingface_hub
|
5 |
from torchvision.transforms import InterpolationMode
|
6 |
from network.models.facexformer import FaceXFormer
|
7 |
from dataclasses import dataclass
|
|
|
14 |
# device = "cuda:0"
|
15 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
16 |
dtype = torch.float32
|
17 |
+
weights_path = "ckpts/model.pt"
|
18 |
+
# weights_path = "ckpts/pytorch_model.bin"
|
19 |
# face_model_path = "ckpts/blaze_face_short_range.tflite"
|
20 |
|
21 |
# import mediapipe as mp
|
|
|
48 |
|
49 |
def load_model(weights_path):
|
50 |
model = FaceXFormer().to(device)
|
51 |
+
if not os.path.exists(weights_path):
|
52 |
+
name = os.path.basename(weights_path)
|
53 |
+
path = os.path.dirname(weights_path)
|
54 |
+
huggingface_hub.hf_hub_download(
|
55 |
+
"kartiknarayan/facexformer",
|
56 |
+
"ckpts/model.pt",
|
57 |
+
repo_type="model",
|
58 |
+
local_dir=path,
|
59 |
+
)
|
60 |
checkpoint = torch.load(weights_path, map_location=device)
|
61 |
model.load_state_dict(checkpoint)
|
62 |
# model.load_state_dict(checkpoint["state_dict_backbone"])
|