csgaobb commited on
Commit
861218b
·
1 Parent(s): 7e0623e

fix model loading bug

Browse files
Files changed (2) hide show
  1. __pycache__/metauas.cpython-313.pyc +0 -0
  2. app.py +8 -2
__pycache__/metauas.cpython-313.pyc ADDED
Binary file (13.4 kB). View file
 
app.py CHANGED
@@ -49,17 +49,23 @@ metauas_model = MetaUAS(encoder_name,
49
  fusion_policy
50
  )
51
 
 
 
 
 
 
 
52
  def process_image(prompt_img, query_img, options):
53
  # Load the model based on selected options
54
  if 'model-512' in options:
55
  #ckt_path = "weights/metauas-512.ckpt"
56
  #model = safely_load_state_dict(metauas_model, ckt_path)
57
- model = MetaUAS.load_from_checkpoint("csgaobb/MetaUAS/MetaUAS-512.ckpt")
58
  img_size = 512
59
  else:
60
  #ckt_path = 'weights/metauas-256.ckpt'
61
  #model = safely_load_state_dict(metauas_model, ckt_path)
62
- model = MetaUAS.load_from_checkpoint("csgaobb/MetaUAS/MetaUAS-256.ckpt")
63
  img_size = 256
64
 
65
  model.to(device)
 
49
  fusion_policy
50
  )
51
 
52
+
53
+ model_256 = safely_load_state_dict(metauas_model, "weights/metauas-256.ckpt")
54
+ model_512 = safely_load_state_dict(metauas_model, "weights/metauas-512.ckpt")
55
+ model_256.push_to_hub("csgaobb/MetaUAS-256")
56
+ model_512.push_to_hub("csgaobb/MetaUAS-512")
57
+
58
  def process_image(prompt_img, query_img, options):
59
  # Load the model based on selected options
60
  if 'model-512' in options:
61
  #ckt_path = "weights/metauas-512.ckpt"
62
  #model = safely_load_state_dict(metauas_model, ckt_path)
63
+ model = MetaUAS.from_pretrained("csgaobb/MetaUAS-512")
64
  img_size = 512
65
  else:
66
  #ckt_path = 'weights/metauas-256.ckpt'
67
  #model = safely_load_state_dict(metauas_model, ckt_path)
68
+ model = MetaUAS.from_pretrained("csgaobb/MetaUAS-256")
69
  img_size = 256
70
 
71
  model.to(device)