File size: 3,124 Bytes
f2c2a4e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 |
import torch
from PIL import Image
from torch.utils.data import Dataset
import models.config as cfg
class VQADataset(Dataset): # Visual Question Answering Dataset
def __init__(self, dataset, tokenizer, image_processor):
self.dataset = dataset
self.tokenizer = tokenizer
self.image_processor = image_processor
def __len__(self):
return len(self.dataset)
def __getitem__(self, idx):
item = self.dataset[idx]
# Handle image (it's a list)
image_data = item['images']
if isinstance(image_data, list) and len(image_data) > 0:
image = image_data[0]
else:
image = image_data
# Now process the image
if isinstance(image, Image.Image):
if image.mode != 'RGB':
image = image.convert('RGB')
processed_image = self.image_processor(image)
else:
print(f"Error processing image at index {idx}")
# Create empty tensor with right dimensions as fallback
processed_image = torch.zeros(
3, cfg.VLMConfig.vit_img_size, cfg.VLMConfig.vit_img_size)
# Process text (also a list)
text_data = item['texts']
if isinstance(text_data, list) and len(text_data) > 0:
text = text_data[0]
else:
text = text_data
question = text['user']
# Add EOS token to the answer to train model to predict it, enabling correct stopping during generation
answer = text['assistant'] + self.tokenizer.eos_token
formatted_text = f"Question: {question} Answer:"
return {
"image": processed_image,
"text_data": formatted_text,
"answer": answer
}
class MMStarDataset(Dataset): # https://huggingface.co/datasets/Lin-Chen/MMStar
def __init__(self, dataset, tokenizer, image_processor):
self.dataset = dataset
self.tokenizer = tokenizer
self.image_processor = image_processor
def __len__(self):
return len(self.dataset)
def __getitem__(self, idx):
item = self.dataset[idx]
image = item['image']
# Now process the image
if isinstance(image, Image.Image):
if image.mode != 'RGB':
image = image.convert('RGB')
processed_image = self.image_processor(image)
else:
print(f"Error processing image at index {idx}")
# Create empty tensor with right dimensions as fallback
processed_image = torch.zeros(3, cfg.VLMConfig.vit_img_size, cfg.VLMConfig.vit_img_size)
question = item['question']
answer = item['answer'] + self.tokenizer.eos_token # Add EOS token to the answer to train model to predict it, enabling correct stopping during generation
formatted_text = f"Question: {question} \nAnswer only with the letter! \nAnswer:"
return {
"image": processed_image,
"text_data": formatted_text,
"answer": answer
}
|