Spaces:
Runtime error
Runtime error
user
5fb352c
import torch | |
import safetensors.torch | |
import os | |
import collections | |
from collections import namedtuple | |
from modules import paths, shared, devices, script_callbacks, sd_models | |
import glob | |
from copy import deepcopy | |
vae_path = os.path.abspath(os.path.join(paths.models_path, "VAE")) | |
vae_ignore_keys = {"model_ema.decay", "model_ema.num_updates"} | |
vae_dict = {} | |
base_vae = None | |
loaded_vae_file = None | |
checkpoint_info = None | |
checkpoints_loaded = collections.OrderedDict() | |
def get_base_vae(model): | |
if base_vae is not None and checkpoint_info == model.sd_checkpoint_info and model: | |
return base_vae | |
return None | |
def store_base_vae(model): | |
global base_vae, checkpoint_info | |
if checkpoint_info != model.sd_checkpoint_info: | |
assert not loaded_vae_file, "Trying to store non-base VAE!" | |
base_vae = deepcopy(model.first_stage_model.state_dict()) | |
checkpoint_info = model.sd_checkpoint_info | |
def delete_base_vae(): | |
global base_vae, checkpoint_info | |
base_vae = None | |
checkpoint_info = None | |
def restore_base_vae(model): | |
global loaded_vae_file | |
if base_vae is not None and checkpoint_info == model.sd_checkpoint_info: | |
print("Restoring base VAE") | |
_load_vae_dict(model, base_vae) | |
loaded_vae_file = None | |
delete_base_vae() | |
def get_filename(filepath): | |
return os.path.basename(filepath) | |
def refresh_vae_list(): | |
vae_dict.clear() | |
paths = [ | |
os.path.join(sd_models.model_path, '**/*.vae.ckpt'), | |
os.path.join(sd_models.model_path, '**/*.vae.pt'), | |
os.path.join(sd_models.model_path, '**/*.vae.safetensors'), | |
os.path.join(vae_path, '**/*.ckpt'), | |
os.path.join(vae_path, '**/*.pt'), | |
os.path.join(vae_path, '**/*.safetensors'), | |
] | |
if shared.cmd_opts.ckpt_dir is not None and os.path.isdir(shared.cmd_opts.ckpt_dir): | |
paths += [ | |
os.path.join(shared.cmd_opts.ckpt_dir, '**/*.vae.ckpt'), | |
os.path.join(shared.cmd_opts.ckpt_dir, '**/*.vae.pt'), | |
os.path.join(shared.cmd_opts.ckpt_dir, '**/*.vae.safetensors'), | |
] | |
if shared.cmd_opts.vae_dir is not None and os.path.isdir(shared.cmd_opts.vae_dir): | |
paths += [ | |
os.path.join(shared.cmd_opts.vae_dir, '**/*.ckpt'), | |
os.path.join(shared.cmd_opts.vae_dir, '**/*.pt'), | |
os.path.join(shared.cmd_opts.vae_dir, '**/*.safetensors'), | |
] | |
candidates = [] | |
for path in paths: | |
candidates += glob.iglob(path, recursive=True) | |
for filepath in candidates: | |
name = get_filename(filepath) | |
vae_dict[name] = filepath | |
def find_vae_near_checkpoint(checkpoint_file): | |
checkpoint_path = os.path.splitext(checkpoint_file)[0] | |
for vae_location in [checkpoint_path + ".vae.pt", checkpoint_path + ".vae.ckpt", checkpoint_path + ".vae.safetensors"]: | |
if os.path.isfile(vae_location): | |
return vae_location | |
return None | |
def resolve_vae(checkpoint_file): | |
if shared.cmd_opts.vae_path is not None: | |
return shared.cmd_opts.vae_path, 'from commandline argument' | |
is_automatic = shared.opts.sd_vae in {"Automatic", "auto"} # "auto" for people with old config | |
vae_near_checkpoint = find_vae_near_checkpoint(checkpoint_file) | |
if vae_near_checkpoint is not None and (shared.opts.sd_vae_as_default or is_automatic): | |
return vae_near_checkpoint, 'found near the checkpoint' | |
if shared.opts.sd_vae == "None": | |
return None, None | |
vae_from_options = vae_dict.get(shared.opts.sd_vae, None) | |
if vae_from_options is not None: | |
return vae_from_options, 'specified in settings' | |
if not is_automatic: | |
print(f"Couldn't find VAE named {shared.opts.sd_vae}; using None instead") | |
return None, None | |
def load_vae_dict(filename, map_location): | |
vae_ckpt = sd_models.read_state_dict(filename, map_location=map_location) | |
vae_dict_1 = {k: v for k, v in vae_ckpt.items() if k[0:4] != "loss" and k not in vae_ignore_keys} | |
return vae_dict_1 | |
def load_vae(model, vae_file=None, vae_source="from unknown source"): | |
global vae_dict, loaded_vae_file | |
# save_settings = False | |
cache_enabled = shared.opts.sd_vae_checkpoint_cache > 0 | |
if vae_file: | |
if cache_enabled and vae_file in checkpoints_loaded: | |
# use vae checkpoint cache | |
print(f"Loading VAE weights {vae_source}: cached {get_filename(vae_file)}") | |
store_base_vae(model) | |
_load_vae_dict(model, checkpoints_loaded[vae_file]) | |
else: | |
assert os.path.isfile(vae_file), f"VAE {vae_source} doesn't exist: {vae_file}" | |
print(f"Loading VAE weights {vae_source}: {vae_file}") | |
store_base_vae(model) | |
vae_dict_1 = load_vae_dict(vae_file, map_location=shared.weight_load_location) | |
_load_vae_dict(model, vae_dict_1) | |
if cache_enabled: | |
# cache newly loaded vae | |
checkpoints_loaded[vae_file] = vae_dict_1.copy() | |
# clean up cache if limit is reached | |
if cache_enabled: | |
while len(checkpoints_loaded) > shared.opts.sd_vae_checkpoint_cache + 1: # we need to count the current model | |
checkpoints_loaded.popitem(last=False) # LRU | |
# If vae used is not in dict, update it | |
# It will be removed on refresh though | |
vae_opt = get_filename(vae_file) | |
if vae_opt not in vae_dict: | |
vae_dict[vae_opt] = vae_file | |
elif loaded_vae_file: | |
restore_base_vae(model) | |
loaded_vae_file = vae_file | |
# don't call this from outside | |
def _load_vae_dict(model, vae_dict_1): | |
model.first_stage_model.load_state_dict(vae_dict_1) | |
model.first_stage_model.to(devices.dtype_vae) | |
def clear_loaded_vae(): | |
global loaded_vae_file | |
loaded_vae_file = None | |
unspecified = object() | |
def reload_vae_weights(sd_model=None, vae_file=unspecified): | |
from modules import lowvram, devices, sd_hijack | |
if not sd_model: | |
sd_model = shared.sd_model | |
checkpoint_info = sd_model.sd_checkpoint_info | |
checkpoint_file = checkpoint_info.filename | |
if vae_file == unspecified: | |
vae_file, vae_source = resolve_vae(checkpoint_file) | |
else: | |
vae_source = "from function argument" | |
if loaded_vae_file == vae_file: | |
return | |
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram: | |
lowvram.send_everything_to_cpu() | |
else: | |
sd_model.to(devices.cpu) | |
sd_hijack.model_hijack.undo_hijack(sd_model) | |
load_vae(sd_model, vae_file, vae_source) | |
sd_hijack.model_hijack.hijack(sd_model) | |
script_callbacks.model_loaded_callback(sd_model) | |
if not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram: | |
sd_model.to(devices.device) | |
print("VAE weights loaded.") | |
return sd_model | |