mirror of
https://github.com/Vision-CAIR/MiniGPT-4.git
synced 2025-04-05 02:20:47 +00:00
add textvqa for training
This commit is contained in:
parent
006a7b863b
commit
bb2a833fc1
6
minigpt4/configs/datasets/textvqa/default.yaml
Executable file
6
minigpt4/configs/datasets/textvqa/default.yaml
Executable file
@ -0,0 +1,6 @@
|
|||||||
|
datasets:
|
||||||
|
textvqa:
|
||||||
|
data_type: images
|
||||||
|
build_info:
|
||||||
|
image_path: ./data/TextVQA_tiny/images
|
||||||
|
ann_path: ./data/TextVQA_tiny/train.json
|
@ -19,6 +19,7 @@ from minigpt4.datasets.datasets.coco_vqa_datasets import COCOVQADataset
|
|||||||
from minigpt4.datasets.datasets.ocrvqa_dataset import OCRVQADataset
|
from minigpt4.datasets.datasets.ocrvqa_dataset import OCRVQADataset
|
||||||
from minigpt4.datasets.datasets.coco_caption import COCOCapDataset
|
from minigpt4.datasets.datasets.coco_caption import COCOCapDataset
|
||||||
from minigpt4.datasets.datasets.mvtec_dataset import MVTecDataset
|
from minigpt4.datasets.datasets.mvtec_dataset import MVTecDataset
|
||||||
|
from minigpt4.datasets.datasets.vqa_datasets import TextVQADataset
|
||||||
|
|
||||||
@registry.register_builder("multitask_conversation")
|
@registry.register_builder("multitask_conversation")
|
||||||
class MultitaskConversationBuilder(BaseDatasetBuilder):
|
class MultitaskConversationBuilder(BaseDatasetBuilder):
|
||||||
@ -418,6 +419,28 @@ class MVTECADBuilder(BaseDatasetBuilder):
|
|||||||
return datasets
|
return datasets
|
||||||
|
|
||||||
|
|
||||||
|
@registry.register_builder("textvqa")
|
||||||
|
class TextVQABuilder(BaseDatasetBuilder):
|
||||||
|
train_dataset_cls = TextVQADataset
|
||||||
|
DATASET_CONFIG_DICT = {
|
||||||
|
"default": "configs/datasets/textvqa/default.yaml",
|
||||||
|
}
|
||||||
|
def build_datasets(self):
|
||||||
|
logging.info("Building datasets...")
|
||||||
|
self.build_processors()
|
||||||
|
build_info = self.config.build_info
|
||||||
|
datasets = dict()
|
||||||
|
|
||||||
|
# create datasets
|
||||||
|
dataset_cls = self.train_dataset_cls
|
||||||
|
datasets['train'] = dataset_cls(
|
||||||
|
vis_processor=self.vis_processors["train"],
|
||||||
|
text_processor=self.text_processors["train"],
|
||||||
|
vis_root=build_info.image_path,
|
||||||
|
ann_path=build_info.ann_path,
|
||||||
|
)
|
||||||
|
return datasets
|
||||||
|
|
||||||
|
|
||||||
class DocumentVQABuilder(BaseDatasetBuilder):
|
class DocumentVQABuilder(BaseDatasetBuilder):
|
||||||
def _download_ann(self):
|
def _download_ann(self):
|
||||||
|
@ -8,6 +8,8 @@
|
|||||||
import torch
|
import torch
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
import os
|
import os
|
||||||
|
import random
|
||||||
|
import json
|
||||||
|
|
||||||
from minigpt4.datasets.datasets.base_dataset import BaseDataset
|
from minigpt4.datasets.datasets.base_dataset import BaseDataset
|
||||||
|
|
||||||
@ -22,6 +24,39 @@ class VQAEvalDataset(BaseDataset):
|
|||||||
super().__init__(vis_processor, text_processor, vis_root, ann_paths)
|
super().__init__(vis_processor, text_processor, vis_root, ann_paths)
|
||||||
|
|
||||||
|
|
||||||
|
class TextVQADataset(torch.utils.data.Dataset):
|
||||||
|
def __init__(self, vis_processor, text_processor, vis_root, ann_path):
|
||||||
|
self.vis_root = vis_root
|
||||||
|
self.vis_processor = vis_processor
|
||||||
|
self.text_processor = text_processor
|
||||||
|
|
||||||
|
self.instruction_pool =[
|
||||||
|
"[vqa] {}",
|
||||||
|
"[vqa] Based on the image, respond to this question with a short answer: {}"
|
||||||
|
]
|
||||||
|
with open(ann_path, 'r') as f:
|
||||||
|
self.data = json.load(f)
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.data)
|
||||||
|
|
||||||
|
def __getitem__(self, index):
|
||||||
|
sample = self.data[index]
|
||||||
|
image_path = os.path.join(self.vis_root, f"{sample['image_id']}.jpg")
|
||||||
|
image = Image.open(image_path).convert("RGB")
|
||||||
|
image = self.vis_processor(image)
|
||||||
|
question = self.text_processor(sample["question"])
|
||||||
|
answer = self.text_processor(sample["answer"])
|
||||||
|
instruction = random.choice(self.instruction_pool).format(question)
|
||||||
|
instruction = "<Img><ImageHere></Img> {} ".format(instruction)
|
||||||
|
return {
|
||||||
|
"image": image,
|
||||||
|
"instruction_input": instruction,
|
||||||
|
"answer": answer,
|
||||||
|
"image_id": sample['image_id']
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
class OKVQAEvalData(torch.utils.data.Dataset):
|
class OKVQAEvalData(torch.utils.data.Dataset):
|
||||||
def __init__(self, loaded_data, vis_processor, root_path):
|
def __init__(self, loaded_data, vis_processor, root_path):
|
||||||
self.loaded_data = loaded_data
|
self.loaded_data = loaded_data
|
||||||
|
Loading…
Reference in New Issue
Block a user