|
from typing import * |
|
import torch |
|
import torch.nn.functional as F |
|
from torchvision import transforms |
|
import numpy as np |
|
from PIL import Image |
|
|
|
from ....utils import dist_utils |
|
|
|
|
|
class ImageConditionedMixin: |
|
""" |
|
Mixin for image-conditioned models. |
|
|
|
Args: |
|
image_cond_model: The image conditioning model. |
|
""" |
|
def __init__(self, *args, image_cond_model: str = 'dinov2_vitl14_reg', **kwargs): |
|
super().__init__(*args, **kwargs) |
|
self.image_cond_model_name = image_cond_model |
|
self.image_cond_model = None |
|
|
|
@staticmethod |
|
def prepare_for_training(image_cond_model: str, **kwargs): |
|
""" |
|
Prepare for training. |
|
""" |
|
if hasattr(super(ImageConditionedMixin, ImageConditionedMixin), 'prepare_for_training'): |
|
super(ImageConditionedMixin, ImageConditionedMixin).prepare_for_training(**kwargs) |
|
|
|
torch.hub.load('facebookresearch/dinov2', image_cond_model, pretrained=True) |
|
|
|
def _init_image_cond_model(self): |
|
""" |
|
Initialize the image conditioning model. |
|
""" |
|
with dist_utils.local_master_first(): |
|
dinov2_model = torch.hub.load('facebookresearch/dinov2', self.image_cond_model_name, pretrained=True) |
|
dinov2_model.eval().cuda() |
|
transform = transforms.Compose([ |
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), |
|
]) |
|
self.image_cond_model = { |
|
'model': dinov2_model, |
|
'transform': transform, |
|
} |
|
|
|
@torch.no_grad() |
|
def encode_image(self, image: Union[torch.Tensor, List[Image.Image]]) -> torch.Tensor: |
|
""" |
|
Encode the image. |
|
""" |
|
if isinstance(image, torch.Tensor): |
|
assert image.ndim == 4, "Image tensor should be batched (B, C, H, W)" |
|
elif isinstance(image, list): |
|
assert all(isinstance(i, Image.Image) for i in image), "Image list should be list of PIL images" |
|
image = [i.resize((518, 518), Image.LANCZOS) for i in image] |
|
image = [np.array(i.convert('RGB')).astype(np.float32) / 255 for i in image] |
|
image = [torch.from_numpy(i).permute(2, 0, 1).float() for i in image] |
|
image = torch.stack(image).cuda() |
|
else: |
|
raise ValueError(f"Unsupported type of image: {type(image)}") |
|
|
|
if self.image_cond_model is None: |
|
self._init_image_cond_model() |
|
image = self.image_cond_model['transform'](image).cuda() |
|
features = self.image_cond_model['model'](image, is_training=True)['x_prenorm'] |
|
patchtokens = F.layer_norm(features, features.shape[-1:]) |
|
return patchtokens |
|
|
|
def get_cond(self, cond, **kwargs): |
|
""" |
|
Get the conditioning data. |
|
""" |
|
cond = self.encode_image(cond) |
|
kwargs['neg_cond'] = torch.zeros_like(cond) |
|
cond = super().get_cond(cond, **kwargs) |
|
return cond |
|
|
|
def get_inference_cond(self, cond, **kwargs): |
|
""" |
|
Get the conditioning data for inference. |
|
""" |
|
cond = self.encode_image(cond) |
|
kwargs['neg_cond'] = torch.zeros_like(cond) |
|
cond = super().get_inference_cond(cond, **kwargs) |
|
return cond |
|
|
|
def vis_cond(self, cond, **kwargs): |
|
""" |
|
Visualize the conditioning data. |
|
""" |
|
return {'image': {'value': cond, 'type': 'image'}} |
|
|