BAAI
/

Video-XL-2 / vision_tower_builder.py
3v324v23's picture
fix bug
5644dea
import os
from typing import Optional, Tuple, Union, Dict
from PIL import Image
from functools import partial, reduce
from transformers import SiglipImageProcessor, SiglipVisionConfig, SiglipVisionModel
import torch.distributed as dist
from abc import ABC, abstractmethod
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers.image_processing_utils import BatchFeature, get_size_dict
from transformers.image_transforms import (
convert_to_rgb,
normalize,
rescale,
resize,
to_channel_dimension_format,
)
from transformers.image_utils import (
ChannelDimension,
PILImageResampling,
to_numpy_array,
)
def rank0_print(*args):
if dist.is_initialized():
if dist.get_rank() == 0:
print(f"Rank {dist.get_rank()}: ", *args)
else:
print(*args)
class BaseVisionTower(nn.Module):
def __init__(self, vision_tower_name, vision_tower_cfg, delay_load=False):
super().__init__()
self.is_loaded = False
self.vision_tower_name = vision_tower_name
self.delay_load = delay_load
@abstractmethod
def load_model(self, device_map=None):
raise NotImplementedError("Subclasses must implement load_model")
@abstractmethod
def _forward(self, images):
raise NotImplementedError("Subclasses must implement forward")
def forward(self, images):
if type(images) is list:
image_features = [self._forward(image.unsqueeze(0)) for image in images]
else:
image_features = self._forward(images)
return image_features
@property
def dummy_feature(self):
return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
@property
def dtype(self):
# Dynamically infer the dtype from the first parameter, if not explicitly specified
if hasattr(self.vision_tower, "dtype"):
return self.vision_tower.dtype
else:
params = list(self.vision_tower.parameters())
return (
params[0].dtype if len(params) > 0 else torch.float32
) # Default to torch.float32 if no parameters
@property
def device(self):
# Dynamically infer the device from the first parameter, if not explicitly specified
if hasattr(self.vision_tower, "device"):
return self.vision_tower.device
else:
params = list(self.vision_tower.parameters())
return (
params[0].device if len(params) > 0 else torch.device("cpu")
) # Default to CPU if no parameters
@property
def config(self):
if self.is_loaded:
return self.vision_tower.config
else:
return self.cfg_only
@property
def hidden_size(self):
try:
return self.config.hidden_size
except:
return self._hidden_size
class SigLipImageProcessor:
def __init__(self, image_mean=(0.5, 0.5, 0.5), image_std=(0.5, 0.5, 0.5), size=(384, 384), crop_size: Dict[str, int] = None, resample=PILImageResampling.BICUBIC, rescale_factor=1 / 255, data_format=ChannelDimension.FIRST):
crop_size = crop_size if crop_size is not None else {"height": 384, "width": 384}
crop_size = get_size_dict(crop_size, default_to_square=True, param_name="crop_size")
self.image_mean = image_mean
self.image_std = image_std
self.size = size
self.resample = resample
self.rescale_factor = rescale_factor
self.data_format = data_format
self.crop_size = crop_size
def preprocess(self, images, return_tensors):
if isinstance(images, Image.Image):
images = [images]
else:
# to adapt video data
images = [to_numpy_array(image) for image in images]
assert isinstance(images, list)
transforms = [
convert_to_rgb,
to_numpy_array,
partial(resize, size=self.size, resample=self.resample, data_format=self.data_format),
partial(rescale, scale=self.rescale_factor, data_format=self.data_format),
partial(normalize, mean=self.image_mean, std=self.image_std, data_format=self.data_format),
partial(to_channel_dimension_format, channel_dim=self.data_format, input_channel_dim=self.data_format),
]
images = reduce(lambda x, f: [*map(f, x)], transforms, images)
data = {"pixel_values": images}
return BatchFeature(data=data, tensor_type=return_tensors)
class SigLipVisionTower(BaseVisionTower):
def __init__(self, vision_tower_name, vision_tower_cfg, delay_load=False):
super(SigLipVisionTower, self).__init__(vision_tower_name, vision_tower_cfg, delay_load)
# model_path = "google/siglip-so400m-patch14-384"
# base_model_name, res, interp = model_path, 384, 576
# self.vision_tower_name = base_model_name
self.vision_tower_name, res, interp = vision_tower_name, 384, 576
self._image_size = res if res is not None else 512
self.unfreeze_mm_vision_tower = getattr(vision_tower_cfg, "unfreeze_mm_vision_tower", False)
if not delay_load:
rank0_print(f"Loading vision tower: {vision_tower_name}")
self.load_model()
elif getattr(vision_tower_cfg, "unfreeze_mm_vision_tower", False):
# TODO: better detector is needed.
rank0_print(f"The checkpoint seems to contain `vision_tower` weights: `unfreeze_mm_vision_tower`: True.")
self.load_model()
elif hasattr(vision_tower_cfg, "mm_tunable_parts") and "mm_vision_tower" in vision_tower_cfg.mm_tunable_parts:
rank0_print(f"The checkpoint seems to contain `vision_tower` weights: `mm_tunable_parts` contains `mm_vision_tower`.")
self.load_model()
else:
self.cfg_only = self.config
def load_model(self, device_map=None):
self.vision_model = "siglip"
# clip_model, processor = create_model_from_pretrained(self.vision_tower_name)
print(self.vision_tower_name)
self.vision_tower = SiglipVisionModel.from_pretrained(self.vision_tower_name)
# self.vision_tower = clip_model.visual.trunk
self.vision_tower.output_tokens = True
self._hidden_size = self.vision_tower.config.hidden_size
self.image_processor = SigLipImageProcessor()
del self.vision_tower.vision_model.encoder.layers[-1:]
self.vision_tower.vision_model.head = nn.Identity()
self.vision_tower.requires_grad_(self.unfreeze_mm_vision_tower)
self.is_loaded = True
def _forward(self, images):
with torch.set_grad_enabled(self.unfreeze_mm_vision_tower):
image_features = self.vision_tower.forward(
images.to(device=self.device, dtype=self.dtype),
output_hidden_states=True,
).hidden_states[-1]
return image_features
@property
def dummy_feature(self):
return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
@property
def dtype(self):
for p in self.vision_tower.parameters():
return p.dtype
@property
def device(self):
for p in self.vision_tower.parameters():
return p.device
@property
def hidden_size(self):
return self.config.hidden_size
@property
def num_patches(self):
return (336 // 14) ** 2
@property
def num_patches_per_side(self):
#return self.config.image_size // self.config.patch_size
return 336//14
#return 27
# return self.model_config["vision_cfg"]["image_size"] // self.model_config["vision_cfg"]["patch_size"]
@property
def image_size(self):
return 384
def build_vision_tower(vision_tower_cfg, **kwargs):
vision_tower = getattr(vision_tower_cfg, "mm_vision_tower", getattr(vision_tower_cfg, "vision_tower", None))
is_absolute_path_exists = os.path.exists(vision_tower)
use_s2 = getattr(vision_tower_cfg, "s2", False)
#print(getattr(vision_tower_cfg, "vision_tower", None))
return SigLipVisionTower(vision_tower, vision_tower_cfg=vision_tower_cfg, **kwargs)
if getattr(vision_tower_cfg, "vision_tower", None) and "siglip" in getattr(vision_tower_cfg, "vision_tower", None).lower():
#print('*************\n')
return SigLipVisionTower(vision_tower, vision_tower_cfg=vision_tower_cfg, **kwargs)
raise ValueError(f"Unknown vision tower: {vision_tower}")