shimu0215 commited on
Commit
20ce537
·
verified ·
1 Parent(s): 73e1151

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -6
app.py CHANGED
@@ -12,13 +12,13 @@ import tempfile
12
  from classes_and_palettes import GOLIATH_PALETTE, GOLIATH_CLASSES
13
 
14
  class Config:
15
- os.system(f"wget https://huggingface.co/shimu0215/seg/resolve/main/sapiens_0.3b_goliath_best_goliath_mIoU_7673_epoch_194_torchscript.pt2")
16
 
17
 
18
- # ASSETS_DIR = os.path.join(os.path.dirname(__file__), 'assets')
19
- # CHECKPOINTS_DIR = os.path.join(ASSETS_DIR, "checkpoints")
20
- ASSETS_DIR = os.system(f"wget https://huggingface.co/shimu0215/seg/resolve/")
21
- CHECKPOINTS_DIR = os.path.join(ASSETS_DIR, "main")
22
  CHECKPOINTS = {
23
  "0.3b": "sapiens_0.3b_goliath_best_goliath_mIoU_7673_epoch_194_torchscript.pt2",
24
  "0.6b": "sapiens_0.6b_goliath_best_goliath_mIoU_7777_epoch_178_torchscript.pt2",
@@ -28,7 +28,7 @@ class Config:
28
  class ModelManager:
29
  @staticmethod
30
  def load_model(checkpoint_name: str):
31
- checkpoint_path = os.path.join(Config.CHECKPOINTS_DIR, Config.CHECKPOINTS[checkpoint_name])
32
  model = torch.jit.load(checkpoint_path)
33
  model.eval()
34
  model.to("cuda")
 
12
  from classes_and_palettes import GOLIATH_PALETTE, GOLIATH_CLASSES
13
 
14
  class Config:
15
+ # os.system(f"wget https://huggingface.co/shimu0215/seg/resolve/main/sapiens_0.3b_goliath_best_goliath_mIoU_7673_epoch_194_torchscript.pt2")
16
 
17
 
18
+ ASSETS_DIR = os.path.join(os.path.dirname(__file__), 'assets')
19
+ CHECKPOINTS_DIR = os.path.join(ASSETS_DIR, "checkpoints")
20
+ # ASSETS_DIR = os.system(f"wget https://huggingface.co/shimu0215/seg/resolve/")
21
+ # CHECKPOINTS_DIR = os.path.join(ASSETS_DIR, "main")
22
  CHECKPOINTS = {
23
  "0.3b": "sapiens_0.3b_goliath_best_goliath_mIoU_7673_epoch_194_torchscript.pt2",
24
  "0.6b": "sapiens_0.6b_goliath_best_goliath_mIoU_7777_epoch_178_torchscript.pt2",
 
28
  class ModelManager:
29
  @staticmethod
30
  def load_model(checkpoint_name: str):
31
+ checkpoint_path = os.path.join(‘shimu0215/seg/’, Config.CHECKPOINTS[checkpoint_name])
32
  model = torch.jit.load(checkpoint_path)
33
  model.eval()
34
  model.to("cuda")