diff --git a/minigpt4/configs/datasets/textvqa/default.yaml b/minigpt4/configs/datasets/textvqa/default.yaml
new file mode 100755
index 0000000..13c4516
--- /dev/null
+++ b/minigpt4/configs/datasets/textvqa/default.yaml
@@ -0,0 +1,6 @@
+datasets:
+ textvqa:
+ data_type: images
+ build_info:
+ image_path: ./data/TextVQA_tiny/images
+ ann_path: ./data/TextVQA_tiny/train.json
diff --git a/minigpt4/datasets/builders/image_text_pair_builder.py b/minigpt4/datasets/builders/image_text_pair_builder.py
index 403de6e..7fc49c4 100644
--- a/minigpt4/datasets/builders/image_text_pair_builder.py
+++ b/minigpt4/datasets/builders/image_text_pair_builder.py
@@ -19,6 +19,7 @@ from minigpt4.datasets.datasets.coco_vqa_datasets import COCOVQADataset
from minigpt4.datasets.datasets.ocrvqa_dataset import OCRVQADataset
from minigpt4.datasets.datasets.coco_caption import COCOCapDataset
from minigpt4.datasets.datasets.mvtec_dataset import MVTecDataset
+from minigpt4.datasets.datasets.vqa_datasets import TextVQADataset
@registry.register_builder("multitask_conversation")
class MultitaskConversationBuilder(BaseDatasetBuilder):
@@ -418,6 +419,28 @@ class MVTECADBuilder(BaseDatasetBuilder):
return datasets
+@registry.register_builder("textvqa")
+class TextVQABuilder(BaseDatasetBuilder):
+ train_dataset_cls = TextVQADataset
+ DATASET_CONFIG_DICT = {
+ "default": "configs/datasets/textvqa/default.yaml",
+ }
+ def build_datasets(self):
+ logging.info("Building datasets...")
+ self.build_processors()
+ build_info = self.config.build_info
+ datasets = dict()
+
+ # create datasets
+ dataset_cls = self.train_dataset_cls
+ datasets['train'] = dataset_cls(
+ vis_processor=self.vis_processors["train"],
+ text_processor=self.text_processors["train"],
+ vis_root=build_info.image_path,
+ ann_path=build_info.ann_path,
+ )
+ return datasets
+
class DocumentVQABuilder(BaseDatasetBuilder):
def _download_ann(self):
diff --git a/minigpt4/datasets/datasets/vqa_datasets.py b/minigpt4/datasets/datasets/vqa_datasets.py
index a8df3cc..27966d8 100755
--- a/minigpt4/datasets/datasets/vqa_datasets.py
+++ b/minigpt4/datasets/datasets/vqa_datasets.py
@@ -8,6 +8,8 @@
import torch
from PIL import Image
import os
+import random
+import json
from minigpt4.datasets.datasets.base_dataset import BaseDataset
@@ -22,6 +24,39 @@ class VQAEvalDataset(BaseDataset):
super().__init__(vis_processor, text_processor, vis_root, ann_paths)
+class TextVQADataset(torch.utils.data.Dataset):
+ def __init__(self, vis_processor, text_processor, vis_root, ann_path):
+ self.vis_root = vis_root
+ self.vis_processor = vis_processor
+ self.text_processor = text_processor
+
+ self.instruction_pool =[
+ "[vqa] {}",
+ "[vqa] Based on the image, respond to this question with a short answer: {}"
+ ]
+ with open(ann_path, 'r') as f:
+ self.data = json.load(f)
+
+ def __len__(self):
+ return len(self.data)
+
+ def __getitem__(self, index):
+ sample = self.data[index]
+ image_path = os.path.join(self.vis_root, f"{sample['image_id']}.jpg")
+ image = Image.open(image_path).convert("RGB")
+ image = self.vis_processor(image)
+ question = self.text_processor(sample["question"])
+ answer = self.text_processor(sample["answer"])
+ instruction = random.choice(self.instruction_pool).format(question)
+ instruction = "
{} ".format(instruction)
+ return {
+ "image": image,
+ "instruction_input": instruction,
+ "answer": answer,
+ "image_id": sample['image_id']
+ }
+
+
class OKVQAEvalData(torch.utils.data.Dataset):
def __init__(self, loaded_data, vis_processor, root_path):
self.loaded_data = loaded_data