rename some components for further extension

This commit is contained in:
unknown 2023-05-22 18:48:43 +08:00
parent 49f6b84880
commit d2ba3c48b6
19 changed files with 292 additions and 101 deletions

View File

@ -32,10 +32,10 @@ class Registry:
""" """
def wrap(builder_cls): def wrap(builder_cls):
from minigpt4.datasets.builders.base_dataset_builder import BaseDatasetBuilder from minigpt4.datasets.builders.image_base_dataset_builder import ImageBaseDatasetBuilder
assert issubclass( assert issubclass(
builder_cls, BaseDatasetBuilder builder_cls, ImageBaseDatasetBuilder
), "All builders must inherit BaseDatasetBuilder class, found {}".format( ), "All builders must inherit BaseDatasetBuilder class, found {}".format(
builder_cls builder_cls
) )

View File

@ -5,7 +5,7 @@ model:
freeze_imagebind: True freeze_imagebind: True
# Q-Former # Q-Former
freeze_qformer: True freeze_qformer: False
num_query_token: 32 num_query_token: 32
# Vicuna # Vicuna

View File

@ -5,18 +5,18 @@
For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
""" """
from minigpt4.datasets.builders.base_dataset_builder import load_dataset_config from minigpt4.datasets.builders.image_base_dataset_builder import load_dataset_config
from minigpt4.datasets.builders.image_text_pair_builder import ( from minigpt4.datasets.builders.image_text_pair_builder import (
CCSBUBuilder, CCSBUBuilderImage,
LaionBuilder, LaionBuilderImage,
CCSBUAlignBuilder CCSBUAlignBuilderImage
) )
from minigpt4.common.registry import registry from minigpt4.common.registry import registry
__all__ = [ __all__ = [
"CCSBUBuilder", "CCSBUBuilderImage",
"LaionBuilder", "LaionBuilderImage",
"CCSBUAlignBuilder" "CCSBUAlignBuilderImage"
] ]

View File

@ -0,0 +1,99 @@
import logging
import os
import shutil
import warnings
from omegaconf import OmegaConf
import torch.distributed as dist
from torchvision.datasets.utils import download_url
import minigpt4.common.utils as utils
from minigpt4.common.dist_utils import is_dist_avail_and_initialized, is_main_process
from minigpt4.common.registry import registry
from minigpt4.datasets.builders import load_dataset_config
from minigpt4.processors.base_processor import BaseProcessor
class AudioBaseDatasetBuilder:
train_dataset_cls, eval_dataset_cls = None, None
def __init__(self, cfg=None):
super().__init__()
if cfg is None:
# help to create datasets from default config.
self.config = load_dataset_config(self.default_config_path())
elif isinstance(cfg, str):
self.config = load_dataset_config(cfg)
else:
# when called from task.build_dataset()
self.config = cfg
self.data_type = self.config.data_type
@classmethod
def default_config_path(cls, type="default"):
return utils.get_abs_path(cls.DATASET_CONFIG_DICT[type])
def _download_data(self):
self._download_ann()
self._download_aud()
def _download_ann(self):
"""
Download annotation files if necessary.
All the audio-language datasets should have annotations of unified format.
storage_path can be:
(1) relative/absolute: will be prefixed with env.cache_root to make full path if relative.
(2) basename/dirname: will be suffixed with base name of URL if dirname is provided.
Local annotation paths should be relative.
"""
anns = self.config.build_info.annotations
splits = anns.keys()
cache_root = registry.get_path("cache_root")
for split in splits:
info = anns[split]
urls, storage_paths = info.get("url", None), info.storage
if isinstance(urls, str):
urls = [urls]
if isinstance(storage_paths, str):
storage_paths = [storage_paths]
assert len(urls) == len(storage_paths)
for url_or_filename, storage_path in zip(urls, storage_paths):
# if storage_path is relative, make it full by prefixing with cache_root.
if not os.path.isabs(storage_path):
storage_path = os.path.join(cache_root, storage_path)
dirname = os.path.dirname(storage_path)
if not os.path.exists(dirname):
os.makedirs(dirname)
if os.path.isfile(url_or_filename):
src, dst = url_or_filename, storage_path
if not os.path.exists(dst):
shutil.copyfile(src=src, dst=dst)
else:
logging.info("Using existing file {}.".format(dst))
else:
if os.path.isdir(storage_path):
# if only dirname is provided, suffix with basename of URL.
raise ValueError(
"Expecting storage_path to be a file path, got directory {}".format(
storage_path
)
)
else:
filename = os.path.basename(storage_path)
download_url(url=url_or_filename, root=dirname, filename=filename)

