Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -10,6 +10,7 @@ from torchvision import transforms
|
|
10 |
from PIL import Image
|
11 |
import tempfile
|
12 |
from classes_and_palettes import GOLIATH_PALETTE, GOLIATH_CLASSES
|
|
|
13 |
|
14 |
class Config:
|
15 |
# os.system(f"wget https://huggingface.co/shimu0215/seg/resolve/main/sapiens_0.3b_goliath_best_goliath_mIoU_7673_epoch_194_torchscript.pt2")
|
@@ -28,8 +29,16 @@ class Config:
|
|
28 |
class ModelManager:
|
29 |
@staticmethod
|
30 |
def load_model(checkpoint_name: str):
|
31 |
-
|
32 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
model.eval()
|
34 |
model.to("cuda")
|
35 |
return model
|
|
|
10 |
from PIL import Image
|
11 |
import tempfile
|
12 |
from classes_and_palettes import GOLIATH_PALETTE, GOLIATH_CLASSES
|
13 |
+
from huggingface_hub import hf_hub_download
|
14 |
|
15 |
class Config:
|
16 |
# os.system(f"wget https://huggingface.co/shimu0215/seg/resolve/main/sapiens_0.3b_goliath_best_goliath_mIoU_7673_epoch_194_torchscript.pt2")
|
|
|
29 |
class ModelManager:
|
30 |
@staticmethod
|
31 |
def load_model(checkpoint_name: str):
|
32 |
+
|
33 |
+
|
34 |
+
|
35 |
+
# 下载模型到本地(默认存储在 ~/.cache/huggingface/hub)
|
36 |
+
model_path = hf_hub_download(
|
37 |
+
repo_id="shimu0215/seg", # 你的模型仓库
|
38 |
+
filename="sapiens_0.3b_goliath_best_goliath_mIoU_7673_epoch_194_torchscript.pt2", # 你的模型文件
|
39 |
+
)
|
40 |
+
# checkpoint_path = os.path.join('shimu0215/seg', Config.CHECKPOINTS[checkpoint_name])
|
41 |
+
model = torch.jit.load(model_path)
|
42 |
model.eval()
|
43 |
model.to("cuda")
|
44 |
return model
|