File size: 2,823 Bytes
b4942cf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import logging
import os
from typing import Dict, Sequence, Union, List

import torch
from PIL import Image
from torch.utils.data import Dataset
from transformers import PreTrainedTokenizer

from ovis.model.modeling_ovis import Ovis
from ovis.train.arguments import TrainingArguments
from ovis.util.constants import IGNORE_ID


class MultimodalDataset(Dataset):
    def __init__(self, name: str, info: Dict, model: Ovis, training_args: TrainingArguments):
        self.name = name
        self.meta_file = info['meta_file']
        self.image_dir = info['image_dir']
        self.caption_template = info.get('caption_template', None)
        self.text_tokenizer = model.get_text_tokenizer()
        self.visual_tokenizer = model.get_visual_tokenizer()
        self.image_height, self.image_width = self.visual_tokenizer.get_image_size()
        self.model = model
        self.text_max_length = training_args.text_max_length
        self.max_partitions = [int(m.strip()) for m in training_args.max_partitions.split('|')]
        self.samples = self.load()

    def load(self):
        raise NotImplementedError

    def __getitem__(self, i: int) -> Dict[str, torch.Tensor]:
        raise NotImplementedError

    def __len__(self):
        return len(self.samples)

    def read_image(self, path):
        try:
            full_path = os.path.join(self.image_dir, path)
            image = Image.open(full_path).convert('RGB')
            return image, None
        except Exception as e:
            return None, e


class DataCollatorForMultimodalDataset:
    def __init__(self, text_tokenizer: PreTrainedTokenizer):
        self.text_tokenizer = text_tokenizer

    def __call__(self, instances: Sequence[Dict]) -> Dict[str, Union[torch.Tensor, List[torch.Tensor]]]:
        pixel_values, input_ids, labels = tuple([instance[key] for instance in instances]
                                                for key in ("pixel_values", "input_ids", "labels"))
        input_ids = torch.nn.utils.rnn.pad_sequence(
            input_ids,
            batch_first=True,
            padding_value=self.text_tokenizer.pad_token_id)
        attention_mask = torch.ne(input_ids, self.text_tokenizer.pad_token_id)
        labels = torch.nn.utils.rnn.pad_sequence(
            labels,
            batch_first=True,
            padding_value=IGNORE_ID)
        num_valid_label = torch.not_equal(labels, IGNORE_ID).sum().item()
        if num_valid_label == 0:
            logging.warning(
                f'[DataCollatorForMultimodalDataset] All labels in a batch are ignored, which may lead to training instability\n{input_ids=}\n{attention_mask=}\n{labels=}')
        return dict(
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=labels,
            pixel_values=pixel_values
        )