|
import os |
|
from typing import Optional, Tuple, Union, Dict |
|
from PIL import Image |
|
from functools import partial, reduce |
|
from transformers import SiglipImageProcessor, SiglipVisionConfig, SiglipVisionModel |
|
import torch.distributed as dist |
|
from abc import ABC, abstractmethod |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from transformers.image_processing_utils import BatchFeature, get_size_dict |
|
from transformers.image_transforms import ( |
|
convert_to_rgb, |
|
normalize, |
|
rescale, |
|
resize, |
|
to_channel_dimension_format, |
|
) |
|
from transformers.image_utils import ( |
|
ChannelDimension, |
|
PILImageResampling, |
|
to_numpy_array, |
|
) |
|
|
|
def rank0_print(*args): |
|
if dist.is_initialized(): |
|
if dist.get_rank() == 0: |
|
print(f"Rank {dist.get_rank()}: ", *args) |
|
else: |
|
print(*args) |
|
|
|
|
|
class BaseVisionTower(nn.Module): |
|
def __init__(self, vision_tower_name, vision_tower_cfg, delay_load=False): |
|
super().__init__() |
|
|
|
self.is_loaded = False |
|
|
|
self.vision_tower_name = vision_tower_name |
|
self.delay_load = delay_load |
|
|
|
@abstractmethod |
|
def load_model(self, device_map=None): |
|
raise NotImplementedError("Subclasses must implement load_model") |
|
|
|
@abstractmethod |
|
def _forward(self, images): |
|
raise NotImplementedError("Subclasses must implement forward") |
|
|
|
def forward(self, images): |
|
if type(images) is list: |
|
image_features = [self._forward(image.unsqueeze(0)) for image in images] |
|
else: |
|
image_features = self._forward(images) |
|
|
|
return image_features |
|
|
|
@property |
|
def dummy_feature(self): |
|
return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype) |
|
|
|
@property |
|
def dtype(self): |
|
|
|
if hasattr(self.vision_tower, "dtype"): |
|
return self.vision_tower.dtype |
|
else: |
|
params = list(self.vision_tower.parameters()) |
|
return ( |
|
params[0].dtype if len(params) > 0 else torch.float32 |
|
) |
|
|
|
@property |
|
def device(self): |
|
|
|
if hasattr(self.vision_tower, "device"): |
|
return self.vision_tower.device |
|
else: |
|
params = list(self.vision_tower.parameters()) |
|
return ( |
|
params[0].device if len(params) > 0 else torch.device("cpu") |
|
) |
|
@property |
|
def config(self): |
|
if self.is_loaded: |
|
return self.vision_tower.config |
|
else: |
|
return self.cfg_only |
|
@property |
|
def hidden_size(self): |
|
try: |
|
return self.config.hidden_size |
|
except: |
|
return self._hidden_size |
|
|
|
class SigLipImageProcessor: |
|
def __init__(self, image_mean=(0.5, 0.5, 0.5), image_std=(0.5, 0.5, 0.5), size=(384, 384), crop_size: Dict[str, int] = None, resample=PILImageResampling.BICUBIC, rescale_factor=1 / 255, data_format=ChannelDimension.FIRST): |
|
crop_size = crop_size if crop_size is not None else {"height": 384, "width": 384} |
|
crop_size = get_size_dict(crop_size, default_to_square=True, param_name="crop_size") |
|
|
|
self.image_mean = image_mean |
|
self.image_std = image_std |
|
self.size = size |
|
self.resample = resample |
|
self.rescale_factor = rescale_factor |
|
self.data_format = data_format |
|
self.crop_size = crop_size |
|
|
|
def preprocess(self, images, return_tensors): |
|
if isinstance(images, Image.Image): |
|
images = [images] |
|
else: |
|
|
|
images = [to_numpy_array(image) for image in images] |
|
assert isinstance(images, list) |
|
|
|
transforms = [ |
|
convert_to_rgb, |
|
to_numpy_array, |
|
partial(resize, size=self.size, resample=self.resample, data_format=self.data_format), |
|
partial(rescale, scale=self.rescale_factor, data_format=self.data_format), |
|
partial(normalize, mean=self.image_mean, std=self.image_std, data_format=self.data_format), |
|
partial(to_channel_dimension_format, channel_dim=self.data_format, input_channel_dim=self.data_format), |
|
] |
|
|
|
images = reduce(lambda x, f: [*map(f, x)], transforms, images) |
|
|
|
data = {"pixel_values": images} |
|
|
|
return BatchFeature(data=data, tensor_type=return_tensors) |
|
|
|
class SigLipVisionTower(BaseVisionTower): |
|
def __init__(self, vision_tower_name, vision_tower_cfg, delay_load=False): |
|
super(SigLipVisionTower, self).__init__(vision_tower_name, vision_tower_cfg, delay_load) |
|
|
|
|
|
|
|
|
|
self.vision_tower_name, res, interp = vision_tower_name, 384, 576 |
|
self._image_size = res if res is not None else 512 |
|
self.unfreeze_mm_vision_tower = getattr(vision_tower_cfg, "unfreeze_mm_vision_tower", False) |
|
|
|
if not delay_load: |
|
rank0_print(f"Loading vision tower: {vision_tower_name}") |
|
self.load_model() |
|
elif getattr(vision_tower_cfg, "unfreeze_mm_vision_tower", False): |
|
|
|
rank0_print(f"The checkpoint seems to contain `vision_tower` weights: `unfreeze_mm_vision_tower`: True.") |
|
self.load_model() |
|
elif hasattr(vision_tower_cfg, "mm_tunable_parts") and "mm_vision_tower" in vision_tower_cfg.mm_tunable_parts: |
|
rank0_print(f"The checkpoint seems to contain `vision_tower` weights: `mm_tunable_parts` contains `mm_vision_tower`.") |
|
self.load_model() |
|
else: |
|
self.cfg_only = self.config |
|
|
|
def load_model(self, device_map=None): |
|
self.vision_model = "siglip" |
|
|
|
print(self.vision_tower_name) |
|
self.vision_tower = SiglipVisionModel.from_pretrained(self.vision_tower_name) |
|
|
|
|
|
self.vision_tower.output_tokens = True |
|
|
|
self._hidden_size = self.vision_tower.config.hidden_size |
|
|
|
self.image_processor = SigLipImageProcessor() |
|
|
|
del self.vision_tower.vision_model.encoder.layers[-1:] |
|
self.vision_tower.vision_model.head = nn.Identity() |
|
|
|
self.vision_tower.requires_grad_(self.unfreeze_mm_vision_tower) |
|
|
|
self.is_loaded = True |
|
|
|
def _forward(self, images): |
|
with torch.set_grad_enabled(self.unfreeze_mm_vision_tower): |
|
image_features = self.vision_tower.forward( |
|
images.to(device=self.device, dtype=self.dtype), |
|
output_hidden_states=True, |
|
).hidden_states[-1] |
|
return image_features |
|
@property |
|
def dummy_feature(self): |
|
return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype) |
|
|
|
@property |
|
def dtype(self): |
|
for p in self.vision_tower.parameters(): |
|
return p.dtype |
|
|
|
@property |
|
def device(self): |
|
for p in self.vision_tower.parameters(): |
|
return p.device |
|
|
|
@property |
|
def hidden_size(self): |
|
return self.config.hidden_size |
|
|
|
@property |
|
def num_patches(self): |
|
return (336 // 14) ** 2 |
|
|
|
@property |
|
def num_patches_per_side(self): |
|
|
|
return 336//14 |
|
|
|
|
|
|
|
@property |
|
def image_size(self): |
|
return 384 |
|
|
|
def build_vision_tower(vision_tower_cfg, **kwargs): |
|
|
|
vision_tower = getattr(vision_tower_cfg, "mm_vision_tower", getattr(vision_tower_cfg, "vision_tower", None)) |
|
is_absolute_path_exists = os.path.exists(vision_tower) |
|
use_s2 = getattr(vision_tower_cfg, "s2", False) |
|
|
|
|
|
return SigLipVisionTower(vision_tower, vision_tower_cfg=vision_tower_cfg, **kwargs) |
|
if getattr(vision_tower_cfg, "vision_tower", None) and "siglip" in getattr(vision_tower_cfg, "vision_tower", None).lower(): |
|
|
|
return SigLipVisionTower(vision_tower, vision_tower_cfg=vision_tower_cfg, **kwargs) |
|
|
|
|
|
raise ValueError(f"Unknown vision tower: {vision_tower}") |
|
|