|
from typing import AnyStr |
|
import pathlib |
|
from collections import OrderedDict |
|
from packaging import version |
|
|
|
import torch |
|
from diffusers import StableDiffusionPipeline, SchedulerMixin, DiffusionPipeline |
|
from diffusers.utils import is_torch_version, is_xformers_available |
|
|
|
huggingface_model_dict = OrderedDict({ |
|
"sd14": "/nfs/StableDiffusionModels/CompVis/stable-diffusion-v1-4", |
|
"sd15": "/nfs/StableDiffusionModels/runwayml/stable-diffusion-v1-5", |
|
"sd21b": "stabilityai/stable-diffusion-2-1-base", |
|
"sd21": "stabilityai/stable-diffusion-2-1", |
|
"sdxl": "stabilityai/stable-diffusion-xl-base-1.0", |
|
}) |
|
|
|
_model2resolution = { |
|
"sd14": 512, |
|
"sd15": 512, |
|
"sd21b": 512, |
|
"sd21": 768, |
|
"sdxl": 1024, |
|
} |
|
|
|
|
|
def model2res(model_id: str): |
|
return _model2resolution.get(model_id, 512) |
|
|
|
|
|
def init_diffusion_pipeline(model_id: AnyStr, |
|
custom_pipeline: StableDiffusionPipeline, |
|
custom_scheduler: SchedulerMixin = None, |
|
device: torch.device = "cuda", |
|
torch_dtype: torch.dtype = torch.float32, |
|
local_files_only: bool = True, |
|
force_download: bool = False, |
|
resume_download: bool = False, |
|
ldm_speed_up: bool = False, |
|
enable_xformers: bool = True, |
|
gradient_checkpoint: bool = False, |
|
lora_path: AnyStr = None, |
|
unet_path: AnyStr = None) -> StableDiffusionPipeline: |
|
""" |
|
A tool for initial diffusers model. |
|
|
|
Args: |
|
model_id (`str` or `os.PathLike`, *optional*): pretrained_model_name_or_path |
|
custom_pipeline: any StableDiffusionPipeline pipeline |
|
custom_scheduler: any scheduler |
|
device: set device |
|
local_files_only: prohibited download model |
|
force_download: forced download model |
|
resume_download: re-download model |
|
ldm_speed_up: use the `torch.compile` api to speed up unet |
|
enable_xformers: enable memory efficient attention from [xFormers] |
|
gradient_checkpoint: activates gradient checkpointing for the current model |
|
lora_path: load LoRA checkpoint |
|
unet_path: load unet checkpoint |
|
|
|
Returns: |
|
diffusers.StableDiffusionPipeline |
|
""" |
|
|
|
|
|
model_id = huggingface_model_dict.get(model_id, model_id) |
|
|
|
|
|
if custom_scheduler is not None: |
|
pipeline = custom_pipeline.from_pretrained( |
|
model_id, |
|
torch_dtype=torch_dtype, |
|
local_files_only=local_files_only, |
|
force_download=force_download, |
|
resume_download=resume_download, |
|
scheduler=custom_scheduler.from_pretrained(model_id, |
|
subfolder="scheduler", |
|
local_files_only=local_files_only) |
|
).to(device) |
|
else: |
|
pipeline = custom_pipeline.from_pretrained( |
|
model_id, |
|
torch_dtype=torch_dtype, |
|
local_files_only=local_files_only, |
|
force_download=force_download, |
|
resume_download=resume_download, |
|
).to(device) |
|
|
|
|
|
if unet_path is not None and pathlib.Path(unet_path).exists(): |
|
print(f"=> load u-net from {unet_path}") |
|
pipeline.unet.from_pretrained(model_id, subfolder="unet") |
|
|
|
|
|
if lora_path is not None and pathlib.Path(lora_path).exists(): |
|
pipeline.unet.load_attn_procs(lora_path) |
|
print(f"=> load lora layers into U-Net from {lora_path} ...") |
|
|
|
|
|
if ldm_speed_up: |
|
if is_torch_version(">=", "2.0.0"): |
|
pipeline.unet = torch.compile(pipeline.unet, mode="reduce-overhead", fullgraph=True) |
|
print(f"=> enable torch.compile on U-Net") |
|
else: |
|
print(f"=> warning: calling torch.compile speed-up failed, since torch version <= 2.0.0") |
|
|
|
|
|
if enable_xformers: |
|
if is_xformers_available(): |
|
import xformers |
|
|
|
xformers_version = version.parse(xformers.__version__) |
|
if xformers_version == version.parse("0.0.16"): |
|
print( |
|
"xFormers 0.0.16 cannot be used for training in some GPUs. " |
|
"If you observe problems during training, please update xFormers to at least 0.0.17. " |
|
"See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details." |
|
) |
|
print(f"=> enable xformers") |
|
pipeline.unet.enable_xformers_memory_efficient_attention() |
|
else: |
|
print(f"=> warning: calling xformers failed") |
|
|
|
|
|
if gradient_checkpoint: |
|
if pipeline.unet.is_gradient_checkpointing: |
|
print(f"=> enable gradient checkpointing") |
|
pipeline.unet.enable_gradient_checkpointing() |
|
else: |
|
print("=> waring: gradient checkpointing is not activated for this model.") |
|
|
|
print(f"Diffusion Model: {model_id}") |
|
print(pipeline.scheduler) |
|
return pipeline |
|
|