File size: 5,397 Bytes
7aefe45 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 |
from typing import AnyStr
import pathlib
from collections import OrderedDict
from packaging import version
import torch
from diffusers import StableDiffusionPipeline, SchedulerMixin, DiffusionPipeline
from diffusers.utils import is_torch_version, is_xformers_available
huggingface_model_dict = OrderedDict({
"sd14": "/nfs/StableDiffusionModels/CompVis/stable-diffusion-v1-4", # resolution: 512
"sd15": "/nfs/StableDiffusionModels/runwayml/stable-diffusion-v1-5", # resolution: 512
"sd21b": "stabilityai/stable-diffusion-2-1-base", # resolution: 512
"sd21": "stabilityai/stable-diffusion-2-1", # resolution: 768
"sdxl": "stabilityai/stable-diffusion-xl-base-1.0", # resolution: 1024
})
_model2resolution = {
"sd14": 512,
"sd15": 512,
"sd21b": 512,
"sd21": 768,
"sdxl": 1024,
}
def model2res(model_id: str):
return _model2resolution.get(model_id, 512)
def init_diffusion_pipeline(model_id: AnyStr,
custom_pipeline: StableDiffusionPipeline,
custom_scheduler: SchedulerMixin = None,
device: torch.device = "cuda",
torch_dtype: torch.dtype = torch.float32,
local_files_only: bool = True,
force_download: bool = False,
resume_download: bool = False,
ldm_speed_up: bool = False,
enable_xformers: bool = True,
gradient_checkpoint: bool = False,
lora_path: AnyStr = None,
unet_path: AnyStr = None) -> StableDiffusionPipeline:
"""
A tool for initial diffusers model.
Args:
model_id (`str` or `os.PathLike`, *optional*): pretrained_model_name_or_path
custom_pipeline: any StableDiffusionPipeline pipeline
custom_scheduler: any scheduler
device: set device
local_files_only: prohibited download model
force_download: forced download model
resume_download: re-download model
ldm_speed_up: use the `torch.compile` api to speed up unet
enable_xformers: enable memory efficient attention from [xFormers]
gradient_checkpoint: activates gradient checkpointing for the current model
lora_path: load LoRA checkpoint
unet_path: load unet checkpoint
Returns:
diffusers.StableDiffusionPipeline
"""
# get model id
model_id = huggingface_model_dict.get(model_id, model_id)
# process diffusion model
if custom_scheduler is not None:
pipeline = custom_pipeline.from_pretrained(
model_id,
torch_dtype=torch_dtype,
local_files_only=local_files_only,
force_download=force_download,
resume_download=resume_download,
scheduler=custom_scheduler.from_pretrained(model_id,
subfolder="scheduler",
local_files_only=local_files_only)
).to(device)
else:
pipeline = custom_pipeline.from_pretrained(
model_id,
torch_dtype=torch_dtype,
local_files_only=local_files_only,
force_download=force_download,
resume_download=resume_download,
).to(device)
# process unet model if exist
if unet_path is not None and pathlib.Path(unet_path).exists():
print(f"=> load u-net from {unet_path}")
pipeline.unet.from_pretrained(model_id, subfolder="unet")
# process lora layers if exist
if lora_path is not None and pathlib.Path(lora_path).exists():
pipeline.unet.load_attn_procs(lora_path)
print(f"=> load lora layers into U-Net from {lora_path} ...")
# torch.compile
if ldm_speed_up:
if is_torch_version(">=", "2.0.0"):
pipeline.unet = torch.compile(pipeline.unet, mode="reduce-overhead", fullgraph=True)
print(f"=> enable torch.compile on U-Net")
else:
print(f"=> warning: calling torch.compile speed-up failed, since torch version <= 2.0.0")
# Meta xformers
if enable_xformers:
if is_xformers_available():
import xformers
xformers_version = version.parse(xformers.__version__)
if xformers_version == version.parse("0.0.16"):
print(
"xFormers 0.0.16 cannot be used for training in some GPUs. "
"If you observe problems during training, please update xFormers to at least 0.0.17. "
"See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
)
print(f"=> enable xformers")
pipeline.unet.enable_xformers_memory_efficient_attention()
else:
print(f"=> warning: calling xformers failed")
# gradient checkpointing
if gradient_checkpoint:
if pipeline.unet.is_gradient_checkpointing:
print(f"=> enable gradient checkpointing")
pipeline.unet.enable_gradient_checkpointing()
else:
print("=> waring: gradient checkpointing is not activated for this model.")
print(f"Diffusion Model: {model_id}")
print(pipeline.scheduler)
return pipeline
|