Spaces:
Runtime error
Runtime error
Update merged_files3.py
Browse files- merged_files3.py +5 -4
merged_files3.py
CHANGED
@@ -487,8 +487,12 @@ def hooked_unet_forward(sample, timestep, encoder_hidden_states, **kwargs):
|
|
487 |
|
488 |
unet.forward = hooked_unet_forward
|
489 |
|
|
|
|
|
|
|
|
|
490 |
|
491 |
-
sd_offset = sf.load_file(model_path)
|
492 |
sd_origin = unet.state_dict()
|
493 |
keys = sd_origin.keys()
|
494 |
sd_merged = {k: sd_origin[k] + sd_offset[k] for k in sd_origin.keys()}
|
@@ -496,9 +500,6 @@ unet.load_state_dict(sd_merged, strict=True)
|
|
496 |
del sd_offset, sd_origin, sd_merged, keys
|
497 |
|
498 |
|
499 |
-
# Device and dtype setup
|
500 |
-
device = torch.device('cuda')
|
501 |
-
dtype = torch.float16 # Use float16 consistently for all models
|
502 |
|
503 |
pipe = prepare_pipeline(
|
504 |
base_model="stabilityai/stable-diffusion-xl-base-1.0",
|
|
|
487 |
|
488 |
unet.forward = hooked_unet_forward
|
489 |
|
490 |
+
# Device and dtype setup
|
491 |
+
device = torch.device('cuda')
|
492 |
+
dtype = torch.float16 # Use float16 consistently for all models
|
493 |
+
|
494 |
|
495 |
+
sd_offset = sf.load_file(model_path, device=device) # Use device variable
|
496 |
sd_origin = unet.state_dict()
|
497 |
keys = sd_origin.keys()
|
498 |
sd_merged = {k: sd_origin[k] + sd_offset[k] for k in sd_origin.keys()}
|
|
|
500 |
del sd_offset, sd_origin, sd_merged, keys
|
501 |
|
502 |
|
|
|
|
|
|
|
503 |
|
504 |
pipe = prepare_pipeline(
|
505 |
base_model="stabilityai/stable-diffusion-xl-base-1.0",
|