Ashoka74 commited on
Commit
d93491f
ยท
verified ยท
1 Parent(s): 4f4d85a

Update merged_files3.py

Browse files
Files changed (1) hide show
  1. merged_files3.py +5 -4
merged_files3.py CHANGED
@@ -487,8 +487,12 @@ def hooked_unet_forward(sample, timestep, encoder_hidden_states, **kwargs):
487
 
488
  unet.forward = hooked_unet_forward
489
 
 
 
 
 
490
 
491
- sd_offset = sf.load_file(model_path)
492
  sd_origin = unet.state_dict()
493
  keys = sd_origin.keys()
494
  sd_merged = {k: sd_origin[k] + sd_offset[k] for k in sd_origin.keys()}
@@ -496,9 +500,6 @@ unet.load_state_dict(sd_merged, strict=True)
496
  del sd_offset, sd_origin, sd_merged, keys
497
 
498
 
499
- # Device and dtype setup
500
- device = torch.device('cuda')
501
- dtype = torch.float16 # Use float16 consistently for all models
502
 
503
  pipe = prepare_pipeline(
504
  base_model="stabilityai/stable-diffusion-xl-base-1.0",
 
487
 
488
  unet.forward = hooked_unet_forward
489
 
490
+ # Device and dtype setup
491
+ device = torch.device('cuda')
492
+ dtype = torch.float16 # Use float16 consistently for all models
493
+
494
 
495
+ sd_offset = sf.load_file(model_path, device=device) # Use device variable
496
  sd_origin = unet.state_dict()
497
  keys = sd_origin.keys()
498
  sd_merged = {k: sd_origin[k] + sd_offset[k] for k in sd_origin.keys()}
 
500
  del sd_offset, sd_origin, sd_merged, keys
501
 
502
 
 
 
 
503
 
504
  pipe = prepare_pipeline(
505
  base_model="stabilityai/stable-diffusion-xl-base-1.0",