diff --git a/minigpt4/configs/datasets/textvqa/default.yaml b/minigpt4/configs/datasets/textvqa/default.yaml new file mode 100755 index 0000000..13c4516 --- /dev/null +++ b/minigpt4/configs/datasets/textvqa/default.yaml @@ -0,0 +1,6 @@ +datasets: + textvqa: + data_type: images + build_info: + image_path: ./data/TextVQA_tiny/images + ann_path: ./data/TextVQA_tiny/train.json diff --git a/minigpt4/datasets/builders/image_text_pair_builder.py b/minigpt4/datasets/builders/image_text_pair_builder.py index 403de6e..7fc49c4 100644 --- a/minigpt4/datasets/builders/image_text_pair_builder.py +++ b/minigpt4/datasets/builders/image_text_pair_builder.py @@ -19,6 +19,7 @@ from minigpt4.datasets.datasets.coco_vqa_datasets import COCOVQADataset from minigpt4.datasets.datasets.ocrvqa_dataset import OCRVQADataset from minigpt4.datasets.datasets.coco_caption import COCOCapDataset from minigpt4.datasets.datasets.mvtec_dataset import MVTecDataset +from minigpt4.datasets.datasets.vqa_datasets import TextVQADataset @registry.register_builder("multitask_conversation") class MultitaskConversationBuilder(BaseDatasetBuilder): @@ -418,6 +419,28 @@ class MVTECADBuilder(BaseDatasetBuilder): return datasets +@registry.register_builder("textvqa") +class TextVQABuilder(BaseDatasetBuilder): + train_dataset_cls = TextVQADataset + DATASET_CONFIG_DICT = { + "default": "configs/datasets/textvqa/default.yaml", + } + def build_datasets(self): + logging.info("Building datasets...") + self.build_processors() + build_info = self.config.build_info + datasets = dict() + + # create datasets + dataset_cls = self.train_dataset_cls + datasets['train'] = dataset_cls( + vis_processor=self.vis_processors["train"], + text_processor=self.text_processors["train"], + vis_root=build_info.image_path, + ann_path=build_info.ann_path, + ) + return datasets + class DocumentVQABuilder(BaseDatasetBuilder): def _download_ann(self): diff --git a/minigpt4/datasets/datasets/vqa_datasets.py b/minigpt4/datasets/datasets/vqa_datasets.py index a8df3cc..27966d8 100755 --- a/minigpt4/datasets/datasets/vqa_datasets.py +++ b/minigpt4/datasets/datasets/vqa_datasets.py @@ -8,6 +8,8 @@ import torch from PIL import Image import os +import random +import json from minigpt4.datasets.datasets.base_dataset import BaseDataset @@ -22,6 +24,39 @@ class VQAEvalDataset(BaseDataset): super().__init__(vis_processor, text_processor, vis_root, ann_paths) +class TextVQADataset(torch.utils.data.Dataset): + def __init__(self, vis_processor, text_processor, vis_root, ann_path): + self.vis_root = vis_root + self.vis_processor = vis_processor + self.text_processor = text_processor + + self.instruction_pool =[ + "[vqa] {}", + "[vqa] Based on the image, respond to this question with a short answer: {}" + ] + with open(ann_path, 'r') as f: + self.data = json.load(f) + + def __len__(self): + return len(self.data) + + def __getitem__(self, index): + sample = self.data[index] + image_path = os.path.join(self.vis_root, f"{sample['image_id']}.jpg") + image = Image.open(image_path).convert("RGB") + image = self.vis_processor(image) + question = self.text_processor(sample["question"]) + answer = self.text_processor(sample["answer"]) + instruction = random.choice(self.instruction_pool).format(question) + instruction = " {} ".format(instruction) + return { + "image": image, + "instruction_input": instruction, + "answer": answer, + "image_id": sample['image_id'] + } + + class OKVQAEvalData(torch.utils.data.Dataset): def __init__(self, loaded_data, vis_processor, root_path): self.loaded_data = loaded_data