|
from io import BytesIO |
|
from typing import Any, Dict, Optional, List |
|
import torch |
|
from PIL import Image |
|
from sentence_transformers.models import Transformer as BaseTransformer |
|
from transformers import AutoModelForVision2Seq, AutoProcessor |
|
|
|
|
|
class MultiModalTransformer(BaseTransformer): |
|
def __init__( |
|
self, |
|
model_name_or_path: str, |
|
cache_dir: Optional[str] = None, |
|
tokenizer_args: Optional[Dict[str, Any]] = None, |
|
min_image_tokens: int = 256, |
|
max_image_tokens: int = 1280, |
|
max_length: int = 1800, |
|
**kwargs, |
|
): |
|
super().__init__(model_name_or_path, **kwargs) |
|
if tokenizer_args is None: |
|
tokenizer_args = {} |
|
tokenizer_args.pop("trust_remote_code", None) |
|
|
|
|
|
min_pixels = min_image_tokens * 28 * 28 |
|
max_pixels = max_image_tokens * 28 * 28 |
|
self.processor = AutoProcessor.from_pretrained( |
|
model_name_or_path, min_pixels=min_pixels, max_pixels=max_pixels, **kwargs |
|
) |
|
self.processor.tokenizer.padding_side = 'right' |
|
self.sep = ' ' |
|
self.max_length = max_length |
|
self.normalize = True |
|
|
|
def _load_model( |
|
self, |
|
model_name_or_path: str, |
|
config, |
|
cache_dir: str, |
|
backend: str, |
|
is_peft_model: bool, |
|
**model_args, |
|
) -> None: |
|
model_args.pop("trust_remote_code", None) |
|
self.auto_model = AutoModelForVision2Seq.from_pretrained( |
|
model_name_or_path, torch_dtype=torch.float16, **model_args |
|
) |
|
|
|
def forward( |
|
self, features: Dict[str, torch.Tensor], **kwargs |
|
) -> Dict[str, torch.Tensor]: |
|
if features.get("inputs_embeds", None) is None: |
|
features["inputs_embeds"] = self.auto_model.base_model.embed_tokens(features["input_ids"]) |
|
if features.get("pixel_values", None) is not None: |
|
features["pixel_values"] = features["pixel_values"].type(self.auto_model.visual.get_dtype()) |
|
image_embeds = self.auto_model.visual( |
|
features["pixel_values"], grid_thw=features["image_grid_thw"] |
|
) |
|
image_mask = features["input_ids"] == self.auto_model.config.image_token_id |
|
features["inputs_embeds"][image_mask] = image_embeds |
|
|
|
|
|
|
|
inputs = {k: v for k, v in features.items() if k in 'position_ids,attention_mask,inputs_embeds'} |
|
outputs = self.auto_model.model( |
|
**inputs, |
|
return_dict=True, |
|
output_hidden_states=True, |
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
features.update({"token_embeddings": outputs.last_hidden_state}) |
|
return features |
|
|
|
def tokenize(self, texts: List[List[Dict[str, Any]]] | List[str]) -> Dict[str, torch.Tensor]: |
|
default_instruction = 'You are a helpful assistant.' |
|
|
|
all_texts, all_images = list(), list() |
|
for item in texts: |
|
if isinstance(item, str): |
|
txt, img, inst = item, None, default_instruction |
|
elif isinstance(item, dict): |
|
txt = item.get('text', None) |
|
img = item.get('image', None) |
|
inst = item.get('prompt', default_instruction) |
|
else: |
|
raise RuntimeError(f'Input format not supported! {item=}') |
|
|
|
input_str = '' |
|
if img is None: |
|
all_images = None |
|
|
|
else: |
|
input_str += '<|vision_start|><|image_pad|><|vision_end|>' |
|
img = fetch_image(img) |
|
all_images.append(img) |
|
if txt is not None: |
|
input_str += txt |
|
msg = f'<|im_start|>system\n{inst}<|im_end|>\n<|im_start|>user\n{input_str}<|im_end|>\n<|im_start|>assistant\n<|endoftext|>' |
|
all_texts.append(msg) |
|
|
|
inputs = self.processor( |
|
text=all_texts, |
|
images=all_images, |
|
padding="longest", |
|
truncation=True, |
|
max_length=self.max_seq_length, |
|
return_tensors='pt' |
|
) |
|
return inputs |
|
|
|
|
|
|
|
import base64 |
|
from io import BytesIO |
|
import requests |
|
|
|
IMAGE_FACTOR = 28 |
|
MIN_PIXELS = 4 * 28 * 28 |
|
MAX_PIXELS = 16384 * 28 * 28 |
|
MAX_RATIO = 200 |
|
|
|
|
|
def round_by_factor(number: int, factor: int) -> int: |
|
"""Returns the closest integer to 'number' that is divisible by 'factor'.""" |
|
return round(number / factor) * factor |
|
|
|
|
|
def ceil_by_factor(number: int, factor: int) -> int: |
|
"""Returns the smallest integer greater than or equal to 'number' that is divisible by 'factor'.""" |
|
return math.ceil(number / factor) * factor |
|
|
|
|
|
def floor_by_factor(number: int, factor: int) -> int: |
|
"""Returns the largest integer less than or equal to 'number' that is divisible by 'factor'.""" |
|
return math.floor(number / factor) * factor |
|
|
|
|
|
def smart_resize( |
|
height: int, width: int, factor: int = IMAGE_FACTOR, min_pixels: int = MIN_PIXELS, max_pixels: int = MAX_PIXELS |
|
) -> tuple[int, int]: |
|
""" |
|
Rescales the image so that the following conditions are met: |
|
|
|
1. Both dimensions (height and width) are divisible by 'factor'. |
|
|
|
2. The total number of pixels is within the range ['min_pixels', 'max_pixels']. |
|
|
|
3. The aspect ratio of the image is maintained as closely as possible. |
|
""" |
|
h_bar = max(factor, round_by_factor(height, factor)) |
|
w_bar = max(factor, round_by_factor(width, factor)) |
|
if h_bar * w_bar > max_pixels: |
|
beta = math.sqrt((height * width) / max_pixels) |
|
h_bar = floor_by_factor(height / beta, factor) |
|
w_bar = floor_by_factor(width / beta, factor) |
|
elif h_bar * w_bar < min_pixels: |
|
beta = math.sqrt(min_pixels / (height * width)) |
|
h_bar = ceil_by_factor(height * beta, factor) |
|
w_bar = ceil_by_factor(width * beta, factor) |
|
|
|
if max(h_bar, w_bar) / min(h_bar, w_bar) > MAX_RATIO: |
|
logging.warning( |
|
f"Absolute aspect ratio must be smaller than {MAX_RATIO}, got {max(h_bar, w_bar) / min(h_bar, w_bar)}" |
|
) |
|
if h_bar > w_bar: |
|
h_bar = w_bar * MAX_RATIO |
|
else: |
|
w_bar = h_bar * MAX_RATIO |
|
return h_bar, w_bar |
|
|
|
|
|
def fetch_image(image: str | Image.Image, size_factor: int = IMAGE_FACTOR) -> Image.Image: |
|
image_obj = None |
|
if isinstance(image, Image.Image): |
|
image_obj = image |
|
elif image.startswith("http://") or image.startswith("https://"): |
|
image_obj = Image.open(requests.get(image, stream=True).raw) |
|
elif image.startswith("file://"): |
|
image_obj = Image.open(image[7:]) |
|
elif image.startswith("data:image"): |
|
if "base64," in image: |
|
_, base64_data = image.split("base64,", 1) |
|
data = base64.b64decode(base64_data) |
|
image_obj = Image.open(BytesIO(data)) |
|
else: |
|
image_obj = Image.open(image) |
|
if image_obj is None: |
|
raise ValueError(f"Unrecognized image input, support local path, http url, base64 and PIL.Image, got {image}") |
|
image = image_obj.convert("RGB") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
width, height = image.size |
|
|
|
|
|
resized_height, resized_width = smart_resize( |
|
height, |
|
width, |
|
factor=size_factor, |
|
min_pixels=MIN_PIXELS, |
|
max_pixels=MAX_PIXELS, |
|
) |
|
image = image.resize((resized_width, resized_height)) |
|
|
|
return image |
|
|
|
|