csgaobb commited on
Commit
f2fb485
·
1 Parent(s): 861218b

fix model loading bug

Browse files
Files changed (2) hide show
  1. app.py +11 -10
  2. metauas.py +2 -2
app.py CHANGED
@@ -49,23 +49,24 @@ metauas_model = MetaUAS(encoder_name,
49
  fusion_policy
50
  )
51
 
52
-
53
- model_256 = safely_load_state_dict(metauas_model, "weights/metauas-256.ckpt")
54
- model_512 = safely_load_state_dict(metauas_model, "weights/metauas-512.ckpt")
55
- model_256.push_to_hub("csgaobb/MetaUAS-256")
56
- model_512.push_to_hub("csgaobb/MetaUAS-512")
57
-
58
  def process_image(prompt_img, query_img, options):
59
  # Load the model based on selected options
60
  if 'model-512' in options:
61
- #ckt_path = "weights/metauas-512.ckpt"
62
  #model = safely_load_state_dict(metauas_model, ckt_path)
63
- model = MetaUAS.from_pretrained("csgaobb/MetaUAS-512")
 
 
 
 
64
  img_size = 512
65
  else:
66
- #ckt_path = 'weights/metauas-256.ckpt'
67
  #model = safely_load_state_dict(metauas_model, ckt_path)
68
- model = MetaUAS.from_pretrained("csgaobb/MetaUAS-256")
 
 
 
69
  img_size = 256
70
 
71
  model.to(device)
 
49
  fusion_policy
50
  )
51
 
 
 
 
 
 
 
52
  def process_image(prompt_img, query_img, options):
53
  # Load the model based on selected options
54
  if 'model-512' in options:
55
+ ckt_path = "weights/metauas-512.ckpt"
56
  #model = safely_load_state_dict(metauas_model, ckt_path)
57
+
58
+ ckpt_path = hf_hub_download(repo_id="csgaobb/MetaUAS", filename=ckt_path)
59
+ metauas_model.load_state_dict(torch.load(ckpt_path))
60
+
61
+ #model = MetaUAS.from_pretrained("csgaobb/MetaUAS-512")
62
  img_size = 512
63
  else:
64
+ ckt_path = 'weights/metauas-256.ckpt'
65
  #model = safely_load_state_dict(metauas_model, ckt_path)
66
+
67
+ ckpt_path = hf_hub_download(repo_id="csgaobb/MetaUAS", filename=ckt_path)
68
+ metauas_model.load_state_dict(torch.load(ckpt_path))
69
+ #model = MetaUAS.from_pretrained("csgaobb/MetaUAS-256")
70
  img_size = 256
71
 
72
  model.to(device)
metauas.py CHANGED
@@ -132,7 +132,7 @@ class AlignmentLayer(nn.Module):
132
  return aligned_features
133
 
134
 
135
- class MetaUAS(pl.LightningModule, PyTorchModelHubMixin):
136
  def __init__(self, encoder_name, decoder_name, encoder_depth, decoder_depth, num_alignment_layers, alignment_type, fusion_policy):
137
  super().__init__()
138
 
@@ -267,7 +267,7 @@ class MetaUAS(pl.LightningModule, PyTorchModelHubMixin):
267
  stride=1,
268
  padding=0,
269
  )
270
-
271
  def forward(self, batch):
272
  query_input = self.preprocess(batch["query_image"])
273
  prompt_input = self.preprocess(batch["prompt_image"])
 
132
  return aligned_features
133
 
134
 
135
+ class MetaUAS(pl.LightningModule):
136
  def __init__(self, encoder_name, decoder_name, encoder_depth, decoder_depth, num_alignment_layers, alignment_type, fusion_policy):
137
  super().__init__()
138
 
 
267
  stride=1,
268
  padding=0,
269
  )
270
+
271
  def forward(self, batch):
272
  query_input = self.preprocess(batch["query_image"])
273
  prompt_input = self.preprocess(batch["prompt_image"])