matgen / preprocessor.py
jingyangcarl's picture
update
9d36776
import gc
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from collections.abc import Callable
import numpy as np
import PIL.Image
import torch
from controlnet_aux import (
CannyDetector,
ContentShuffleDetector,
HEDdetector,
LineartAnimeDetector,
LineartDetector,
MidasDetector,
MLSDdetector,
NormalBaeDetector,
OpenposeDetector,
PidiNetDetector,
)
from controlnet_aux.util import HWC3
from cv_utils import resize_image
from depth_estimator import DepthEstimator
from image_segmentor import ImageSegmentor
class Preprocessor:
MODEL_ID = "lllyasviel/Annotators"
def __init__(self) -> None:
self.model: Callable = None # type: ignore
self.name = ""
def load(self, name: str) -> None: # noqa: C901, PLR0912
if name == self.name:
return
if name == "HED":
self.model = HEDdetector.from_pretrained(self.MODEL_ID)
elif name == "Midas":
self.model = MidasDetector.from_pretrained(self.MODEL_ID)
elif name == "MLSD":
self.model = MLSDdetector.from_pretrained(self.MODEL_ID)
elif name == "Openpose":
self.model = OpenposeDetector.from_pretrained(self.MODEL_ID)
elif name == "PidiNet":
self.model = PidiNetDetector.from_pretrained(self.MODEL_ID)
elif name == "NormalBae":
self.model = NormalBaeDetector.from_pretrained(self.MODEL_ID)
elif name == "Lineart":
self.model = LineartDetector.from_pretrained(self.MODEL_ID)
elif name == "LineartAnime":
self.model = LineartAnimeDetector.from_pretrained(self.MODEL_ID)
elif name == "Canny":
self.model = CannyDetector()
elif name == "ContentShuffle":
self.model = ContentShuffleDetector()
elif name == "DPT":
self.model = DepthEstimator()
elif name == "UPerNet":
self.model = ImageSegmentor()
elif name == 'texnet':
self.model = TexnetPreprocessor()
else:
raise ValueError
torch.cuda.empty_cache()
gc.collect()
self.name = name
def __call__(self, image: PIL.Image.Image, **kwargs) -> PIL.Image.Image: # noqa: ANN003
if self.name == "Canny":
if "detect_resolution" in kwargs:
detect_resolution = kwargs.pop("detect_resolution")
image = np.array(image)
image = HWC3(image)
image = resize_image(image, resolution=detect_resolution)
image = self.model(image, **kwargs)
return PIL.Image.fromarray(image)
if self.name == "Midas":
detect_resolution = kwargs.pop("detect_resolution", 512)
image_resolution = kwargs.pop("image_resolution", 512)
image = np.array(image)
image = HWC3(image)
image = resize_image(image, resolution=detect_resolution)
image = self.model(image, **kwargs)
image = HWC3(image)
image = resize_image(image, resolution=image_resolution)
return PIL.Image.fromarray(image)
return self.model(image, **kwargs)
# https://github.com/huggingface/controlnet_aux/blob/master/src/controlnet_aux/canny/__init__.py
class TexnetPreprocessor:
def __call__(self, input_image=None, low_threshold=100, high_threshold=200, image_resolution=512, output_type=None, **kwargs):
if "img" in kwargs:
warnings.warn("img is deprecated, please use `input_image=...` instead.", DeprecationWarning)
input_image = kwargs.pop("img")
if input_image is None:
raise ValueError("input_image must be defined.")
if not isinstance(input_image, np.ndarray):
input_image = np.array(input_image, dtype=np.uint8)
output_type = output_type or "pil"
else:
output_type = output_type or "np"
input_image = HWC3(input_image)
input_image = resize_image(input_image, image_resolution)
H, W, C = input_image.shape
# detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_LINEAR)
output_image = input_image.copy()
if output_type == "pil":
# detected_map = Image.fromarray(detected_map)
output_image = PIL.Image.fromarray(output_image)
return output_image