ZenCtrl / flux /condition.py
salso's picture
Upload 7 files
5432315 verified
# Recycled from Ominicontrol and modified to accept an extra condition.
# While Zenctrl pursued a similar idea, it diverged structurally.
# We appreciate the clarity of Omini's implementation and decided to align with it.
import torch
from typing import Optional, Union, List, Tuple
from diffusers.pipelines import FluxPipeline
from PIL import Image, ImageFilter
import numpy as np
import cv2
# from pipeline_tools import encode_images
from .pipeline_tools import encode_images
condition_dict = {
"subject": 1,
"sr": 2,
"cot": 3,
}
class Condition(object):
def __init__(
self,
condition_type: str,
raw_img: Union[Image.Image, torch.Tensor] = None,
condition: Union[Image.Image, torch.Tensor] = None,
position_delta=None,
) -> None:
self.condition_type = condition_type
assert raw_img is not None or condition is not None
if raw_img is not None:
self.condition = self.get_condition(condition_type, raw_img)
else:
self.condition = condition
self.position_delta = position_delta
def get_condition(
self, condition_type: str, raw_img: Union[Image.Image, torch.Tensor]
) -> Union[Image.Image, torch.Tensor]:
"""
Returns the condition image.
"""
if condition_type == "subject":
return raw_img
elif condition_type == "sr":
return raw_img
elif condition_type == "cot":
return raw_img.convert("RGB")
return self.condition
@property
def type_id(self) -> int:
"""
Returns the type id of the condition.
"""
return condition_dict[self.condition_type]
def encode(
self, pipe: FluxPipeline, empty: bool = False
) -> Tuple[torch.Tensor, torch.Tensor, int]:
"""
Encodes the condition into tokens, ids and type_id.
"""
if self.condition_type in [
"subject",
"sr",
"cot"
]:
if empty:
# make the condition black
e_condition = Image.new("RGB", self.condition.size, (0, 0, 0))
e_condition = e_condition.convert("RGB")
tokens, ids = encode_images(pipe, e_condition)
else:
tokens, ids = encode_images(pipe, self.condition)
else:
raise NotImplementedError(
f"Condition type {self.condition_type} not implemented"
)
if self.position_delta is None and self.condition_type == "subject":
self.position_delta = [0, -self.condition.size[0] // 16]
if self.position_delta is not None:
ids[:, 1] += self.position_delta[0]
ids[:, 2] += self.position_delta[1]
type_id = torch.ones_like(ids[:, :1]) * self.type_id
return tokens, ids, type_id