From d2ba3c48b6a89e2eca58f50cb1468e4ae9e0c13b Mon Sep 17 00:00:00 2001 From: unknown <913556700@qq.com> Date: Mon, 22 May 2023 18:48:43 +0800 Subject: [PATCH] rename some components for further extension --- minigpt4/common/registry.py | 4 +- minigpt4/configs/models/bindgpt4.yaml | 2 +- minigpt4/datasets/builders/__init__.py | 14 +-- .../builders/audio_base_dataset_builder.py | 99 +++++++++++++++++++ ...ilder.py => image_base_dataset_builder.py} | 3 +- .../builders/image_text_pair_builder.py | 14 +-- minigpt4/datasets/data_utils.py | 46 +++++---- .../datasets/audio_caption/__init__.py | 0 .../audio_caption/audio_caption_datasets.py | 26 +++++ minigpt4/datasets/datasets/base_dataset.py | 35 ++----- .../datasets/image_caption/__init__.py | 0 .../{ => image_caption}/cc_sbu_dataset.py | 14 +-- .../image_caption_datasets.py} | 34 +++---- .../{ => image_caption}/laion_dataset.py | 6 +- minigpt4/datasets/datasets/mixins/__init__.py | 0 minigpt4/datasets/datasets/mixins/mixins.py | 30 ++++++ minigpt4/processors/imagebind_processor.py | 3 +- minigpt4/runners/runner_base.py | 6 +- train_configs/bindgpt4.yaml | 57 +++++++++++ 19 files changed, 292 insertions(+), 101 deletions(-) create mode 100644 minigpt4/datasets/builders/audio_base_dataset_builder.py rename minigpt4/datasets/builders/{base_dataset_builder.py => image_base_dataset_builder.py} (99%) create mode 100644 minigpt4/datasets/datasets/audio_caption/__init__.py create mode 100644 minigpt4/datasets/datasets/audio_caption/audio_caption_datasets.py create mode 100644 minigpt4/datasets/datasets/image_caption/__init__.py rename minigpt4/datasets/datasets/{ => image_caption}/cc_sbu_dataset.py (70%) rename minigpt4/datasets/datasets/{caption_datasets.py => image_caption/image_caption_datasets.py} (62%) rename minigpt4/datasets/datasets/{ => image_caption}/laion_dataset.py (80%) create mode 100644 minigpt4/datasets/datasets/mixins/__init__.py create mode 100644 minigpt4/datasets/datasets/mixins/mixins.py create mode 100644 train_configs/bindgpt4.yaml diff --git a/minigpt4/common/registry.py b/minigpt4/common/registry.py index 679467a..75a6c23 100644 --- a/minigpt4/common/registry.py +++ b/minigpt4/common/registry.py @@ -32,10 +32,10 @@ class Registry: """ def wrap(builder_cls): - from minigpt4.datasets.builders.base_dataset_builder import BaseDatasetBuilder + from minigpt4.datasets.builders.image_base_dataset_builder import ImageBaseDatasetBuilder assert issubclass( - builder_cls, BaseDatasetBuilder + builder_cls, ImageBaseDatasetBuilder ), "All builders must inherit BaseDatasetBuilder class, found {}".format( builder_cls ) diff --git a/minigpt4/configs/models/bindgpt4.yaml b/minigpt4/configs/models/bindgpt4.yaml index 3d436ec..3f4b259 100644 --- a/minigpt4/configs/models/bindgpt4.yaml +++ b/minigpt4/configs/models/bindgpt4.yaml @@ -5,7 +5,7 @@ model: freeze_imagebind: True # Q-Former - freeze_qformer: True + freeze_qformer: False num_query_token: 32 # Vicuna diff --git a/minigpt4/datasets/builders/__init__.py b/minigpt4/datasets/builders/__init__.py index 6d09640..744fe2f 100644 --- a/minigpt4/datasets/builders/__init__.py +++ b/minigpt4/datasets/builders/__init__.py @@ -5,18 +5,18 @@ For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause """ -from minigpt4.datasets.builders.base_dataset_builder import load_dataset_config +from minigpt4.datasets.builders.image_base_dataset_builder import load_dataset_config from minigpt4.datasets.builders.image_text_pair_builder import ( - CCSBUBuilder, - LaionBuilder, - CCSBUAlignBuilder + CCSBUBuilderImage, + LaionBuilderImage, + CCSBUAlignBuilderImage ) from minigpt4.common.registry import registry __all__ = [ - "CCSBUBuilder", - "LaionBuilder", - "CCSBUAlignBuilder" + "CCSBUBuilderImage", + "LaionBuilderImage", + "CCSBUAlignBuilderImage" ] diff --git a/minigpt4/datasets/builders/audio_base_dataset_builder.py b/minigpt4/datasets/builders/audio_base_dataset_builder.py new file mode 100644 index 0000000..4e1361e --- /dev/null +++ b/minigpt4/datasets/builders/audio_base_dataset_builder.py @@ -0,0 +1,99 @@ +import logging +import os +import shutil +import warnings + +from omegaconf import OmegaConf +import torch.distributed as dist +from torchvision.datasets.utils import download_url + +import minigpt4.common.utils as utils +from minigpt4.common.dist_utils import is_dist_avail_and_initialized, is_main_process +from minigpt4.common.registry import registry +from minigpt4.datasets.builders import load_dataset_config +from minigpt4.processors.base_processor import BaseProcessor + + +class AudioBaseDatasetBuilder: + train_dataset_cls, eval_dataset_cls = None, None + + def __init__(self, cfg=None): + super().__init__() + + if cfg is None: + # help to create datasets from default config. + self.config = load_dataset_config(self.default_config_path()) + elif isinstance(cfg, str): + self.config = load_dataset_config(cfg) + else: + # when called from task.build_dataset() + self.config = cfg + + self.data_type = self.config.data_type + + @classmethod + def default_config_path(cls, type="default"): + return utils.get_abs_path(cls.DATASET_CONFIG_DICT[type]) + + def _download_data(self): + self._download_ann() + self._download_aud() + + def _download_ann(self): + """ + Download annotation files if necessary. + All the audio-language datasets should have annotations of unified format. + + storage_path can be: + (1) relative/absolute: will be prefixed with env.cache_root to make full path if relative. + (2) basename/dirname: will be suffixed with base name of URL if dirname is provided. + + Local annotation paths should be relative. + """ + anns = self.config.build_info.annotations + + splits = anns.keys() + + cache_root = registry.get_path("cache_root") + + for split in splits: + info = anns[split] + + urls, storage_paths = info.get("url", None), info.storage + + if isinstance(urls, str): + urls = [urls] + if isinstance(storage_paths, str): + storage_paths = [storage_paths] + + assert len(urls) == len(storage_paths) + + for url_or_filename, storage_path in zip(urls, storage_paths): + # if storage_path is relative, make it full by prefixing with cache_root. + if not os.path.isabs(storage_path): + storage_path = os.path.join(cache_root, storage_path) + + dirname = os.path.dirname(storage_path) + if not os.path.exists(dirname): + os.makedirs(dirname) + + if os.path.isfile(url_or_filename): + src, dst = url_or_filename, storage_path + if not os.path.exists(dst): + shutil.copyfile(src=src, dst=dst) + else: + logging.info("Using existing file {}.".format(dst)) + else: + if os.path.isdir(storage_path): + # if only dirname is provided, suffix with basename of URL. + raise ValueError( + "Expecting storage_path to be a file path, got directory {}".format( + storage_path + ) + ) + else: + filename = os.path.basename(storage_path) + + download_url(url=url_or_filename, root=dirname, filename=filename) + + diff --git a/minigpt4/datasets/builders/base_dataset_builder.py b/minigpt4/datasets/builders/image_base_dataset_builder.py similarity index 99% rename from minigpt4/datasets/builders/base_dataset_builder.py rename to minigpt4/datasets/builders/image_base_dataset_builder.py index 4b607e3..b9c5bd4 100644 --- a/minigpt4/datasets/builders/base_dataset_builder.py +++ b/minigpt4/datasets/builders/image_base_dataset_builder.py @@ -21,8 +21,7 @@ from minigpt4.common.registry import registry from minigpt4.processors.base_processor import BaseProcessor - -class BaseDatasetBuilder: +class ImageBaseDatasetBuilder: train_dataset_cls, eval_dataset_cls = None, None def __init__(self, cfg=None): diff --git a/minigpt4/datasets/builders/image_text_pair_builder.py b/minigpt4/datasets/builders/image_text_pair_builder.py index e5d66b8..e8d81f7 100644 --- a/minigpt4/datasets/builders/image_text_pair_builder.py +++ b/minigpt4/datasets/builders/image_text_pair_builder.py @@ -3,13 +3,13 @@ import logging import warnings from minigpt4.common.registry import registry -from minigpt4.datasets.builders.base_dataset_builder import BaseDatasetBuilder -from minigpt4.datasets.datasets.laion_dataset import LaionDataset -from minigpt4.datasets.datasets.cc_sbu_dataset import CCSBUDataset, CCSBUAlignDataset +from minigpt4.datasets.builders.image_base_dataset_builder import ImageBaseDatasetBuilder +from minigpt4.datasets.datasets.image_caption.laion_dataset import LaionDataset +from minigpt4.datasets.datasets.image_caption.cc_sbu_dataset import CCSBUDataset, CCSBUAlignDatasetImageImageCaptionDataset @registry.register_builder("cc_sbu") -class CCSBUBuilder(BaseDatasetBuilder): +class CCSBUBuilderImage(ImageBaseDatasetBuilder): train_dataset_cls = CCSBUDataset DATASET_CONFIG_DICT = {"default": "configs/datasets/cc_sbu/defaults.yaml"} @@ -41,7 +41,7 @@ class CCSBUBuilder(BaseDatasetBuilder): @registry.register_builder("laion") -class LaionBuilder(BaseDatasetBuilder): +class LaionBuilderImage(ImageBaseDatasetBuilder): train_dataset_cls = LaionDataset DATASET_CONFIG_DICT = {"default": "configs/datasets/laion/defaults.yaml"} @@ -73,8 +73,8 @@ class LaionBuilder(BaseDatasetBuilder): @registry.register_builder("cc_sbu_align") -class CCSBUAlignBuilder(BaseDatasetBuilder): - train_dataset_cls = CCSBUAlignDataset +class CCSBUAlignBuilderImage(ImageBaseDatasetBuilder): + train_dataset_cls = CCSBUAlignDatasetImageImageCaptionDataset DATASET_CONFIG_DICT = { "default": "configs/datasets/cc_sbu/align.yaml", diff --git a/minigpt4/datasets/data_utils.py b/minigpt4/datasets/data_utils.py index cf6497f..a87bc36 100644 --- a/minigpt4/datasets/data_utils.py +++ b/minigpt4/datasets/data_utils.py @@ -5,32 +5,44 @@ For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause """ -import gzip import logging -import os -import random as rnd -import tarfile -import zipfile import random -from typing import List -from tqdm import tqdm +from typing import List, Iterable import decord -from decord import VideoReader import webdataset as wds -import numpy as np import torch -from torch.utils.data.dataset import IterableDataset +from torch.utils.data import IterableDataset, Dataset, ConcatDataset from minigpt4.common.registry import registry -from minigpt4.datasets.datasets.base_dataset import ConcatDataset - decord.bridge.set_bridge("torch") MAX_INT = registry.get("MAX_INT") -class ChainDataset(wds.DataPipeline): +class WrappedConcatDataset(ConcatDataset): + def __init__(self, datasets: Iterable[Dataset]) -> None: + super().__init__(datasets) + + def collater(self, samples): + # TODO For now only supports datasets with same underlying collater implementations + + all_keys = set() + for s in samples: + all_keys.update(s) + + shared_keys = all_keys + for s in samples: + shared_keys = shared_keys & set(s.keys()) + + samples_shared_keys = [] + for s in samples: + samples_shared_keys.append({k: s[k] for k in s.keys() if k in shared_keys}) + + return self.datasets[0].collater(samples_shared_keys) + + +class WrappedChainDataset(wds.DataPipeline): r"""Dataset for chaining multiple :class:`DataPipeline` s. This class is useful to assemble different existing dataset streams. The @@ -40,6 +52,7 @@ class ChainDataset(wds.DataPipeline): Args: datasets (iterable of IterableDataset): datasets to be chained together """ + def __init__(self, datasets: List[wds.DataPipeline]) -> None: super().__init__() self.datasets = datasets @@ -149,7 +162,7 @@ def concat_datasets(datasets): for split_name in datasets: if split_name != "train": assert ( - len(datasets[split_name]) == 1 + len(datasets[split_name]) == 1 ), "Do not support multiple {} datasets.".format(split_name) datasets[split_name] = datasets[split_name][0] else: @@ -173,7 +186,7 @@ def concat_datasets(datasets): # concatenate map-style datasets and iterable-style datasets separately if len(iterable_datasets) > 1: chained_datasets = ( - ChainDataset(iterable_datasets) + WrappedChainDataset(iterable_datasets) ) elif len(iterable_datasets) == 1: chained_datasets = iterable_datasets[0] @@ -181,7 +194,7 @@ def concat_datasets(datasets): chained_datasets = None concat_datasets = ( - ConcatDataset(map_datasets) if len(map_datasets) > 0 else None + WrappedConcatDataset(map_datasets) if len(map_datasets) > 0 else None ) train_datasets = concat_datasets, chained_datasets @@ -193,4 +206,3 @@ def concat_datasets(datasets): datasets[split_name] = train_datasets return datasets - diff --git a/minigpt4/datasets/datasets/audio_caption/__init__.py b/minigpt4/datasets/datasets/audio_caption/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/minigpt4/datasets/datasets/audio_caption/audio_caption_datasets.py b/minigpt4/datasets/datasets/audio_caption/audio_caption_datasets.py new file mode 100644 index 0000000..52b6dcd --- /dev/null +++ b/minigpt4/datasets/datasets/audio_caption/audio_caption_datasets.py @@ -0,0 +1,26 @@ +import json + +from torch.utils.data import Dataset, default_collate +import webdataset as wds +from minigpt4.datasets.datasets.base_dataset import BaseDataset + + +class GenericAudioDataset(BaseDataset): + def __init__(self, vision_processor, text_processor, location): + super().__init__(x_processor=vision_processor, text_processor=text_processor) + + self.inner_dataset = wds.DataPipeline( + wds.ResampledShards(location), + wds.tarfile_to_samples(handler=wds.warn_and_continue), + wds.shuffle(1000, handler=wds.warn_and_continue), + wds.decode(wds.torch_audio, handler=wds.warn_and_continue), + wds.to_tuple("jpg", "json", handler=wds.warn_and_continue), + wds.map_tuple(self.x_processor, handler=wds.warn_and_continue), + wds.map(self.to_dict, handler=wds.warn_and_continue), + ) + + def to_dict(self, sample): + return { + "image": sample[0], + "text_input": self.text_processor(sample[1]["caption"]), + } diff --git a/minigpt4/datasets/datasets/base_dataset.py b/minigpt4/datasets/datasets/base_dataset.py index ae2a8d0..fc0d648 100644 --- a/minigpt4/datasets/datasets/base_dataset.py +++ b/minigpt4/datasets/datasets/base_dataset.py @@ -8,25 +8,25 @@ import json from typing import Iterable -from torch.utils.data import Dataset, ConcatDataset +from torch.utils.data import Dataset from torch.utils.data.dataloader import default_collate class BaseDataset(Dataset): def __init__( - self, vis_processor=None, text_processor=None, vis_root=None, ann_paths=[] + self, x_processor=None, text_processor=None, x_root=None, ann_paths=[] ): """ - vis_root (string): Root directory of images (e.g. coco/images/) + x_root (string): Root directory of data in modality X (e.g. coco/images/, etc.) ann_root (string): directory to store the annotation file """ - self.vis_root = vis_root + self.x_root = x_root self.annotation = [] for ann_path in ann_paths: self.annotation.extend(json.load(open(ann_path, "r"))['annotations']) - self.vis_processor = vis_processor + self.x_processor = x_processor self.text_processor = text_processor self._add_instance_ids() @@ -37,32 +37,11 @@ class BaseDataset(Dataset): def collater(self, samples): return default_collate(samples) - def set_processors(self, vis_processor, text_processor): - self.vis_processor = vis_processor + def set_processors(self, x_processor, text_processor): + self.x_processor = x_processor self.text_processor = text_processor def _add_instance_ids(self, key="instance_id"): for idx, ann in enumerate(self.annotation): ann[key] = str(idx) - -class ConcatDataset(ConcatDataset): - def __init__(self, datasets: Iterable[Dataset]) -> None: - super().__init__(datasets) - - def collater(self, samples): - # TODO For now only supports datasets with same underlying collater implementations - - all_keys = set() - for s in samples: - all_keys.update(s) - - shared_keys = all_keys - for s in samples: - shared_keys = shared_keys & set(s.keys()) - - samples_shared_keys = [] - for s in samples: - samples_shared_keys.append({k: s[k] for k in s.keys() if k in shared_keys}) - - return self.datasets[0].collater(samples_shared_keys) diff --git a/minigpt4/datasets/datasets/image_caption/__init__.py b/minigpt4/datasets/datasets/image_caption/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/minigpt4/datasets/datasets/cc_sbu_dataset.py b/minigpt4/datasets/datasets/image_caption/cc_sbu_dataset.py similarity index 70% rename from minigpt4/datasets/datasets/cc_sbu_dataset.py rename to minigpt4/datasets/datasets/image_caption/cc_sbu_dataset.py index f42bbce..82971f6 100644 --- a/minigpt4/datasets/datasets/cc_sbu_dataset.py +++ b/minigpt4/datasets/datasets/image_caption/cc_sbu_dataset.py @@ -2,12 +2,12 @@ import os from PIL import Image import webdataset as wds from minigpt4.datasets.datasets.base_dataset import BaseDataset -from minigpt4.datasets.datasets.caption_datasets import CaptionDataset +from minigpt4.datasets.datasets.image_caption.image_caption_datasets import ImageCaptionDataset class CCSBUDataset(BaseDataset): - def __init__(self, vis_processor, text_processor, location): - super().__init__(vis_processor=vis_processor, text_processor=text_processor) + def __init__(self, vision_processor, text_processor, location): + super().__init__(x_processor=vision_processor, text_processor=text_processor) self.inner_dataset = wds.DataPipeline( wds.ResampledShards(location), @@ -15,7 +15,7 @@ class CCSBUDataset(BaseDataset): wds.shuffle(1000, handler=wds.warn_and_continue), wds.decode("pilrgb", handler=wds.warn_and_continue), wds.to_tuple("jpg", "json", handler=wds.warn_and_continue), - wds.map_tuple(self.vis_processor, handler=wds.warn_and_continue), + wds.map_tuple(self.x_processor, handler=wds.warn_and_continue), wds.map(self.to_dict, handler=wds.warn_and_continue), ) @@ -26,7 +26,7 @@ class CCSBUDataset(BaseDataset): } -class CCSBUAlignDataset(CaptionDataset): +class CCSBUAlignDatasetImageImageCaptionDataset(ImageCaptionDataset): def __getitem__(self, index): @@ -34,10 +34,10 @@ class CCSBUAlignDataset(CaptionDataset): ann = self.annotation[index] img_file = '{}.jpg'.format(ann["image_id"]) - image_path = os.path.join(self.vis_root, img_file) + image_path = os.path.join(self.x_root, img_file) image = Image.open(image_path).convert("RGB") - image = self.vis_processor(image) + image = self.x_processor(image) caption = ann["caption"] return { diff --git a/minigpt4/datasets/datasets/caption_datasets.py b/minigpt4/datasets/datasets/image_caption/image_caption_datasets.py similarity index 62% rename from minigpt4/datasets/datasets/caption_datasets.py rename to minigpt4/datasets/datasets/image_caption/image_caption_datasets.py index 78bab66..120fcf3 100644 --- a/minigpt4/datasets/datasets/caption_datasets.py +++ b/minigpt4/datasets/datasets/image_caption/image_caption_datasets.py @@ -6,32 +6,20 @@ """ import os -from collections import OrderedDict from minigpt4.datasets.datasets.base_dataset import BaseDataset from PIL import Image - -class __DisplMixin: - def displ_item(self, index): - sample, ann = self.__getitem__(index), self.annotation[index] - - return OrderedDict( - { - "file": ann["image"], - "caption": ann["caption"], - "image": sample["image"], - } - ) +from minigpt4.datasets.datasets.mixins.mixins import __ImageDisplMixin -class CaptionDataset(BaseDataset, __DisplMixin): - def __init__(self, vis_processor, text_processor, vis_root, ann_paths): +class ImageCaptionDataset(BaseDataset, __ImageDisplMixin): + def __init__(self, vision_processor, text_processor, x_root, ann_paths): """ vis_root (string): Root directory of images (e.g. coco/images/) ann_root (string): directory to store the annotation file """ - super().__init__(vis_processor, text_processor, vis_root, ann_paths) + super().__init__(vision_processor, text_processor, x_root, ann_paths) self.img_ids = {} n = 0 @@ -47,10 +35,10 @@ class CaptionDataset(BaseDataset, __DisplMixin): ann = self.annotation[index] img_file = '{:0>12}.jpg'.format(ann["image_id"]) - image_path = os.path.join(self.vis_root, img_file) + image_path = os.path.join(self.x_root, img_file) image = Image.open(image_path).convert("RGB") - image = self.vis_processor(image) + image = self.x_processor(image) caption = self.text_processor(ann["caption"]) return { @@ -60,23 +48,23 @@ class CaptionDataset(BaseDataset, __DisplMixin): } -class CaptionEvalDataset(BaseDataset, __DisplMixin): - def __init__(self, vis_processor, text_processor, vis_root, ann_paths): +class CaptionEvalDataset(BaseDataset, __ImageDisplMixin): + def __init__(self, vision_processor, text_processor, x_root, ann_paths): """ vis_root (string): Root directory of images (e.g. coco/images/) ann_root (string): directory to store the annotation file split (string): val or test """ - super().__init__(vis_processor, text_processor, vis_root, ann_paths) + super().__init__(vision_processor, text_processor, x_root, ann_paths) def __getitem__(self, index): ann = self.annotation[index] - image_path = os.path.join(self.vis_root, ann["image"]) + image_path = os.path.join(self.x_root, ann["image"]) image = Image.open(image_path).convert("RGB") - image = self.vis_processor(image) + image = self.x_processor(image) return { "image": image, diff --git a/minigpt4/datasets/datasets/laion_dataset.py b/minigpt4/datasets/datasets/image_caption/laion_dataset.py similarity index 80% rename from minigpt4/datasets/datasets/laion_dataset.py rename to minigpt4/datasets/datasets/image_caption/laion_dataset.py index 1becbe4..8cb34a8 100644 --- a/minigpt4/datasets/datasets/laion_dataset.py +++ b/minigpt4/datasets/datasets/image_caption/laion_dataset.py @@ -10,8 +10,8 @@ from minigpt4.datasets.datasets.base_dataset import BaseDataset class LaionDataset(BaseDataset): - def __init__(self, vis_processor, text_processor, location): - super().__init__(vis_processor=vis_processor, text_processor=text_processor) + def __init__(self, vision_processor, text_processor, location): + super().__init__(x_processor=vision_processor, text_processor=text_processor) self.inner_dataset = wds.DataPipeline( wds.ResampledShards(location), @@ -19,7 +19,7 @@ class LaionDataset(BaseDataset): wds.shuffle(1000, handler=wds.warn_and_continue), wds.decode("pilrgb", handler=wds.warn_and_continue), wds.to_tuple("jpg", "json", handler=wds.warn_and_continue), - wds.map_tuple(self.vis_processor, handler=wds.warn_and_continue), + wds.map_tuple(self.x_processor, handler=wds.warn_and_continue), wds.map(self.to_dict, handler=wds.warn_and_continue), ) diff --git a/minigpt4/datasets/datasets/mixins/__init__.py b/minigpt4/datasets/datasets/mixins/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/minigpt4/datasets/datasets/mixins/mixins.py b/minigpt4/datasets/datasets/mixins/mixins.py new file mode 100644 index 0000000..e143147 --- /dev/null +++ b/minigpt4/datasets/datasets/mixins/mixins.py @@ -0,0 +1,30 @@ +from collections import OrderedDict + + +class __ImageDisplMixin: + def displ_item(self, index): + sample, ann = self.__getitem__(index), self.annotation[index] + + return OrderedDict( + { + "file": ann["image"], + "caption": ann["caption"], + "image": sample["image"], + } + ) + + +class __AudioDisplMixin: + def displ_item(self, index): + sample, ann = self.__getitem__(index), self.annotation[index] + + # TODO: Finish the Audio Display Mixin + ''' + return OrderedDict( + { + } + ) + ''' + + raise NotImplementedError + diff --git a/minigpt4/processors/imagebind_processor.py b/minigpt4/processors/imagebind_processor.py index 4e13560..837f499 100644 --- a/minigpt4/processors/imagebind_processor.py +++ b/minigpt4/processors/imagebind_processor.py @@ -145,4 +145,5 @@ class ImageBindVisionEvalProcessor(ImageBindVisionBaseProcessor): mean = cfg.get("mean", None) std = cfg.get("std", None) - return cls(image_size=image_size, mean=mean, std=std) \ No newline at end of file + return cls(image_size=image_size, mean=mean, std=std) + diff --git a/minigpt4/runners/runner_base.py b/minigpt4/runners/runner_base.py index ccb5706..cb38b1f 100644 --- a/minigpt4/runners/runner_base.py +++ b/minigpt4/runners/runner_base.py @@ -24,7 +24,7 @@ from minigpt4.common.dist_utils import ( ) from minigpt4.common.registry import registry from minigpt4.common.utils import is_url -from minigpt4.datasets.data_utils import concat_datasets, reorg_datasets_by_split, ChainDataset +from minigpt4.datasets.data_utils import concat_datasets, reorg_datasets_by_split, WrappedChainDataset from minigpt4.datasets.datasets.dataloader_utils import ( IterLoader, MultiIterLoader, @@ -219,7 +219,7 @@ class RunnerBase: num_records = sum( [ len(d) - if not type(d) in [wds.DataPipeline, ChainDataset] + if not type(d) in [wds.DataPipeline, WrappedChainDataset] else 0 for d in self.datasets[split_name] ] @@ -503,7 +503,7 @@ class RunnerBase: def _create_loader(dataset, num_workers, bsz, is_train, collate_fn): # create a single dataloader for each split - if isinstance(dataset, ChainDataset) or isinstance( + if isinstance(dataset, WrappedChainDataset) or isinstance( dataset, wds.DataPipeline ): # wds.WebdDataset instance are chained together diff --git a/train_configs/bindgpt4.yaml b/train_configs/bindgpt4.yaml new file mode 100644 index 0000000..1c96669 --- /dev/null +++ b/train_configs/bindgpt4.yaml @@ -0,0 +1,57 @@ +model: + arch: bind_gpt4 + model_type: pretrain_vicuna + freeze_imagebind: True + freeze_qformer: False + + +datasets: + laion: + vis_processor: + train: + name: "imagebind_vision_train" + image_size: 224 + text_processor: + train: + name: "imagebind_caption" + sample_ratio: 115 + cc_sbu: + vis_processor: + train: + name: "imagebind_vision_train" + image_size: 224 + text_processor: + train: + name: "imagebind_caption" + sample_ratio: 14 + + +run: + task: imagebind_qformer_train + # optimizer + lr_sched: "linear_warmup_cosine_lr" + init_lr: 1e-4 + min_lr: 8e-5 + warmup_lr: 1e-6 + + weight_decay: 0.05 + max_epoch: 4 + batch_size_train: 64 + batch_size_eval: 64 + num_workers: 4 + warmup_steps: 5000 + iters_per_epoch: 5000 + + seed: 42 + output_dir: "output/bindgpt4" + + amp: True + resume_ckpt_path: null + + evaluate: False + train_splits: ["train"] + + device: "cuda" + world_size: 1 + dist_url: "env://" + distributed: True \ No newline at end of file