rename some components for further extension

This commit is contained in:
unknown 2023-05-22 18:48:43 +08:00
parent 49f6b84880
commit d2ba3c48b6
19 changed files with 292 additions and 101 deletions

View File

@ -32,10 +32,10 @@ class Registry:
"""
def wrap(builder_cls):
from minigpt4.datasets.builders.base_dataset_builder import BaseDatasetBuilder
from minigpt4.datasets.builders.image_base_dataset_builder import ImageBaseDatasetBuilder
assert issubclass(
builder_cls, BaseDatasetBuilder
builder_cls, ImageBaseDatasetBuilder
), "All builders must inherit BaseDatasetBuilder class, found {}".format(
builder_cls
)

View File

@ -5,7 +5,7 @@ model:
freeze_imagebind: True
# Q-Former
freeze_qformer: True
freeze_qformer: False
num_query_token: 32
# Vicuna

View File

@ -5,18 +5,18 @@
For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""
from minigpt4.datasets.builders.base_dataset_builder import load_dataset_config
from minigpt4.datasets.builders.image_base_dataset_builder import load_dataset_config
from minigpt4.datasets.builders.image_text_pair_builder import (
CCSBUBuilder,
LaionBuilder,
CCSBUAlignBuilder
CCSBUBuilderImage,
LaionBuilderImage,
CCSBUAlignBuilderImage
)
from minigpt4.common.registry import registry
__all__ = [
"CCSBUBuilder",
"LaionBuilder",
"CCSBUAlignBuilder"
"CCSBUBuilderImage",
"LaionBuilderImage",
"CCSBUAlignBuilderImage"
]

View File

@ -0,0 +1,99 @@
import logging
import os
import shutil
import warnings
from omegaconf import OmegaConf
import torch.distributed as dist
from torchvision.datasets.utils import download_url
import minigpt4.common.utils as utils
from minigpt4.common.dist_utils import is_dist_avail_and_initialized, is_main_process
from minigpt4.common.registry import registry
from minigpt4.datasets.builders import load_dataset_config
from minigpt4.processors.base_processor import BaseProcessor
class AudioBaseDatasetBuilder:
train_dataset_cls, eval_dataset_cls = None, None
def __init__(self, cfg=None):
super().__init__()
if cfg is None:
# help to create datasets from default config.
self.config = load_dataset_config(self.default_config_path())
elif isinstance(cfg, str):
self.config = load_dataset_config(cfg)
else:
# when called from task.build_dataset()
self.config = cfg
self.data_type = self.config.data_type
@classmethod
def default_config_path(cls, type="default"):
return utils.get_abs_path(cls.DATASET_CONFIG_DICT[type])
def _download_data(self):
self._download_ann()
self._download_aud()
def _download_ann(self):
"""
Download annotation files if necessary.
All the audio-language datasets should have annotations of unified format.
storage_path can be:
(1) relative/absolute: will be prefixed with env.cache_root to make full path if relative.
(2) basename/dirname: will be suffixed with base name of URL if dirname is provided.
Local annotation paths should be relative.
"""
anns = self.config.build_info.annotations
splits = anns.keys()
cache_root = registry.get_path("cache_root")
for split in splits:
info = anns[split]
urls, storage_paths = info.get("url", None), info.storage
if isinstance(urls, str):
urls = [urls]
if isinstance(storage_paths, str):
storage_paths = [storage_paths]
assert len(urls) == len(storage_paths)
for url_or_filename, storage_path in zip(urls, storage_paths):
# if storage_path is relative, make it full by prefixing with cache_root.
if not os.path.isabs(storage_path):
storage_path = os.path.join(cache_root, storage_path)
dirname = os.path.dirname(storage_path)
if not os.path.exists(dirname):
os.makedirs(dirname)
if os.path.isfile(url_or_filename):
src, dst = url_or_filename, storage_path
if not os.path.exists(dst):
shutil.copyfile(src=src, dst=dst)
else:
logging.info("Using existing file {}.".format(dst))
else:
if os.path.isdir(storage_path):
# if only dirname is provided, suffix with basename of URL.
raise ValueError(
"Expecting storage_path to be a file path, got directory {}".format(
storage_path
)
)
else:
filename = os.path.basename(storage_path)
download_url(url=url_or_filename, root=dirname, filename=filename)

View File

@ -21,8 +21,7 @@ from minigpt4.common.registry import registry
from minigpt4.processors.base_processor import BaseProcessor
class BaseDatasetBuilder:
class ImageBaseDatasetBuilder:
train_dataset_cls, eval_dataset_cls = None, None
def __init__(self, cfg=None):

View File

@ -3,13 +3,13 @@ import logging
import warnings
from minigpt4.common.registry import registry
from minigpt4.datasets.builders.base_dataset_builder import BaseDatasetBuilder
from minigpt4.datasets.datasets.laion_dataset import LaionDataset
from minigpt4.datasets.datasets.cc_sbu_dataset import CCSBUDataset, CCSBUAlignDataset
from minigpt4.datasets.builders.image_base_dataset_builder import ImageBaseDatasetBuilder
from minigpt4.datasets.datasets.image_caption.laion_dataset import LaionDataset
from minigpt4.datasets.datasets.image_caption.cc_sbu_dataset import CCSBUDataset, CCSBUAlignDatasetImageImageCaptionDataset
@registry.register_builder("cc_sbu")
class CCSBUBuilder(BaseDatasetBuilder):
class CCSBUBuilderImage(ImageBaseDatasetBuilder):
train_dataset_cls = CCSBUDataset
DATASET_CONFIG_DICT = {"default": "configs/datasets/cc_sbu/defaults.yaml"}
@ -41,7 +41,7 @@ class CCSBUBuilder(BaseDatasetBuilder):
@registry.register_builder("laion")
class LaionBuilder(BaseDatasetBuilder):
class LaionBuilderImage(ImageBaseDatasetBuilder):
train_dataset_cls = LaionDataset
DATASET_CONFIG_DICT = {"default": "configs/datasets/laion/defaults.yaml"}
@ -73,8 +73,8 @@ class LaionBuilder(BaseDatasetBuilder):
@registry.register_builder("cc_sbu_align")
class CCSBUAlignBuilder(BaseDatasetBuilder):
train_dataset_cls = CCSBUAlignDataset
class CCSBUAlignBuilderImage(ImageBaseDatasetBuilder):
train_dataset_cls = CCSBUAlignDatasetImageImageCaptionDataset
DATASET_CONFIG_DICT = {
"default": "configs/datasets/cc_sbu/align.yaml",

View File

@ -5,32 +5,44 @@
For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""
import gzip
import logging
import os
import random as rnd
import tarfile
import zipfile
import random
from typing import List
from tqdm import tqdm
from typing import List, Iterable
import decord
from decord import VideoReader
import webdataset as wds
import numpy as np
import torch
from torch.utils.data.dataset import IterableDataset
from torch.utils.data import IterableDataset, Dataset, ConcatDataset
from minigpt4.common.registry import registry
from minigpt4.datasets.datasets.base_dataset import ConcatDataset
decord.bridge.set_bridge("torch")
MAX_INT = registry.get("MAX_INT")
class ChainDataset(wds.DataPipeline):
class WrappedConcatDataset(ConcatDataset):
def __init__(self, datasets: Iterable[Dataset]) -> None:
super().__init__(datasets)
def collater(self, samples):
# TODO For now only supports datasets with same underlying collater implementations
all_keys = set()
for s in samples:
all_keys.update(s)
shared_keys = all_keys
for s in samples:
shared_keys = shared_keys & set(s.keys())
samples_shared_keys = []
for s in samples:
samples_shared_keys.append({k: s[k] for k in s.keys() if k in shared_keys})
return self.datasets[0].collater(samples_shared_keys)
class WrappedChainDataset(wds.DataPipeline):
r"""Dataset for chaining multiple :class:`DataPipeline` s.
This class is useful to assemble different existing dataset streams. The
@ -40,6 +52,7 @@ class ChainDataset(wds.DataPipeline):
Args:
datasets (iterable of IterableDataset): datasets to be chained together
"""
def __init__(self, datasets: List[wds.DataPipeline]) -> None:
super().__init__()
self.datasets = datasets
@ -149,7 +162,7 @@ def concat_datasets(datasets):
for split_name in datasets:
if split_name != "train":
assert (
len(datasets[split_name]) == 1
len(datasets[split_name]) == 1
), "Do not support multiple {} datasets.".format(split_name)
datasets[split_name] = datasets[split_name][0]
else:
@ -173,7 +186,7 @@ def concat_datasets(datasets):
# concatenate map-style datasets and iterable-style datasets separately
if len(iterable_datasets) > 1:
chained_datasets = (
ChainDataset(iterable_datasets)
WrappedChainDataset(iterable_datasets)
)
elif len(iterable_datasets) == 1:
chained_datasets = iterable_datasets[0]
@ -181,7 +194,7 @@ def concat_datasets(datasets):
chained_datasets = None
concat_datasets = (
ConcatDataset(map_datasets) if len(map_datasets) > 0 else None
WrappedConcatDataset(map_datasets) if len(map_datasets) > 0 else None
)
train_datasets = concat_datasets, chained_datasets
@ -193,4 +206,3 @@ def concat_datasets(datasets):
datasets[split_name] = train_datasets
return datasets

View File

@ -0,0 +1,26 @@
import json
from torch.utils.data import Dataset, default_collate
import webdataset as wds
from minigpt4.datasets.datasets.base_dataset import BaseDataset
class GenericAudioDataset(BaseDataset):
def __init__(self, vision_processor, text_processor, location):
super().__init__(x_processor=vision_processor, text_processor=text_processor)
self.inner_dataset = wds.DataPipeline(
wds.ResampledShards(location),
wds.tarfile_to_samples(handler=wds.warn_and_continue),
wds.shuffle(1000, handler=wds.warn_and_continue),
wds.decode(wds.torch_audio, handler=wds.warn_and_continue),
wds.to_tuple("jpg", "json", handler=wds.warn_and_continue),
wds.map_tuple(self.x_processor, handler=wds.warn_and_continue),
wds.map(self.to_dict, handler=wds.warn_and_continue),
)
def to_dict(self, sample):
return {
"image": sample[0],
"text_input": self.text_processor(sample[1]["caption"]),
}

View File

@ -8,25 +8,25 @@
import json
from typing import Iterable
from torch.utils.data import Dataset, ConcatDataset
from torch.utils.data import Dataset
from torch.utils.data.dataloader import default_collate
class BaseDataset(Dataset):
def __init__(
self, vis_processor=None, text_processor=None, vis_root=None, ann_paths=[]
self, x_processor=None, text_processor=None, x_root=None, ann_paths=[]
):
"""
vis_root (string): Root directory of images (e.g. coco/images/)
x_root (string): Root directory of data in modality X (e.g. coco/images/, etc.)
ann_root (string): directory to store the annotation file
"""
self.vis_root = vis_root
self.x_root = x_root
self.annotation = []
for ann_path in ann_paths:
self.annotation.extend(json.load(open(ann_path, "r"))['annotations'])
self.vis_processor = vis_processor
self.x_processor = x_processor
self.text_processor = text_processor
self._add_instance_ids()
@ -37,32 +37,11 @@ class BaseDataset(Dataset):
def collater(self, samples):
return default_collate(samples)
def set_processors(self, vis_processor, text_processor):
self.vis_processor = vis_processor
def set_processors(self, x_processor, text_processor):
self.x_processor = x_processor
self.text_processor = text_processor
def _add_instance_ids(self, key="instance_id"):
for idx, ann in enumerate(self.annotation):
ann[key] = str(idx)
class ConcatDataset(ConcatDataset):
def __init__(self, datasets: Iterable[Dataset]) -> None:
super().__init__(datasets)
def collater(self, samples):
# TODO For now only supports datasets with same underlying collater implementations
all_keys = set()
for s in samples:
all_keys.update(s)
shared_keys = all_keys
for s in samples:
shared_keys = shared_keys & set(s.keys())
samples_shared_keys = []
for s in samples:
samples_shared_keys.append({k: s[k] for k in s.keys() if k in shared_keys})
return self.datasets[0].collater(samples_shared_keys)

View File

@ -2,12 +2,12 @@ import os
from PIL import Image
import webdataset as wds
from minigpt4.datasets.datasets.base_dataset import BaseDataset
from minigpt4.datasets.datasets.caption_datasets import CaptionDataset
from minigpt4.datasets.datasets.image_caption.image_caption_datasets import ImageCaptionDataset
class CCSBUDataset(BaseDataset):
def __init__(self, vis_processor, text_processor, location):
super().__init__(vis_processor=vis_processor, text_processor=text_processor)
def __init__(self, vision_processor, text_processor, location):
super().__init__(x_processor=vision_processor, text_processor=text_processor)
self.inner_dataset = wds.DataPipeline(
wds.ResampledShards(location),
@ -15,7 +15,7 @@ class CCSBUDataset(BaseDataset):
wds.shuffle(1000, handler=wds.warn_and_continue),
wds.decode("pilrgb", handler=wds.warn_and_continue),
wds.to_tuple("jpg", "json", handler=wds.warn_and_continue),
wds.map_tuple(self.vis_processor, handler=wds.warn_and_continue),
wds.map_tuple(self.x_processor, handler=wds.warn_and_continue),
wds.map(self.to_dict, handler=wds.warn_and_continue),
)
@ -26,7 +26,7 @@ class CCSBUDataset(BaseDataset):
}
class CCSBUAlignDataset(CaptionDataset):
class CCSBUAlignDatasetImageImageCaptionDataset(ImageCaptionDataset):
def __getitem__(self, index):
@ -34,10 +34,10 @@ class CCSBUAlignDataset(CaptionDataset):
ann = self.annotation[index]
img_file = '{}.jpg'.format(ann["image_id"])
image_path = os.path.join(self.vis_root, img_file)
image_path = os.path.join(self.x_root, img_file)
image = Image.open(image_path).convert("RGB")
image = self.vis_processor(image)
image = self.x_processor(image)
caption = ann["caption"]
return {

View File

@ -6,32 +6,20 @@
"""
import os
from collections import OrderedDict
from minigpt4.datasets.datasets.base_dataset import BaseDataset
from PIL import Image
class __DisplMixin:
def displ_item(self, index):
sample, ann = self.__getitem__(index), self.annotation[index]
return OrderedDict(
{
"file": ann["image"],
"caption": ann["caption"],
"image": sample["image"],
}
)
from minigpt4.datasets.datasets.mixins.mixins import __ImageDisplMixin
class CaptionDataset(BaseDataset, __DisplMixin):
def __init__(self, vis_processor, text_processor, vis_root, ann_paths):
class ImageCaptionDataset(BaseDataset, __ImageDisplMixin):
def __init__(self, vision_processor, text_processor, x_root, ann_paths):
"""
vis_root (string): Root directory of images (e.g. coco/images/)
ann_root (string): directory to store the annotation file
"""
super().__init__(vis_processor, text_processor, vis_root, ann_paths)
super().__init__(vision_processor, text_processor, x_root, ann_paths)
self.img_ids = {}
n = 0
@ -47,10 +35,10 @@ class CaptionDataset(BaseDataset, __DisplMixin):
ann = self.annotation[index]
img_file = '{:0>12}.jpg'.format(ann["image_id"])
image_path = os.path.join(self.vis_root, img_file)
image_path = os.path.join(self.x_root, img_file)
image = Image.open(image_path).convert("RGB")
image = self.vis_processor(image)
image = self.x_processor(image)
caption = self.text_processor(ann["caption"])
return {
@ -60,23 +48,23 @@ class CaptionDataset(BaseDataset, __DisplMixin):
}
class CaptionEvalDataset(BaseDataset, __DisplMixin):
def __init__(self, vis_processor, text_processor, vis_root, ann_paths):
class CaptionEvalDataset(BaseDataset, __ImageDisplMixin):
def __init__(self, vision_processor, text_processor, x_root, ann_paths):
"""
vis_root (string): Root directory of images (e.g. coco/images/)
ann_root (string): directory to store the annotation file
split (string): val or test
"""
super().__init__(vis_processor, text_processor, vis_root, ann_paths)
super().__init__(vision_processor, text_processor, x_root, ann_paths)
def __getitem__(self, index):
ann = self.annotation[index]
image_path = os.path.join(self.vis_root, ann["image"])
image_path = os.path.join(self.x_root, ann["image"])
image = Image.open(image_path).convert("RGB")
image = self.vis_processor(image)
image = self.x_processor(image)
return {
"image": image,

View File

@ -10,8 +10,8 @@ from minigpt4.datasets.datasets.base_dataset import BaseDataset
class LaionDataset(BaseDataset):
def __init__(self, vis_processor, text_processor, location):
super().__init__(vis_processor=vis_processor, text_processor=text_processor)
def __init__(self, vision_processor, text_processor, location):
super().__init__(x_processor=vision_processor, text_processor=text_processor)
self.inner_dataset = wds.DataPipeline(
wds.ResampledShards(location),
@ -19,7 +19,7 @@ class LaionDataset(BaseDataset):
wds.shuffle(1000, handler=wds.warn_and_continue),
wds.decode("pilrgb", handler=wds.warn_and_continue),
wds.to_tuple("jpg", "json", handler=wds.warn_and_continue),
wds.map_tuple(self.vis_processor, handler=wds.warn_and_continue),
wds.map_tuple(self.x_processor, handler=wds.warn_and_continue),
wds.map(self.to_dict, handler=wds.warn_and_continue),
)

View File

@ -0,0 +1,30 @@
from collections import OrderedDict
class __ImageDisplMixin:
def displ_item(self, index):
sample, ann = self.__getitem__(index), self.annotation[index]
return OrderedDict(
{
"file": ann["image"],
"caption": ann["caption"],
"image": sample["image"],
}
)
class __AudioDisplMixin:
def displ_item(self, index):
sample, ann = self.__getitem__(index), self.annotation[index]
# TODO: Finish the Audio Display Mixin
'''
return OrderedDict(
{
}
)
'''
raise NotImplementedError

View File

@ -145,4 +145,5 @@ class ImageBindVisionEvalProcessor(ImageBindVisionBaseProcessor):
mean = cfg.get("mean", None)
std = cfg.get("std", None)
return cls(image_size=image_size, mean=mean, std=std)
return cls(image_size=image_size, mean=mean, std=std)

View File

@ -24,7 +24,7 @@ from minigpt4.common.dist_utils import (
)
from minigpt4.common.registry import registry
from minigpt4.common.utils import is_url
from minigpt4.datasets.data_utils import concat_datasets, reorg_datasets_by_split, ChainDataset
from minigpt4.datasets.data_utils import concat_datasets, reorg_datasets_by_split, WrappedChainDataset
from minigpt4.datasets.datasets.dataloader_utils import (
IterLoader,
MultiIterLoader,
@ -219,7 +219,7 @@ class RunnerBase:
num_records = sum(
[
len(d)
if not type(d) in [wds.DataPipeline, ChainDataset]
if not type(d) in [wds.DataPipeline, WrappedChainDataset]
else 0
for d in self.datasets[split_name]
]
@ -503,7 +503,7 @@ class RunnerBase:
def _create_loader(dataset, num_workers, bsz, is_train, collate_fn):
# create a single dataloader for each split
if isinstance(dataset, ChainDataset) or isinstance(
if isinstance(dataset, WrappedChainDataset) or isinstance(
dataset, wds.DataPipeline
):
# wds.WebdDataset instance are chained together

View File

@ -0,0 +1,57 @@
model:
arch: bind_gpt4
model_type: pretrain_vicuna
freeze_imagebind: True
freeze_qformer: False
datasets:
laion:
vis_processor:
train:
name: "imagebind_vision_train"
image_size: 224
text_processor:
train:
name: "imagebind_caption"
sample_ratio: 115
cc_sbu:
vis_processor:
train:
name: "imagebind_vision_train"
image_size: 224
text_processor:
train:
name: "imagebind_caption"
sample_ratio: 14
run:
task: imagebind_qformer_train
# optimizer
lr_sched: "linear_warmup_cosine_lr"
init_lr: 1e-4
min_lr: 8e-5
warmup_lr: 1e-6
weight_decay: 0.05
max_epoch: 4
batch_size_train: 64
batch_size_eval: 64
num_workers: 4
warmup_steps: 5000
iters_per_epoch: 5000
seed: 42
output_dir: "output/bindgpt4"
amp: True
resume_ckpt_path: null
evaluate: False
train_splits: ["train"]
device: "cuda"
world_size: 1
dist_url: "env://"
distributed: True