WindVChen commited on
Commit
0560442
Β·
1 Parent(s): 7f7fc9d

Update inference_for_arbitrary_resolution_image.py

Browse files
inference_for_arbitrary_resolution_image.py CHANGED
@@ -321,7 +321,7 @@ def main_process(opt, composite_image=None, mask=None):
321
 
322
  model = build_model(opt).to(opt.device)
323
 
324
- load_dict = torch.load(opt.pretrained)['model']
325
  for k in load_dict.keys():
326
  if k not in model.state_dict().keys():
327
  print(f"Skip {k}")
 
321
 
322
  model = build_model(opt).to(opt.device)
323
 
324
+ load_dict = torch.load(opt.pretrained, map_location='cpu')['model']
325
  for k in load_dict.keys():
326
  if k not in model.state_dict().keys():
327
  print(f"Skip {k}")