import gc from typing import TYPE_CHECKING if TYPE_CHECKING: from collections.abc import Callable import numpy as np import PIL.Image import torch from controlnet_aux import ( CannyDetector, ContentShuffleDetector, HEDdetector, LineartAnimeDetector, LineartDetector, MidasDetector, MLSDdetector, NormalBaeDetector, OpenposeDetector, PidiNetDetector, ) from controlnet_aux.util import HWC3 from cv_utils import resize_image from depth_estimator import DepthEstimator from image_segmentor import ImageSegmentor class Preprocessor: MODEL_ID = "lllyasviel/Annotators" def __init__(self) -> None: self.model: Callable = None # type: ignore self.name = "" def load(self, name: str) -> None: # noqa: C901, PLR0912 if name == self.name: return if name == "HED": self.model = HEDdetector.from_pretrained(self.MODEL_ID) elif name == "Midas": self.model = MidasDetector.from_pretrained(self.MODEL_ID) elif name == "MLSD": self.model = MLSDdetector.from_pretrained(self.MODEL_ID) elif name == "Openpose": self.model = OpenposeDetector.from_pretrained(self.MODEL_ID) elif name == "PidiNet": self.model = PidiNetDetector.from_pretrained(self.MODEL_ID) elif name == "NormalBae": self.model = NormalBaeDetector.from_pretrained(self.MODEL_ID) elif name == "Lineart": self.model = LineartDetector.from_pretrained(self.MODEL_ID) elif name == "LineartAnime": self.model = LineartAnimeDetector.from_pretrained(self.MODEL_ID) elif name == "Canny": self.model = CannyDetector() elif name == "ContentShuffle": self.model = ContentShuffleDetector() elif name == "DPT": self.model = DepthEstimator() elif name == "UPerNet": self.model = ImageSegmentor() elif name == 'texnet': self.model = TexnetPreprocessor() else: raise ValueError torch.cuda.empty_cache() gc.collect() self.name = name def __call__(self, image: PIL.Image.Image, **kwargs) -> PIL.Image.Image: # noqa: ANN003 if self.name == "Canny": if "detect_resolution" in kwargs: detect_resolution = kwargs.pop("detect_resolution") image = np.array(image) image = HWC3(image) image = resize_image(image, resolution=detect_resolution) image = self.model(image, **kwargs) return PIL.Image.fromarray(image) if self.name == "Midas": detect_resolution = kwargs.pop("detect_resolution", 512) image_resolution = kwargs.pop("image_resolution", 512) image = np.array(image) image = HWC3(image) image = resize_image(image, resolution=detect_resolution) image = self.model(image, **kwargs) image = HWC3(image) image = resize_image(image, resolution=image_resolution) return PIL.Image.fromarray(image) return self.model(image, **kwargs) # https://github.com/huggingface/controlnet_aux/blob/master/src/controlnet_aux/canny/__init__.py class TexnetPreprocessor: def __call__(self, input_image=None, low_threshold=100, high_threshold=200, image_resolution=512, output_type=None, **kwargs): if "img" in kwargs: warnings.warn("img is deprecated, please use `input_image=...` instead.", DeprecationWarning) input_image = kwargs.pop("img") if input_image is None: raise ValueError("input_image must be defined.") if not isinstance(input_image, np.ndarray): input_image = np.array(input_image, dtype=np.uint8) output_type = output_type or "pil" else: output_type = output_type or "np" input_image = HWC3(input_image) input_image = resize_image(input_image, image_resolution) H, W, C = input_image.shape # detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_LINEAR) output_image = input_image.copy() if output_type == "pil": # detected_map = Image.fromarray(detected_map) output_image = PIL.Image.fromarray(output_image) return output_image