View File

@ -21,8 +21,7 @@ from minigpt4.common.registry import registry
from minigpt4.processors.base_processor import BaseProcessor from minigpt4.processors.base_processor import BaseProcessor
class ImageBaseDatasetBuilder:
class BaseDatasetBuilder:
train_dataset_cls, eval_dataset_cls = None, None train_dataset_cls, eval_dataset_cls = None, None
def __init__(self, cfg=None): def __init__(self, cfg=None):

View File

@ -3,13 +3,13 @@ import logging
import warnings import warnings
from minigpt4.common.registry import registry from minigpt4.common.registry import registry
from minigpt4.datasets.builders.base_dataset_builder import BaseDatasetBuilder from minigpt4.datasets.builders.image_base_dataset_builder import ImageBaseDatasetBuilder
from minigpt4.datasets.datasets.laion_dataset import LaionDataset from minigpt4.datasets.datasets.image_caption.laion_dataset import LaionDataset
from minigpt4.datasets.datasets.cc_sbu_dataset import CCSBUDataset, CCSBUAlignDataset from minigpt4.datasets.datasets.image_caption.cc_sbu_dataset import CCSBUDataset, CCSBUAlignDatasetImageImageCaptionDataset
@registry.register_builder("cc_sbu") @registry.register_builder("cc_sbu")
class CCSBUBuilder(BaseDatasetBuilder): class CCSBUBuilderImage(ImageBaseDatasetBuilder):
train_dataset_cls = CCSBUDataset train_dataset_cls = CCSBUDataset
DATASET_CONFIG_DICT = {"default": "configs/datasets/cc_sbu/defaults.yaml"} DATASET_CONFIG_DICT = {"default": "configs/datasets/cc_sbu/defaults.yaml"}
@ -41,7 +41,7 @@ class CCSBUBuilder(BaseDatasetBuilder):
@registry.register_builder("laion") @registry.register_builder("laion")
class LaionBuilder(BaseDatasetBuilder): class LaionBuilderImage(ImageBaseDatasetBuilder):
train_dataset_cls = LaionDataset train_dataset_cls = LaionDataset
DATASET_CONFIG_DICT = {"default": "configs/datasets/laion/defaults.yaml"} DATASET_CONFIG_DICT = {"default": "configs/datasets/laion/defaults.yaml"}
@ -73,8 +73,8 @@ class LaionBuilder(BaseDatasetBuilder):
@registry.register_builder("cc_sbu_align") @registry.register_builder("cc_sbu_align")
class CCSBUAlignBuilder(BaseDatasetBuilder): class CCSBUAlignBuilderImage(ImageBaseDatasetBuilder):
train_dataset_cls = CCSBUAlignDataset train_dataset_cls = CCSBUAlignDatasetImageImageCaptionDataset
DATASET_CONFIG_DICT = { DATASET_CONFIG_DICT = {
"default": "configs/datasets/cc_sbu/align.yaml", "default": "configs/datasets/cc_sbu/align.yaml",

View File

@ -5,32 +5,44 @@
For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
""" """
import gzip
import logging import logging
import os
import random as rnd
import tarfile
import zipfile
import random import random
from typing import List from typing import List, Iterable
from tqdm import tqdm
import decord import decord
from decord import VideoReader
import webdataset as wds import webdataset as wds
import numpy as np
import torch import torch
from torch.utils.data.dataset import IterableDataset from torch.utils.data import IterableDataset, Dataset, ConcatDataset
from minigpt4.common.registry import registry from minigpt4.common.registry import registry
from minigpt4.datasets.datasets.base_dataset import ConcatDataset
decord.bridge.set_bridge("torch") decord.bridge.set_bridge("torch")
MAX_INT = registry.get("MAX_INT") MAX_INT = registry.get("MAX_INT")
class ChainDataset(wds.DataPipeline): class WrappedConcatDataset(ConcatDataset):
def __init__(self, datasets: Iterable[Dataset]) -> None:
super().__init__(datasets)
def collater(self, samples):
# TODO For now only supports datasets with same underlying collater implementations
all_keys = set()
for s in samples:
all_keys.update(s)
shared_keys = all_keys
for s in samples:
shared_keys = shared_keys & set(s.keys())
samples_shared_keys = []
for s in samples:
samples_shared_keys.append({k: s[k] for k in s.keys() if k in shared_keys})
return self.datasets[0].collater(samples_shared_keys)
class WrappedChainDataset(wds.DataPipeline):
r"""Dataset for chaining multiple :class:`DataPipeline` s. r"""Dataset for chaining multiple :class:`DataPipeline` s.
This class is useful to assemble different existing dataset streams. The This class is useful to assemble different existing dataset streams. The
@ -40,6 +52,7 @@ class ChainDataset(wds.DataPipeline):
Args: Args:
datasets (iterable of IterableDataset): datasets to be chained together datasets (iterable of IterableDataset): datasets to be chained together
""" """
def __init__(self, datasets: List[wds.DataPipeline]) -> None: def __init__(self, datasets: List[wds.DataPipeline]) -> None:
super().__init__() super().__init__()
self.datasets = datasets self.datasets = datasets
@ -173,7 +186,7 @@ def concat_datasets(datasets):
# concatenate map-style datasets and iterable-style datasets separately # concatenate map-style datasets and iterable-style datasets separately
if len(iterable_datasets) > 1: if len(iterable_datasets) > 1:
chained_datasets = ( chained_datasets = (
ChainDataset(iterable_datasets) WrappedChainDataset(iterable_datasets)
) )
elif len(iterable_datasets) == 1: elif len(iterable_datasets) == 1:
chained_datasets = iterable_datasets[0] chained_datasets = iterable_datasets[0]
@ -181,7 +194,7 @@ def concat_datasets(datasets):
chained_datasets = None chained_datasets = None
concat_datasets = ( concat_datasets = (
ConcatDataset(map_datasets) if len(map_datasets) > 0 else None WrappedConcatDataset(map_datasets) if len(map_datasets) > 0 else None
) )
train_datasets = concat_datasets, chained_datasets train_datasets = concat_datasets, chained_datasets
@ -193,4 +206,3 @@ def concat_datasets(datasets):
datasets[split_name] = train_datasets datasets[split_name] = train_datasets
return datasets return datasets

View File

@ -0,0 +1,26 @@
import json
from torch.utils.data import Dataset, default_collate
import webdataset as wds
from minigpt4.datasets.datasets.base_dataset import BaseDataset
class GenericAudioDataset(BaseDataset):
def __init__(self, vision_processor, text_processor, location):
super().__init__(x_processor=vision_processor, text_processor=text_processor)
self.inner_dataset = wds.DataPipeline(
wds.ResampledShards(location),
wds.tarfile_to_samples(handler=wds.warn_and_continue),
wds.shuffle(1000, handler=wds.warn_and_continue),
wds.decode(wds.torch_audio, handler=wds.warn_and_continue),
wds.to_tuple("jpg", "json", handler=wds.warn_and_continue),
wds.map_tuple(self.x_processor, handler=wds.warn_and_continue),
wds.map(self.to_dict, handler=wds.warn_and_continue),
)
def to_dict(self, sample):
return {
"image": sample[0],
"text_input": self.text_processor(sample[1]["caption"]),
}

View File

@ -8,25 +8,25 @@
import json import json
from typing import Iterable from typing import Iterable
from torch.utils.data import Dataset, ConcatDataset from torch.utils.data import Dataset
from torch.utils.data.dataloader import default_collate from torch.utils.data.dataloader import default_collate
class BaseDataset(Dataset): class BaseDataset(Dataset):
def __init__( def __init__(
self, vis_processor=None, text_processor=None, vis_root=None, ann_paths=[] self, x_processor=None, text_processor=None, x_root=None, ann_paths=[]
): ):
""" """
vis_root (string): Root directory of images (e.g. coco/images/) x_root (string): Root directory of data in modality X (e.g. coco/images/, etc.)
ann_root (string): directory to store the annotation file ann_root (string): directory to store the annotation file
""" """
self.vis_root = vis_root self.x_root = x_root
self.annotation = [] self.annotation = []
for ann_path in ann_paths: for ann_path in ann_paths:
self.annotation.extend(json.load(open(ann_path, "r"))['annotations']) self.annotation.extend(json.load(open(ann_path, "r"))['annotations'])
self.vis_processor = vis_processor self.x_processor = x_processor
self.text_processor = text_processor self.text_processor = text_processor
self._add_instance_ids() self._add_instance_ids()
@ -37,32 +37,11 @@ class BaseDataset(Dataset):
def collater(self, samples): def collater(self, samples):
return default_collate(samples) return default_collate(samples)
def set_processors(self, vis_processor, text_processor): def set_processors(self, x_processor, text_processor):
self.vis_processor = vis_processor self.x_processor = x_processor
self.text_processor = text_processor self.text_processor = text_processor
def _add_instance_ids(self, key="instance_id"): def _add_instance_ids(self, key="instance_id"):
for idx, ann in enumerate(self.annotation): for idx, ann in enumerate(self.annotation):
ann[key] = str(idx) ann[key] = str(idx)
class ConcatDataset(ConcatDataset):
def __init__(self, datasets: Iterable[Dataset]) -> None:
super().__init__(datasets)
def collater(self, samples):
# TODO For now only supports datasets with same underlying collater implementations
all_keys = set()
for s in samples:
all_keys.update(s)
shared_keys = all_keys
for s in samples:
shared_keys = shared_keys & set(s.keys())
samples_shared_keys = []
for s in samples:
samples_shared_keys.append({k: s[k] for k in s.keys() if k in shared_keys})
return self.datasets[0].collater(samples_shared_keys)

View File

@ -2,12 +2,12 @@ import os
from PIL import Image from PIL import Image
import webdataset as wds import webdataset as wds
from minigpt4.datasets.datasets.base_dataset import BaseDataset from minigpt4.datasets.datasets.base_dataset import BaseDataset
from minigpt4.datasets.datasets.caption_datasets import CaptionDataset from minigpt4.datasets.datasets.image_caption.image_caption_datasets import ImageCaptionDataset
class CCSBUDataset(BaseDataset): class CCSBUDataset(BaseDataset):
def __init__(self, vis_processor, text_processor, location): def __init__(self, vision_processor, text_processor, location):
super().__init__(vis_processor=vis_processor, text_processor=text_processor) super().__init__(x_processor=vision_processor, text_processor=text_processor)
self.inner_dataset = wds.DataPipeline( self.inner_dataset = wds.DataPipeline(
wds.ResampledShards(location), wds.ResampledShards(location),
@ -15,7 +15,7 @@ class CCSBUDataset(BaseDataset):
wds.shuffle(1000, handler=wds.warn_and_continue), wds.shuffle(1000, handler=wds.warn_and_continue),
wds.decode("pilrgb", handler=wds.warn_and_continue), wds.decode("pilrgb", handler=wds.warn_and_continue),
wds.to_tuple("jpg", "json", handler=wds.warn_and_continue), wds.to_tuple("jpg", "json", handler=wds.warn_and_continue),
wds.map_tuple(self.vis_processor, handler=wds.warn_and_continue), wds.map_tuple(self.x_processor, handler=wds.warn_and_continue),
wds.map(self.to_dict, handler=wds.warn_and_continue), wds.map(self.to_dict, handler=wds.warn_and_continue),
) )
@ -26,7 +26,7 @@ class CCSBUDataset(BaseDataset):
} }
class CCSBUAlignDataset(CaptionDataset): class CCSBUAlignDatasetImageImageCaptionDataset(ImageCaptionDataset):
def __getitem__(self, index): def __getitem__(self, index):
@ -34,10 +34,10 @@ class CCSBUAlignDataset(CaptionDataset):
ann = self.annotation[index] ann = self.annotation[index]
img_file = '{}.jpg'.format(ann["image_id"]) img_file = '{}.jpg'.format(ann["image_id"])
image_path = os.path.join(self.vis_root, img_file) image_path = os.path.join(self.x_root, img_file)
image = Image.open(image_path).convert("RGB") image = Image.open(image_path).convert("RGB")
image = self.vis_processor(image) image = self.x_processor(image)
caption = ann["caption"] caption = ann["caption"]
return { return {

View File

@ -6,32 +6,20 @@
""" """
import os import os
from collections import OrderedDict
from minigpt4.datasets.datasets.base_dataset import BaseDataset from minigpt4.datasets.datasets.base_dataset import BaseDataset
from PIL import Image from PIL import Image
from minigpt4.datasets.datasets.mixins.mixins import __ImageDisplMixin
class __DisplMixin:
def displ_item(self, index):
sample, ann = self.__getitem__(index), self.annotation[index]
return OrderedDict(
{
"file": ann["image"],
"caption": ann["caption"],
"image": sample["image"],
}
)
class CaptionDataset(BaseDataset, __DisplMixin): class ImageCaptionDataset(BaseDataset, __ImageDisplMixin):
def __init__(self, vis_processor, text_processor, vis_root, ann_paths): def __init__(self, vision_processor, text_processor, x_root, ann_paths):
""" """
vis_root (string): Root directory of images (e.g. coco/images/) vis_root (string): Root directory of images (e.g. coco/images/)
ann_root (string): directory to store the annotation file ann_root (string): directory to store the annotation file
""" """
super().__init__(vis_processor, text_processor, vis_root, ann_paths) super().__init__(vision_processor, text_processor, x_root, ann_paths)
self.img_ids = {} self.img_ids = {}
n = 0 n = 0
@ -47,10 +35,10 @@ class CaptionDataset(BaseDataset, __DisplMixin):
ann = self.annotation[index] ann = self.annotation[index]
img_file = '{:0>12}.jpg'.format(ann["image_id"]) img_file = '{:0>12}.jpg'.format(ann["image_id"])
image_path = os.path.join(self.vis_root, img_file) image_path = os.path.join(self.x_root, img_file)
image = Image.open(image_path).convert("RGB") image = Image.open(image_path).convert("RGB")
image = self.vis_processor(image) image = self.x_processor(image)
caption = self.text_processor(ann["caption"]) caption = self.text_processor(ann["caption"])
return { return {
@ -60,23 +48,23 @@ class CaptionDataset(BaseDataset, __DisplMixin):
} }
class CaptionEvalDataset(BaseDataset, __DisplMixin): class CaptionEvalDataset(BaseDataset, __ImageDisplMixin):
def __init__(self, vis_processor, text_processor, vis_root, ann_paths): def __init__(self, vision_processor, text_processor, x_root, ann_paths):
""" """
vis_root (string): Root directory of images (e.g. coco/images/) vis_root (string): Root directory of images (e.g. coco/images/)
ann_root (string): directory to store the annotation file ann_root (string): directory to store the annotation file
split (string): val or test split (string): val or test
""" """
super().__init__(vis_processor, text_processor, vis_root, ann_paths) super().__init__(vision_processor, text_processor, x_root, ann_paths)
def __getitem__(self, index): def __getitem__(self, index):
ann = self.annotation[index] ann = self.annotation[index]
image_path = os.path.join(self.vis_root, ann["image"]) image_path = os.path.join(self.x_root, ann["image"])
image = Image.open(image_path).convert("RGB") image = Image.open(image_path).convert("RGB")
image = self.vis_processor(image) image = self.x_processor(image)
return { return {
"image": image, "image": image,

View File

@ -10,8 +10,8 @@ from minigpt4.datasets.datasets.base_dataset import BaseDataset
class LaionDataset(BaseDataset): class LaionDataset(BaseDataset):
def __init__(self, vis_processor, text_processor, location): def __init__(self, vision_processor, text_processor, location):
super().__init__(vis_processor=vis_processor, text_processor=text_processor) super().__init__(x_processor=vision_processor, text_processor=text_processor)
self.inner_dataset = wds.DataPipeline( self.inner_dataset = wds.DataPipeline(
wds.ResampledShards(location), wds.ResampledShards(location),
@ -19,7 +19,7 @@ class LaionDataset(BaseDataset):
wds.shuffle(1000, handler=wds.warn_and_continue), wds.shuffle(1000, handler=wds.warn_and_continue),
wds.decode("pilrgb", handler=wds.warn_and_continue), wds.decode("pilrgb", handler=wds.warn_and_continue),
wds.to_tuple("jpg", "json", handler=wds.warn_and_continue), wds.to_tuple("jpg", "json", handler=wds.warn_and_continue),
wds.map_tuple(self.vis_processor, handler=wds.warn_and_continue), wds.map_tuple(self.x_processor, handler=wds.warn_and_continue),
wds.map(self.to_dict, handler=wds.warn_and_continue), wds.map(self.to_dict, handler=wds.warn_and_continue),
) )

View File

@ -0,0 +1,30 @@
from collections import OrderedDict
class __ImageDisplMixin:
def displ_item(self, index):
sample, ann = self.__getitem__(index), self.annotation[index]
return OrderedDict(
{
"file": ann["image"],
"caption": ann["caption"],
"image": sample["image"],
}
)
class __AudioDisplMixin:
def displ_item(self, index):
sample, ann = self.__getitem__(index), self.annotation[index]
# TODO: Finish the Audio Display Mixin
'''
return OrderedDict(
{
}
)
'''
raise NotImplementedError

View File

@ -146,3 +146,4 @@ class ImageBindVisionEvalProcessor(ImageBindVisionBaseProcessor):
std = cfg.get("std", None) std = cfg.get("std", None)
return cls(image_size=image_size, mean=mean, std=std) return cls(image_size=image_size, mean=mean, std=std)

View File

@ -24,7 +24,7 @@ from minigpt4.common.dist_utils import (
) )
from minigpt4.common.registry import registry from minigpt4.common.registry import registry
from minigpt4.common.utils import is_url from minigpt4.common.utils import is_url
from minigpt4.datasets.data_utils import concat_datasets, reorg_datasets_by_split, ChainDataset from minigpt4.datasets.data_utils import concat_datasets, reorg_datasets_by_split, WrappedChainDataset
from minigpt4.datasets.datasets.dataloader_utils import ( from minigpt4.datasets.datasets.dataloader_utils import (
IterLoader, IterLoader,
MultiIterLoader, MultiIterLoader,
@ -219,7 +219,7 @@ class RunnerBase:
num_records = sum( num_records = sum(
[ [
len(d) len(d)
if not type(d) in [wds.DataPipeline, ChainDataset] if not type(d) in [wds.DataPipeline, WrappedChainDataset]
else 0 else 0
for d in self.datasets[split_name] for d in self.datasets[split_name]
] ]
@ -503,7 +503,7 @@ class RunnerBase:
def _create_loader(dataset, num_workers, bsz, is_train, collate_fn): def _create_loader(dataset, num_workers, bsz, is_train, collate_fn):
# create a single dataloader for each split # create a single dataloader for each split
if isinstance(dataset, ChainDataset) or isinstance( if isinstance(dataset, WrappedChainDataset) or isinstance(
dataset, wds.DataPipeline dataset, wds.DataPipeline
): ):
# wds.WebdDataset instance are chained together # wds.WebdDataset instance are chained together

View File

@ -0,0 +1,57 @@
model:
arch: bind_gpt4
model_type: pretrain_vicuna
freeze_imagebind: True
freeze_qformer: False
datasets:
laion:
vis_processor:
train:
name: "imagebind_vision_train"
image_size: 224
text_processor:
train:
name: "imagebind_caption"
sample_ratio: 115
cc_sbu:
vis_processor:
train:
name: "imagebind_vision_train"
image_size: 224
text_processor:
train:
name: "imagebind_caption"
sample_ratio: 14
run:
task: imagebind_qformer_train
# optimizer
lr_sched: "linear_warmup_cosine_lr"
init_lr: 1e-4
min_lr: 8e-5
warmup_lr: 1e-6
weight_decay: 0.05
max_epoch: 4
batch_size_train: 64
batch_size_eval: 64
num_workers: 4
warmup_steps: 5000
iters_per_epoch: 5000
seed: 42
output_dir: "output/bindgpt4"
amp: True
resume_ckpt_path: null
evaluate: False
train_splits: ["train"]
device: "cuda"
world_size: 1
dist_url: "env://"
distributed: True