shimu0215 commited on
Commit
41ee966
·
verified ·
1 Parent(s): de64dcf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -2
app.py CHANGED
@@ -10,6 +10,7 @@ from torchvision import transforms
10
  from PIL import Image
11
  import tempfile
12
  from classes_and_palettes import GOLIATH_PALETTE, GOLIATH_CLASSES
 
13
 
14
  class Config:
15
  # os.system(f"wget https://huggingface.co/shimu0215/seg/resolve/main/sapiens_0.3b_goliath_best_goliath_mIoU_7673_epoch_194_torchscript.pt2")
@@ -28,8 +29,16 @@ class Config:
28
  class ModelManager:
29
  @staticmethod
30
  def load_model(checkpoint_name: str):
31
- checkpoint_path = os.path.join('shimu0215/seg', Config.CHECKPOINTS[checkpoint_name])
32
- model = torch.jit.load(checkpoint_path)
 
 
 
 
 
 
 
 
33
  model.eval()
34
  model.to("cuda")
35
  return model
 
10
  from PIL import Image
11
  import tempfile
12
  from classes_and_palettes import GOLIATH_PALETTE, GOLIATH_CLASSES
13
+ from huggingface_hub import hf_hub_download
14
 
15
  class Config:
16
  # os.system(f"wget https://huggingface.co/shimu0215/seg/resolve/main/sapiens_0.3b_goliath_best_goliath_mIoU_7673_epoch_194_torchscript.pt2")
 
29
  class ModelManager:
30
  @staticmethod
31
  def load_model(checkpoint_name: str):
32
+
33
+
34
+
35
+ # 下载模型到本地(默认存储在 ~/.cache/huggingface/hub)
36
+ model_path = hf_hub_download(
37
+ repo_id="shimu0215/seg", # 你的模型仓库
38
+ filename="sapiens_0.3b_goliath_best_goliath_mIoU_7673_epoch_194_torchscript.pt2", # 你的模型文件
39
+ )
40
+ # checkpoint_path = os.path.join('shimu0215/seg', Config.CHECKPOINTS[checkpoint_name])
41
+ model = torch.jit.load(model_path)
42
  model.eval()
43
  model.to("cuda")
44
  return model