add textvqa for training

This commit is contained in:
ThuanNaN 2025-01-20 07:23:40 +07:00
parent 006a7b863b
commit bb2a833fc1
3 changed files with 64 additions and 0 deletions

View File

@ -0,0 +1,6 @@
datasets:
textvqa:
data_type: images
build_info:
image_path: ./data/TextVQA_tiny/images
ann_path: ./data/TextVQA_tiny/train.json

View File

@ -19,6 +19,7 @@ from minigpt4.datasets.datasets.coco_vqa_datasets import COCOVQADataset
from minigpt4.datasets.datasets.ocrvqa_dataset import OCRVQADataset from minigpt4.datasets.datasets.ocrvqa_dataset import OCRVQADataset
from minigpt4.datasets.datasets.coco_caption import COCOCapDataset from minigpt4.datasets.datasets.coco_caption import COCOCapDataset
from minigpt4.datasets.datasets.mvtec_dataset import MVTecDataset from minigpt4.datasets.datasets.mvtec_dataset import MVTecDataset
from minigpt4.datasets.datasets.vqa_datasets import TextVQADataset
@registry.register_builder("multitask_conversation") @registry.register_builder("multitask_conversation")
class MultitaskConversationBuilder(BaseDatasetBuilder): class MultitaskConversationBuilder(BaseDatasetBuilder):
@ -418,6 +419,28 @@ class MVTECADBuilder(BaseDatasetBuilder):
return datasets return datasets
@registry.register_builder("textvqa")
class TextVQABuilder(BaseDatasetBuilder):
train_dataset_cls = TextVQADataset
DATASET_CONFIG_DICT = {
"default": "configs/datasets/textvqa/default.yaml",
}
def build_datasets(self):
logging.info("Building datasets...")
self.build_processors()
build_info = self.config.build_info
datasets = dict()
# create datasets
dataset_cls = self.train_dataset_cls
datasets['train'] = dataset_cls(
vis_processor=self.vis_processors["train"],
text_processor=self.text_processors["train"],
vis_root=build_info.image_path,
ann_path=build_info.ann_path,
)
return datasets
class DocumentVQABuilder(BaseDatasetBuilder): class DocumentVQABuilder(BaseDatasetBuilder):
def _download_ann(self): def _download_ann(self):

View File

@ -8,6 +8,8 @@
import torch import torch
from PIL import Image from PIL import Image
import os import os
import random
import json
from minigpt4.datasets.datasets.base_dataset import BaseDataset from minigpt4.datasets.datasets.base_dataset import BaseDataset
@ -22,6 +24,39 @@ class VQAEvalDataset(BaseDataset):
super().__init__(vis_processor, text_processor, vis_root, ann_paths) super().__init__(vis_processor, text_processor, vis_root, ann_paths)
class TextVQADataset(torch.utils.data.Dataset):
def __init__(self, vis_processor, text_processor, vis_root, ann_path):
self.vis_root = vis_root
self.vis_processor = vis_processor
self.text_processor = text_processor
self.instruction_pool =[
"[vqa] {}",
"[vqa] Based on the image, respond to this question with a short answer: {}"
]
with open(ann_path, 'r') as f:
self.data = json.load(f)
def __len__(self):
return len(self.data)
def __getitem__(self, index):
sample = self.data[index]
image_path = os.path.join(self.vis_root, f"{sample['image_id']}.jpg")
image = Image.open(image_path).convert("RGB")
image = self.vis_processor(image)
question = self.text_processor(sample["question"])
answer = self.text_processor(sample["answer"])
instruction = random.choice(self.instruction_pool).format(question)
instruction = "<Img><ImageHere></Img> {} ".format(instruction)
return {
"image": image,
"instruction_input": instruction,
"answer": answer,
"image_id": sample['image_id']
}
class OKVQAEvalData(torch.utils.data.Dataset): class OKVQAEvalData(torch.utils.data.Dataset):
def __init__(self, loaded_data, vis_processor, root_path): def __init__(self, loaded_data, vis_processor, root_path):
self.loaded_data = loaded_data self.loaded_data = loaded_data