Spaces:
Running
on
Zero
Running
on
Zero
fix model loading bug
Browse files- app.py +11 -10
- metauas.py +2 -2
app.py
CHANGED
@@ -49,23 +49,24 @@ metauas_model = MetaUAS(encoder_name,
|
|
49 |
fusion_policy
|
50 |
)
|
51 |
|
52 |
-
|
53 |
-
model_256 = safely_load_state_dict(metauas_model, "weights/metauas-256.ckpt")
|
54 |
-
model_512 = safely_load_state_dict(metauas_model, "weights/metauas-512.ckpt")
|
55 |
-
model_256.push_to_hub("csgaobb/MetaUAS-256")
|
56 |
-
model_512.push_to_hub("csgaobb/MetaUAS-512")
|
57 |
-
|
58 |
def process_image(prompt_img, query_img, options):
|
59 |
# Load the model based on selected options
|
60 |
if 'model-512' in options:
|
61 |
-
|
62 |
#model = safely_load_state_dict(metauas_model, ckt_path)
|
63 |
-
|
|
|
|
|
|
|
|
|
64 |
img_size = 512
|
65 |
else:
|
66 |
-
|
67 |
#model = safely_load_state_dict(metauas_model, ckt_path)
|
68 |
-
|
|
|
|
|
|
|
69 |
img_size = 256
|
70 |
|
71 |
model.to(device)
|
|
|
49 |
fusion_policy
|
50 |
)
|
51 |
|
|
|
|
|
|
|
|
|
|
|
|
|
52 |
def process_image(prompt_img, query_img, options):
|
53 |
# Load the model based on selected options
|
54 |
if 'model-512' in options:
|
55 |
+
ckt_path = "weights/metauas-512.ckpt"
|
56 |
#model = safely_load_state_dict(metauas_model, ckt_path)
|
57 |
+
|
58 |
+
ckpt_path = hf_hub_download(repo_id="csgaobb/MetaUAS", filename=ckt_path)
|
59 |
+
metauas_model.load_state_dict(torch.load(ckpt_path))
|
60 |
+
|
61 |
+
#model = MetaUAS.from_pretrained("csgaobb/MetaUAS-512")
|
62 |
img_size = 512
|
63 |
else:
|
64 |
+
ckt_path = 'weights/metauas-256.ckpt'
|
65 |
#model = safely_load_state_dict(metauas_model, ckt_path)
|
66 |
+
|
67 |
+
ckpt_path = hf_hub_download(repo_id="csgaobb/MetaUAS", filename=ckt_path)
|
68 |
+
metauas_model.load_state_dict(torch.load(ckpt_path))
|
69 |
+
#model = MetaUAS.from_pretrained("csgaobb/MetaUAS-256")
|
70 |
img_size = 256
|
71 |
|
72 |
model.to(device)
|
metauas.py
CHANGED
@@ -132,7 +132,7 @@ class AlignmentLayer(nn.Module):
|
|
132 |
return aligned_features
|
133 |
|
134 |
|
135 |
-
class MetaUAS(pl.LightningModule
|
136 |
def __init__(self, encoder_name, decoder_name, encoder_depth, decoder_depth, num_alignment_layers, alignment_type, fusion_policy):
|
137 |
super().__init__()
|
138 |
|
@@ -267,7 +267,7 @@ class MetaUAS(pl.LightningModule, PyTorchModelHubMixin):
|
|
267 |
stride=1,
|
268 |
padding=0,
|
269 |
)
|
270 |
-
|
271 |
def forward(self, batch):
|
272 |
query_input = self.preprocess(batch["query_image"])
|
273 |
prompt_input = self.preprocess(batch["prompt_image"])
|
|
|
132 |
return aligned_features
|
133 |
|
134 |
|
135 |
+
class MetaUAS(pl.LightningModule):
|
136 |
def __init__(self, encoder_name, decoder_name, encoder_depth, decoder_depth, num_alignment_layers, alignment_type, fusion_policy):
|
137 |
super().__init__()
|
138 |
|
|
|
267 |
stride=1,
|
268 |
padding=0,
|
269 |
)
|
270 |
+
|
271 |
def forward(self, batch):
|
272 |
query_input = self.preprocess(batch["query_image"])
|
273 |
prompt_input = self.preprocess(batch["prompt_image"])
|