csgaobb commited on
Commit
7f9b501
·
1 Parent(s): 5ee368b

fix model loading bug

Browse files
Files changed (1) hide show
  1. app.py +2 -2
app.py CHANGED
@@ -56,7 +56,7 @@ def process_image(prompt_img, query_img, options):
56
  #model = safely_load_state_dict(metauas_model, ckt_path)
57
 
58
  ckpt_path = hf_hub_download(repo_id="csgaobb/MetaUAS", filename="metauas-512.ckpt")
59
- model = metauas_model.load_state_dict(torch.load(ckpt_path))
60
 
61
  #model = MetaUAS.from_pretrained("csgaobb/MetaUAS-512")
62
  img_size = 512
@@ -64,7 +64,7 @@ def process_image(prompt_img, query_img, options):
64
  #model = safely_load_state_dict(metauas_model, ckt_path)
65
 
66
  ckpt_path = hf_hub_download(repo_id="csgaobb/MetaUAS", filename='metauas-256.ckpt')
67
- model = metauas_model.load_state_dict(torch.load(ckpt_path))
68
  #model = MetaUAS.from_pretrained("csgaobb/MetaUAS-256")
69
  img_size = 256
70
 
 
56
  #model = safely_load_state_dict(metauas_model, ckt_path)
57
 
58
  ckpt_path = hf_hub_download(repo_id="csgaobb/MetaUAS", filename="metauas-512.ckpt")
59
+ model = metauas_model.load_state_dict(torch.load(ckpt_path), strict=True)
60
 
61
  #model = MetaUAS.from_pretrained("csgaobb/MetaUAS-512")
62
  img_size = 512
 
64
  #model = safely_load_state_dict(metauas_model, ckt_path)
65
 
66
  ckpt_path = hf_hub_download(repo_id="csgaobb/MetaUAS", filename='metauas-256.ckpt')
67
+ model = metauas_model.load_state_dict(torch.load(ckpt_path), strict=True)
68
  #model = MetaUAS.from_pretrained("csgaobb/MetaUAS-256")
69
  img_size = 256
70