Spaces:
Runtime error
Runtime error
File size: 6,805 Bytes
5fb352c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 |
import torch
import safetensors.torch
import os
import collections
from collections import namedtuple
from modules import paths, shared, devices, script_callbacks, sd_models
import glob
from copy import deepcopy
vae_path = os.path.abspath(os.path.join(paths.models_path, "VAE"))
vae_ignore_keys = {"model_ema.decay", "model_ema.num_updates"}
vae_dict = {}
base_vae = None
loaded_vae_file = None
checkpoint_info = None
checkpoints_loaded = collections.OrderedDict()
def get_base_vae(model):
if base_vae is not None and checkpoint_info == model.sd_checkpoint_info and model:
return base_vae
return None
def store_base_vae(model):
global base_vae, checkpoint_info
if checkpoint_info != model.sd_checkpoint_info:
assert not loaded_vae_file, "Trying to store non-base VAE!"
base_vae = deepcopy(model.first_stage_model.state_dict())
checkpoint_info = model.sd_checkpoint_info
def delete_base_vae():
global base_vae, checkpoint_info
base_vae = None
checkpoint_info = None
def restore_base_vae(model):
global loaded_vae_file
if base_vae is not None and checkpoint_info == model.sd_checkpoint_info:
print("Restoring base VAE")
_load_vae_dict(model, base_vae)
loaded_vae_file = None
delete_base_vae()
def get_filename(filepath):
return os.path.basename(filepath)
def refresh_vae_list():
vae_dict.clear()
paths = [
os.path.join(sd_models.model_path, '**/*.vae.ckpt'),
os.path.join(sd_models.model_path, '**/*.vae.pt'),
os.path.join(sd_models.model_path, '**/*.vae.safetensors'),
os.path.join(vae_path, '**/*.ckpt'),
os.path.join(vae_path, '**/*.pt'),
os.path.join(vae_path, '**/*.safetensors'),
]
if shared.cmd_opts.ckpt_dir is not None and os.path.isdir(shared.cmd_opts.ckpt_dir):
paths += [
os.path.join(shared.cmd_opts.ckpt_dir, '**/*.vae.ckpt'),
os.path.join(shared.cmd_opts.ckpt_dir, '**/*.vae.pt'),
os.path.join(shared.cmd_opts.ckpt_dir, '**/*.vae.safetensors'),
]
if shared.cmd_opts.vae_dir is not None and os.path.isdir(shared.cmd_opts.vae_dir):
paths += [
os.path.join(shared.cmd_opts.vae_dir, '**/*.ckpt'),
os.path.join(shared.cmd_opts.vae_dir, '**/*.pt'),
os.path.join(shared.cmd_opts.vae_dir, '**/*.safetensors'),
]
candidates = []
for path in paths:
candidates += glob.iglob(path, recursive=True)
for filepath in candidates:
name = get_filename(filepath)
vae_dict[name] = filepath
def find_vae_near_checkpoint(checkpoint_file):
checkpoint_path = os.path.splitext(checkpoint_file)[0]
for vae_location in [checkpoint_path + ".vae.pt", checkpoint_path + ".vae.ckpt", checkpoint_path + ".vae.safetensors"]:
if os.path.isfile(vae_location):
return vae_location
return None
def resolve_vae(checkpoint_file):
if shared.cmd_opts.vae_path is not None:
return shared.cmd_opts.vae_path, 'from commandline argument'
is_automatic = shared.opts.sd_vae in {"Automatic", "auto"} # "auto" for people with old config
vae_near_checkpoint = find_vae_near_checkpoint(checkpoint_file)
if vae_near_checkpoint is not None and (shared.opts.sd_vae_as_default or is_automatic):
return vae_near_checkpoint, 'found near the checkpoint'
if shared.opts.sd_vae == "None":
return None, None
vae_from_options = vae_dict.get(shared.opts.sd_vae, None)
if vae_from_options is not None:
return vae_from_options, 'specified in settings'
if not is_automatic:
print(f"Couldn't find VAE named {shared.opts.sd_vae}; using None instead")
return None, None
def load_vae_dict(filename, map_location):
vae_ckpt = sd_models.read_state_dict(filename, map_location=map_location)
vae_dict_1 = {k: v for k, v in vae_ckpt.items() if k[0:4] != "loss" and k not in vae_ignore_keys}
return vae_dict_1
def load_vae(model, vae_file=None, vae_source="from unknown source"):
global vae_dict, loaded_vae_file
# save_settings = False
cache_enabled = shared.opts.sd_vae_checkpoint_cache > 0
if vae_file:
if cache_enabled and vae_file in checkpoints_loaded:
# use vae checkpoint cache
print(f"Loading VAE weights {vae_source}: cached {get_filename(vae_file)}")
store_base_vae(model)
_load_vae_dict(model, checkpoints_loaded[vae_file])
else:
assert os.path.isfile(vae_file), f"VAE {vae_source} doesn't exist: {vae_file}"
print(f"Loading VAE weights {vae_source}: {vae_file}")
store_base_vae(model)
vae_dict_1 = load_vae_dict(vae_file, map_location=shared.weight_load_location)
_load_vae_dict(model, vae_dict_1)
if cache_enabled:
# cache newly loaded vae
checkpoints_loaded[vae_file] = vae_dict_1.copy()
# clean up cache if limit is reached
if cache_enabled:
while len(checkpoints_loaded) > shared.opts.sd_vae_checkpoint_cache + 1: # we need to count the current model
checkpoints_loaded.popitem(last=False) # LRU
# If vae used is not in dict, update it
# It will be removed on refresh though
vae_opt = get_filename(vae_file)
if vae_opt not in vae_dict:
vae_dict[vae_opt] = vae_file
elif loaded_vae_file:
restore_base_vae(model)
loaded_vae_file = vae_file
# don't call this from outside
def _load_vae_dict(model, vae_dict_1):
model.first_stage_model.load_state_dict(vae_dict_1)
model.first_stage_model.to(devices.dtype_vae)
def clear_loaded_vae():
global loaded_vae_file
loaded_vae_file = None
unspecified = object()
def reload_vae_weights(sd_model=None, vae_file=unspecified):
from modules import lowvram, devices, sd_hijack
if not sd_model:
sd_model = shared.sd_model
checkpoint_info = sd_model.sd_checkpoint_info
checkpoint_file = checkpoint_info.filename
if vae_file == unspecified:
vae_file, vae_source = resolve_vae(checkpoint_file)
else:
vae_source = "from function argument"
if loaded_vae_file == vae_file:
return
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
lowvram.send_everything_to_cpu()
else:
sd_model.to(devices.cpu)
sd_hijack.model_hijack.undo_hijack(sd_model)
load_vae(sd_model, vae_file, vae_source)
sd_hijack.model_hijack.hijack(sd_model)
script_callbacks.model_loaded_callback(sd_model)
if not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram:
sd_model.to(devices.device)
print("VAE weights loaded.")
return sd_model
|