mirror of
https://github.com/Vision-CAIR/MiniGPT-4.git
synced 2025-04-06 19:10:45 +00:00
69 lines
2.0 KiB
Python
69 lines
2.0 KiB
Python
"""
|
|
Copyright (c) 2022, salesforce.com, inc.
|
|
All rights reserved.
|
|
SPDX-License-Identifier: BSD-3-Clause
|
|
For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
|
"""
|
|
|
|
import json
|
|
from typing import Iterable
|
|
|
|
from torch.utils.data import Dataset, ConcatDataset
|
|
from torch.utils.data.dataloader import default_collate
|
|
|
|
|
|
class BaseDataset(Dataset):
|
|
def __init__(
|
|
self, vis_processor=None, text_processor=None, vis_root=None, ann_paths=[]
|
|
):
|
|
"""
|
|
vis_root (string): Root directory of images (e.g. coco/images/)
|
|
ann_root (string): directory to store the annotation file
|
|
"""
|
|
self.vis_root = vis_root
|
|
|
|
self.annotation = []
|
|
for ann_path in ann_paths:
|
|
self.annotation.extend(json.load(open(ann_path, "r"))['annotations'])
|
|
|
|
self.vis_processor = vis_processor
|
|
self.text_processor = text_processor
|
|
|
|
self._add_instance_ids()
|
|
|
|
def __len__(self):
|
|
return len(self.annotation)
|
|
|
|
def collater(self, samples):
|
|
return default_collate(samples)
|
|
|
|
def set_processors(self, vis_processor, text_processor):
|
|
self.vis_processor = vis_processor
|
|
self.text_processor = text_processor
|
|
|
|
def _add_instance_ids(self, key="instance_id"):
|
|
for idx, ann in enumerate(self.annotation):
|
|
ann[key] = str(idx)
|
|
|
|
|
|
class ConcatDataset(ConcatDataset):
|
|
def __init__(self, datasets: Iterable[Dataset]) -> None:
|
|
super().__init__(datasets)
|
|
|
|
def collater(self, samples):
|
|
# TODO For now only supports datasets with same underlying collater implementations
|
|
|
|
all_keys = set()
|
|
for s in samples:
|
|
all_keys.update(s)
|
|
|
|
shared_keys = all_keys
|
|
for s in samples:
|
|
shared_keys = shared_keys & set(s.keys())
|
|
|
|
samples_shared_keys = []
|
|
for s in samples:
|
|
samples_shared_keys.append({k: s[k] for k in s.keys() if k in shared_keys})
|
|
|
|
return self.datasets[0].collater(samples_shared_keys)
|