""" Copyright (c) 2022, salesforce.com, inc. All rights reserved. SPDX-License-Identifier: BSD-3-Clause For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause """ import torch from PIL import Image from minigpt4.datasets.datasets.vqa_datasets import VQADataset, VQAEvalDataset from collections import OrderedDict # class textVQADataset(VQADataset): # def __init__(self, vis_processor, text_processor, vis_root, ann_paths): # super().__init__(vis_processor, text_processor, vis_root, ann_paths) # def collater(self, samples): # image_list, question_list, answer_list, weight_list = [], [], [], [] # num_answers = [] # for sample in samples: # image_list.append(sample["image"]) # question_list.append(sample["text_input"]) # weight_list.extend(sample["weights"]) # answers = sample["answers"] # answer_list.extend(answers) # num_answers.append(len(answers)) # return { # "image": torch.stack(image_list, dim=0), # "text_input": question_list, # "answer": answer_list, # "weight": torch.Tensor(weight_list), # "n_answers": torch.LongTensor(num_answers), # } from minigpt4.datasets.datasets.vqa_datasets import VQADataset, VQAEvalDataset class textVQAEvalDataset(VQADataset): def __init__(self, vis_processor, text_processor, vis_root=None, ann_paths=None): # super().__init__(vis_processor, text_processor, vis_root, ann_paths) from datasets import load_dataset self.annotation = load_dataset("textvqa", split="validation") def __getitem__(self, index): ann = self.annotation[index] image = ann["image"].convert("RGB") image = self.vis_processor(image) question = self.text_processor(ann["question"]) instruction = random.choice(self.instruction_pool).format(question) instruction = " {} ".format(instruction) print("instruction", instruction) answers = ann["answers"] if "unk" in answers: print(answers) return { "image": image, "text_input": question, "answer": answers, # 'image_path': image_path, "instruction_input": instruction, "question_id": ann["question_id"], "instance_id": ann["instance_id"], } dataset = textVQAEvalDataset(vis_processor, text_processor) dataloader = torch.utils.data.DataLoader(dataset, batch_size=1)