import os from typing import Optional, Tuple, Union, Dict from PIL import Image from functools import partial, reduce from transformers import SiglipImageProcessor, SiglipVisionConfig, SiglipVisionModel import torch.distributed as dist from abc import ABC, abstractmethod import torch import torch.nn as nn import torch.nn.functional as F from transformers.image_processing_utils import BatchFeature, get_size_dict from transformers.image_transforms import ( convert_to_rgb, normalize, rescale, resize, to_channel_dimension_format, ) from transformers.image_utils import ( ChannelDimension, PILImageResampling, to_numpy_array, ) def rank0_print(*args): if dist.is_initialized(): if dist.get_rank() == 0: print(f"Rank {dist.get_rank()}: ", *args) else: print(*args) class BaseVisionTower(nn.Module): def __init__(self, vision_tower_name, vision_tower_cfg, delay_load=False): super().__init__() self.is_loaded = False self.vision_tower_name = vision_tower_name self.delay_load = delay_load @abstractmethod def load_model(self, device_map=None): raise NotImplementedError("Subclasses must implement load_model") @abstractmethod def _forward(self, images): raise NotImplementedError("Subclasses must implement forward") def forward(self, images): if type(images) is list: image_features = [self._forward(image.unsqueeze(0)) for image in images] else: image_features = self._forward(images) return image_features @property def dummy_feature(self): return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype) @property def dtype(self): # Dynamically infer the dtype from the first parameter, if not explicitly specified if hasattr(self.vision_tower, "dtype"): return self.vision_tower.dtype else: params = list(self.vision_tower.parameters()) return ( params[0].dtype if len(params) > 0 else torch.float32 ) # Default to torch.float32 if no parameters @property def device(self): # Dynamically infer the device from the first parameter, if not explicitly specified if hasattr(self.vision_tower, "device"): return self.vision_tower.device else: params = list(self.vision_tower.parameters()) return ( params[0].device if len(params) > 0 else torch.device("cpu") ) # Default to CPU if no parameters @property def config(self): if self.is_loaded: return self.vision_tower.config else: return self.cfg_only @property def hidden_size(self): try: return self.config.hidden_size except: return self._hidden_size class SigLipImageProcessor: def __init__(self, image_mean=(0.5, 0.5, 0.5), image_std=(0.5, 0.5, 0.5), size=(384, 384), crop_size: Dict[str, int] = None, resample=PILImageResampling.BICUBIC, rescale_factor=1 / 255, data_format=ChannelDimension.FIRST): crop_size = crop_size if crop_size is not None else {"height": 384, "width": 384} crop_size = get_size_dict(crop_size, default_to_square=True, param_name="crop_size") self.image_mean = image_mean self.image_std = image_std self.size = size self.resample = resample self.rescale_factor = rescale_factor self.data_format = data_format self.crop_size = crop_size def preprocess(self, images, return_tensors): if isinstance(images, Image.Image): images = [images] else: # to adapt video data images = [to_numpy_array(image) for image in images] assert isinstance(images, list) transforms = [ convert_to_rgb, to_numpy_array, partial(resize, size=self.size, resample=self.resample, data_format=self.data_format), partial(rescale, scale=self.rescale_factor, data_format=self.data_format), partial(normalize, mean=self.image_mean, std=self.image_std, data_format=self.data_format), partial(to_channel_dimension_format, channel_dim=self.data_format, input_channel_dim=self.data_format), ] images = reduce(lambda x, f: [*map(f, x)], transforms, images) data = {"pixel_values": images} return BatchFeature(data=data, tensor_type=return_tensors) class SigLipVisionTower(BaseVisionTower): def __init__(self, vision_tower_name, vision_tower_cfg, delay_load=False): super(SigLipVisionTower, self).__init__(vision_tower_name, vision_tower_cfg, delay_load) # model_path = "google/siglip-so400m-patch14-384" # base_model_name, res, interp = model_path, 384, 576 # self.vision_tower_name = base_model_name self.vision_tower_name, res, interp = vision_tower_name, 384, 576 self._image_size = res if res is not None else 512 self.unfreeze_mm_vision_tower = getattr(vision_tower_cfg, "unfreeze_mm_vision_tower", False) if not delay_load: rank0_print(f"Loading vision tower: {vision_tower_name}") self.load_model() elif getattr(vision_tower_cfg, "unfreeze_mm_vision_tower", False): # TODO: better detector is needed. rank0_print(f"The checkpoint seems to contain `vision_tower` weights: `unfreeze_mm_vision_tower`: True.") self.load_model() elif hasattr(vision_tower_cfg, "mm_tunable_parts") and "mm_vision_tower" in vision_tower_cfg.mm_tunable_parts: rank0_print(f"The checkpoint seems to contain `vision_tower` weights: `mm_tunable_parts` contains `mm_vision_tower`.") self.load_model() else: self.cfg_only = self.config def load_model(self, device_map=None): self.vision_model = "siglip" # clip_model, processor = create_model_from_pretrained(self.vision_tower_name) print(self.vision_tower_name) self.vision_tower = SiglipVisionModel.from_pretrained(self.vision_tower_name) # self.vision_tower = clip_model.visual.trunk self.vision_tower.output_tokens = True self._hidden_size = self.vision_tower.config.hidden_size self.image_processor = SigLipImageProcessor() del self.vision_tower.vision_model.encoder.layers[-1:] self.vision_tower.vision_model.head = nn.Identity() self.vision_tower.requires_grad_(self.unfreeze_mm_vision_tower) self.is_loaded = True def _forward(self, images): with torch.set_grad_enabled(self.unfreeze_mm_vision_tower): image_features = self.vision_tower.forward( images.to(device=self.device, dtype=self.dtype), output_hidden_states=True, ).hidden_states[-1] return image_features @property def dummy_feature(self): return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype) @property def dtype(self): for p in self.vision_tower.parameters(): return p.dtype @property def device(self): for p in self.vision_tower.parameters(): return p.device @property def hidden_size(self): return self.config.hidden_size @property def num_patches(self): return (336 // 14) ** 2 @property def num_patches_per_side(self): #return self.config.image_size // self.config.patch_size return 336//14 #return 27 # return self.model_config["vision_cfg"]["image_size"] // self.model_config["vision_cfg"]["patch_size"] @property def image_size(self): return 384 def build_vision_tower(vision_tower_cfg, **kwargs): vision_tower = getattr(vision_tower_cfg, "mm_vision_tower", getattr(vision_tower_cfg, "vision_tower", None)) is_absolute_path_exists = os.path.exists(vision_tower) use_s2 = getattr(vision_tower_cfg, "s2", False) #print(getattr(vision_tower_cfg, "vision_tower", None)) return SigLipVisionTower(vision_tower, vision_tower_cfg=vision_tower_cfg, **kwargs) if getattr(vision_tower_cfg, "vision_tower", None) and "siglip" in getattr(vision_tower_cfg, "vision_tower", None).lower(): #print('*************\n') return SigLipVisionTower(vision_tower, vision_tower_cfg=vision_tower_cfg, **kwargs) raise ValueError(f"Unknown vision tower: {vision_tower}")