File size: 2,823 Bytes
b4942cf |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 |
import logging
import os
from typing import Dict, Sequence, Union, List
import torch
from PIL import Image
from torch.utils.data import Dataset
from transformers import PreTrainedTokenizer
from ovis.model.modeling_ovis import Ovis
from ovis.train.arguments import TrainingArguments
from ovis.util.constants import IGNORE_ID
class MultimodalDataset(Dataset):
def __init__(self, name: str, info: Dict, model: Ovis, training_args: TrainingArguments):
self.name = name
self.meta_file = info['meta_file']
self.image_dir = info['image_dir']
self.caption_template = info.get('caption_template', None)
self.text_tokenizer = model.get_text_tokenizer()
self.visual_tokenizer = model.get_visual_tokenizer()
self.image_height, self.image_width = self.visual_tokenizer.get_image_size()
self.model = model
self.text_max_length = training_args.text_max_length
self.max_partitions = [int(m.strip()) for m in training_args.max_partitions.split('|')]
self.samples = self.load()
def load(self):
raise NotImplementedError
def __getitem__(self, i: int) -> Dict[str, torch.Tensor]:
raise NotImplementedError
def __len__(self):
return len(self.samples)
def read_image(self, path):
try:
full_path = os.path.join(self.image_dir, path)
image = Image.open(full_path).convert('RGB')
return image, None
except Exception as e:
return None, e
class DataCollatorForMultimodalDataset:
def __init__(self, text_tokenizer: PreTrainedTokenizer):
self.text_tokenizer = text_tokenizer
def __call__(self, instances: Sequence[Dict]) -> Dict[str, Union[torch.Tensor, List[torch.Tensor]]]:
pixel_values, input_ids, labels = tuple([instance[key] for instance in instances]
for key in ("pixel_values", "input_ids", "labels"))
input_ids = torch.nn.utils.rnn.pad_sequence(
input_ids,
batch_first=True,
padding_value=self.text_tokenizer.pad_token_id)
attention_mask = torch.ne(input_ids, self.text_tokenizer.pad_token_id)
labels = torch.nn.utils.rnn.pad_sequence(
labels,
batch_first=True,
padding_value=IGNORE_ID)
num_valid_label = torch.not_equal(labels, IGNORE_ID).sum().item()
if num_valid_label == 0:
logging.warning(
f'[DataCollatorForMultimodalDataset] All labels in a batch are ignored, which may lead to training instability\n{input_ids=}\n{attention_mask=}\n{labels=}')
return dict(
input_ids=input_ids,
attention_mask=attention_mask,
labels=labels,
pixel_values=pixel_values
)
|