csgaobb commited on
Commit
489f385
·
1 Parent(s): 7f9b501

fix model loading bug

Browse files
Files changed (1) hide show
  1. app.py +3 -3
app.py CHANGED
@@ -41,7 +41,7 @@ fusion_policy = 'cat'
41
 
42
  # build model
43
  set_random_seed(random_seed)
44
- metauas_model = MetaUAS(encoder_name,
45
  decoder_name,
46
  encoder_depth,
47
  decoder_depth,
@@ -56,7 +56,7 @@ def process_image(prompt_img, query_img, options):
56
  #model = safely_load_state_dict(metauas_model, ckt_path)
57
 
58
  ckpt_path = hf_hub_download(repo_id="csgaobb/MetaUAS", filename="metauas-512.ckpt")
59
- model = metauas_model.load_state_dict(torch.load(ckpt_path), strict=True)
60
 
61
  #model = MetaUAS.from_pretrained("csgaobb/MetaUAS-512")
62
  img_size = 512
@@ -64,7 +64,7 @@ def process_image(prompt_img, query_img, options):
64
  #model = safely_load_state_dict(metauas_model, ckt_path)
65
 
66
  ckpt_path = hf_hub_download(repo_id="csgaobb/MetaUAS", filename='metauas-256.ckpt')
67
- model = metauas_model.load_state_dict(torch.load(ckpt_path), strict=True)
68
  #model = MetaUAS.from_pretrained("csgaobb/MetaUAS-256")
69
  img_size = 256
70
 
 
41
 
42
  # build model
43
  set_random_seed(random_seed)
44
+ model = MetaUAS(encoder_name,
45
  decoder_name,
46
  encoder_depth,
47
  decoder_depth,
 
56
  #model = safely_load_state_dict(metauas_model, ckt_path)
57
 
58
  ckpt_path = hf_hub_download(repo_id="csgaobb/MetaUAS", filename="metauas-512.ckpt")
59
+ model.load_state_dict(torch.load(ckpt_path), strict=True)
60
 
61
  #model = MetaUAS.from_pretrained("csgaobb/MetaUAS-512")
62
  img_size = 512
 
64
  #model = safely_load_state_dict(metauas_model, ckt_path)
65
 
66
  ckpt_path = hf_hub_download(repo_id="csgaobb/MetaUAS", filename='metauas-256.ckpt')
67
+ model.load_state_dict(torch.load(ckpt_path), strict=True)
68
  #model = MetaUAS.from_pretrained("csgaobb/MetaUAS-256")
69
  img_size = 256
70