csgaobb commited on
Commit
5ee368b
·
1 Parent(s): 26e5e7f

fix model loading bug

Browse files
Files changed (1) hide show
  1. app.py +4 -6
app.py CHANGED
@@ -53,20 +53,18 @@ metauas_model = MetaUAS(encoder_name,
53
  def process_image(prompt_img, query_img, options):
54
  # Load the model based on selected options
55
  if 'model-512' in options:
56
- ckt_path = "weights/metauas-512.ckpt"
57
  #model = safely_load_state_dict(metauas_model, ckt_path)
58
 
59
- ckpt_path = hf_hub_download(repo_id="csgaobb/MetaUAS", filename=ckt_path)
60
- metauas_model.load_state_dict(torch.load(ckpt_path))
61
 
62
  #model = MetaUAS.from_pretrained("csgaobb/MetaUAS-512")
63
  img_size = 512
64
  else:
65
- ckt_path = 'weights/metauas-256.ckpt'
66
  #model = safely_load_state_dict(metauas_model, ckt_path)
67
 
68
- ckpt_path = hf_hub_download(repo_id="csgaobb/MetaUAS", filename=ckt_path)
69
- metauas_model.load_state_dict(torch.load(ckpt_path))
70
  #model = MetaUAS.from_pretrained("csgaobb/MetaUAS-256")
71
  img_size = 256
72
 
 
53
  def process_image(prompt_img, query_img, options):
54
  # Load the model based on selected options
55
  if 'model-512' in options:
 
56
  #model = safely_load_state_dict(metauas_model, ckt_path)
57
 
58
+ ckpt_path = hf_hub_download(repo_id="csgaobb/MetaUAS", filename="metauas-512.ckpt")
59
+ model = metauas_model.load_state_dict(torch.load(ckpt_path))
60
 
61
  #model = MetaUAS.from_pretrained("csgaobb/MetaUAS-512")
62
  img_size = 512
63
  else:
 
64
  #model = safely_load_state_dict(metauas_model, ckt_path)
65
 
66
+ ckpt_path = hf_hub_download(repo_id="csgaobb/MetaUAS", filename='metauas-256.ckpt')
67
+ model = metauas_model.load_state_dict(torch.load(ckpt_path))
68
  #model = MetaUAS.from_pretrained("csgaobb/MetaUAS-256")
69
  img_size = 256
70