mirror of
https://github.com/Vision-CAIR/MiniGPT-4.git
synced 2025-04-08 12:00:47 +00:00
rename some components for further extension
This commit is contained in:
parent
49f6b84880
commit
d2ba3c48b6
@ -32,10 +32,10 @@ class Registry:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def wrap(builder_cls):
|
def wrap(builder_cls):
|
||||||
from minigpt4.datasets.builders.base_dataset_builder import BaseDatasetBuilder
|
from minigpt4.datasets.builders.image_base_dataset_builder import ImageBaseDatasetBuilder
|
||||||
|
|
||||||
assert issubclass(
|
assert issubclass(
|
||||||
builder_cls, BaseDatasetBuilder
|
builder_cls, ImageBaseDatasetBuilder
|
||||||
), "All builders must inherit BaseDatasetBuilder class, found {}".format(
|
), "All builders must inherit BaseDatasetBuilder class, found {}".format(
|
||||||
builder_cls
|
builder_cls
|
||||||
)
|
)
|
||||||
|
@ -5,7 +5,7 @@ model:
|
|||||||
freeze_imagebind: True
|
freeze_imagebind: True
|
||||||
|
|
||||||
# Q-Former
|
# Q-Former
|
||||||
freeze_qformer: True
|
freeze_qformer: False
|
||||||
num_query_token: 32
|
num_query_token: 32
|
||||||
|
|
||||||
# Vicuna
|
# Vicuna
|
||||||
|
@ -5,18 +5,18 @@
|
|||||||
For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from minigpt4.datasets.builders.base_dataset_builder import load_dataset_config
|
from minigpt4.datasets.builders.image_base_dataset_builder import load_dataset_config
|
||||||
from minigpt4.datasets.builders.image_text_pair_builder import (
|
from minigpt4.datasets.builders.image_text_pair_builder import (
|
||||||
CCSBUBuilder,
|
CCSBUBuilderImage,
|
||||||
LaionBuilder,
|
LaionBuilderImage,
|
||||||
CCSBUAlignBuilder
|
CCSBUAlignBuilderImage
|
||||||
)
|
)
|
||||||
from minigpt4.common.registry import registry
|
from minigpt4.common.registry import registry
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"CCSBUBuilder",
|
"CCSBUBuilderImage",
|
||||||
"LaionBuilder",
|
"LaionBuilderImage",
|
||||||
"CCSBUAlignBuilder"
|
"CCSBUAlignBuilderImage"
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
99
minigpt4/datasets/builders/audio_base_dataset_builder.py
Normal file
99
minigpt4/datasets/builders/audio_base_dataset_builder.py
Normal file
@ -0,0 +1,99 @@
|
|||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
import warnings
|
||||||
|
|
||||||
|
from omegaconf import OmegaConf
|
||||||
|
import torch.distributed as dist
|
||||||
|
from torchvision.datasets.utils import download_url
|
||||||
|
|
||||||
|
import minigpt4.common.utils as utils
|
||||||
|
from minigpt4.common.dist_utils import is_dist_avail_and_initialized, is_main_process
|
||||||
|
from minigpt4.common.registry import registry
|
||||||
|
from minigpt4.datasets.builders import load_dataset_config
|
||||||
|
from minigpt4.processors.base_processor import BaseProcessor
|
||||||
|
|
||||||
|
|
||||||
|
class AudioBaseDatasetBuilder:
|
||||||
|
train_dataset_cls, eval_dataset_cls = None, None
|
||||||
|
|
||||||
|
def __init__(self, cfg=None):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
if cfg is None:
|
||||||
|
# help to create datasets from default config.
|
||||||
|
self.config = load_dataset_config(self.default_config_path())
|
||||||
|
elif isinstance(cfg, str):
|
||||||
|
self.config = load_dataset_config(cfg)
|
||||||
|
else:
|
||||||
|
# when called from task.build_dataset()
|
||||||
|
self.config = cfg
|
||||||
|
|
||||||
|
self.data_type = self.config.data_type
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def default_config_path(cls, type="default"):
|
||||||
|
return utils.get_abs_path(cls.DATASET_CONFIG_DICT[type])
|
||||||
|
|
||||||
|
def _download_data(self):
|
||||||
|
self._download_ann()
|
||||||
|
self._download_aud()
|
||||||
|
|
||||||
|
def _download_ann(self):
|
||||||
|
"""
|
||||||
|
Download annotation files if necessary.
|
||||||
|
All the audio-language datasets should have annotations of unified format.
|
||||||
|
|
||||||
|
storage_path can be:
|
||||||
|
(1) relative/absolute: will be prefixed with env.cache_root to make full path if relative.
|
||||||
|
(2) basename/dirname: will be suffixed with base name of URL if dirname is provided.
|
||||||
|
|
||||||
|
Local annotation paths should be relative.
|
||||||
|
"""
|
||||||
|
anns = self.config.build_info.annotations
|
||||||
|
|
||||||
|
splits = anns.keys()
|
||||||
|
|
||||||
|
cache_root = registry.get_path("cache_root")
|
||||||
|
|
||||||
|
for split in splits:
|
||||||
|
info = anns[split]
|
||||||
|
|
||||||
|
urls, storage_paths = info.get("url", None), info.storage
|
||||||
|
|
||||||
|
if isinstance(urls, str):
|
||||||
|
urls = [urls]
|
||||||
|
if isinstance(storage_paths, str):
|
||||||
|
storage_paths = [storage_paths]
|
||||||
|
|
||||||
|
assert len(urls) == len(storage_paths)
|
||||||
|
|
||||||
|
for url_or_filename, storage_path in zip(urls, storage_paths):
|
||||||
|
# if storage_path is relative, make it full by prefixing with cache_root.
|
||||||
|
if not os.path.isabs(storage_path):
|
||||||
|
storage_path = os.path.join(cache_root, storage_path)
|
||||||
|
|
||||||
|
dirname = os.path.dirname(storage_path)
|
||||||
|
if not os.path.exists(dirname):
|
||||||
|
os.makedirs(dirname)
|
||||||
|
|
||||||
|
if os.path.isfile(url_or_filename):
|
||||||
|
src, dst = url_or_filename, storage_path
|
||||||
|
if not os.path.exists(dst):
|
||||||
|
shutil.copyfile(src=src, dst=dst)
|
||||||
|
else:
|
||||||
|
logging.info("Using existing file {}.".format(dst))
|
||||||
|
else:
|
||||||
|
if os.path.isdir(storage_path):
|
||||||
|
# if only dirname is provided, suffix with basename of URL.
|
||||||
|
raise ValueError(
|
||||||
|
"Expecting storage_path to be a file path, got directory {}".format(
|
||||||
|
storage_path
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
filename = os.path.basename(storage_path)
|
||||||
|
|
||||||
|
download_url(url=url_or_filename, root=dirname, filename=filename)
|
||||||
|
|
||||||
|
|
@ -21,8 +21,7 @@ from minigpt4.common.registry import registry
|
|||||||
from minigpt4.processors.base_processor import BaseProcessor
|
from minigpt4.processors.base_processor import BaseProcessor
|
||||||
|
|
||||||
|
|
||||||
|
class ImageBaseDatasetBuilder:
|
||||||
class BaseDatasetBuilder:
|
|
||||||
train_dataset_cls, eval_dataset_cls = None, None
|
train_dataset_cls, eval_dataset_cls = None, None
|
||||||
|
|
||||||
def __init__(self, cfg=None):
|
def __init__(self, cfg=None):
|
@ -3,13 +3,13 @@ import logging
|
|||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
from minigpt4.common.registry import registry
|
from minigpt4.common.registry import registry
|
||||||
from minigpt4.datasets.builders.base_dataset_builder import BaseDatasetBuilder
|
from minigpt4.datasets.builders.image_base_dataset_builder import ImageBaseDatasetBuilder
|
||||||
from minigpt4.datasets.datasets.laion_dataset import LaionDataset
|
from minigpt4.datasets.datasets.image_caption.laion_dataset import LaionDataset
|
||||||
from minigpt4.datasets.datasets.cc_sbu_dataset import CCSBUDataset, CCSBUAlignDataset
|
from minigpt4.datasets.datasets.image_caption.cc_sbu_dataset import CCSBUDataset, CCSBUAlignDatasetImageImageCaptionDataset
|
||||||
|
|
||||||
|
|
||||||
@registry.register_builder("cc_sbu")
|
@registry.register_builder("cc_sbu")
|
||||||
class CCSBUBuilder(BaseDatasetBuilder):
|
class CCSBUBuilderImage(ImageBaseDatasetBuilder):
|
||||||
train_dataset_cls = CCSBUDataset
|
train_dataset_cls = CCSBUDataset
|
||||||
|
|
||||||
DATASET_CONFIG_DICT = {"default": "configs/datasets/cc_sbu/defaults.yaml"}
|
DATASET_CONFIG_DICT = {"default": "configs/datasets/cc_sbu/defaults.yaml"}
|
||||||
@ -41,7 +41,7 @@ class CCSBUBuilder(BaseDatasetBuilder):
|
|||||||
|
|
||||||
|
|
||||||
@registry.register_builder("laion")
|
@registry.register_builder("laion")
|
||||||
class LaionBuilder(BaseDatasetBuilder):
|
class LaionBuilderImage(ImageBaseDatasetBuilder):
|
||||||
train_dataset_cls = LaionDataset
|
train_dataset_cls = LaionDataset
|
||||||
|
|
||||||
DATASET_CONFIG_DICT = {"default": "configs/datasets/laion/defaults.yaml"}
|
DATASET_CONFIG_DICT = {"default": "configs/datasets/laion/defaults.yaml"}
|
||||||
@ -73,8 +73,8 @@ class LaionBuilder(BaseDatasetBuilder):
|
|||||||
|
|
||||||
|
|
||||||
@registry.register_builder("cc_sbu_align")
|
@registry.register_builder("cc_sbu_align")
|
||||||
class CCSBUAlignBuilder(BaseDatasetBuilder):
|
class CCSBUAlignBuilderImage(ImageBaseDatasetBuilder):
|
||||||
train_dataset_cls = CCSBUAlignDataset
|
train_dataset_cls = CCSBUAlignDatasetImageImageCaptionDataset
|
||||||
|
|
||||||
DATASET_CONFIG_DICT = {
|
DATASET_CONFIG_DICT = {
|
||||||
"default": "configs/datasets/cc_sbu/align.yaml",
|
"default": "configs/datasets/cc_sbu/align.yaml",
|
||||||
|
@ -5,32 +5,44 @@
|
|||||||
For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import gzip
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
|
||||||
import random as rnd
|
|
||||||
import tarfile
|
|
||||||
import zipfile
|
|
||||||
import random
|
import random
|
||||||
from typing import List
|
from typing import List, Iterable
|
||||||
from tqdm import tqdm
|
|
||||||
|
|
||||||
import decord
|
import decord
|
||||||
from decord import VideoReader
|
|
||||||
import webdataset as wds
|
import webdataset as wds
|
||||||
import numpy as np
|
|
||||||
import torch
|
import torch
|
||||||
from torch.utils.data.dataset import IterableDataset
|
from torch.utils.data import IterableDataset, Dataset, ConcatDataset
|
||||||
|
|
||||||
from minigpt4.common.registry import registry
|
from minigpt4.common.registry import registry
|
||||||
from minigpt4.datasets.datasets.base_dataset import ConcatDataset
|
|
||||||
|
|
||||||
|
|
||||||
decord.bridge.set_bridge("torch")
|
decord.bridge.set_bridge("torch")
|
||||||
MAX_INT = registry.get("MAX_INT")
|
MAX_INT = registry.get("MAX_INT")
|
||||||
|
|
||||||
|
|
||||||
class ChainDataset(wds.DataPipeline):
|
class WrappedConcatDataset(ConcatDataset):
|
||||||
|
def __init__(self, datasets: Iterable[Dataset]) -> None:
|
||||||
|
super().__init__(datasets)
|
||||||
|
|
||||||
|
def collater(self, samples):
|
||||||
|
# TODO For now only supports datasets with same underlying collater implementations
|
||||||
|
|
||||||
|
all_keys = set()
|
||||||
|
for s in samples:
|
||||||
|
all_keys.update(s)
|
||||||
|
|
||||||
|
shared_keys = all_keys
|
||||||
|
for s in samples:
|
||||||
|
shared_keys = shared_keys & set(s.keys())
|
||||||
|
|
||||||
|
samples_shared_keys = []
|
||||||
|
for s in samples:
|
||||||
|
samples_shared_keys.append({k: s[k] for k in s.keys() if k in shared_keys})
|
||||||
|
|
||||||
|
return self.datasets[0].collater(samples_shared_keys)
|
||||||
|
|
||||||
|
|
||||||
|
class WrappedChainDataset(wds.DataPipeline):
|
||||||
r"""Dataset for chaining multiple :class:`DataPipeline` s.
|
r"""Dataset for chaining multiple :class:`DataPipeline` s.
|
||||||
|
|
||||||
This class is useful to assemble different existing dataset streams. The
|
This class is useful to assemble different existing dataset streams. The
|
||||||
@ -40,6 +52,7 @@ class ChainDataset(wds.DataPipeline):
|
|||||||
Args:
|
Args:
|
||||||
datasets (iterable of IterableDataset): datasets to be chained together
|
datasets (iterable of IterableDataset): datasets to be chained together
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, datasets: List[wds.DataPipeline]) -> None:
|
def __init__(self, datasets: List[wds.DataPipeline]) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.datasets = datasets
|
self.datasets = datasets
|
||||||
@ -173,7 +186,7 @@ def concat_datasets(datasets):
|
|||||||
# concatenate map-style datasets and iterable-style datasets separately
|
# concatenate map-style datasets and iterable-style datasets separately
|
||||||
if len(iterable_datasets) > 1:
|
if len(iterable_datasets) > 1:
|
||||||
chained_datasets = (
|
chained_datasets = (
|
||||||
ChainDataset(iterable_datasets)
|
WrappedChainDataset(iterable_datasets)
|
||||||
)
|
)
|
||||||
elif len(iterable_datasets) == 1:
|
elif len(iterable_datasets) == 1:
|
||||||
chained_datasets = iterable_datasets[0]
|
chained_datasets = iterable_datasets[0]
|
||||||
@ -181,7 +194,7 @@ def concat_datasets(datasets):
|
|||||||
chained_datasets = None
|
chained_datasets = None
|
||||||
|
|
||||||
concat_datasets = (
|
concat_datasets = (
|
||||||
ConcatDataset(map_datasets) if len(map_datasets) > 0 else None
|
WrappedConcatDataset(map_datasets) if len(map_datasets) > 0 else None
|
||||||
)
|
)
|
||||||
|
|
||||||
train_datasets = concat_datasets, chained_datasets
|
train_datasets = concat_datasets, chained_datasets
|
||||||
@ -193,4 +206,3 @@ def concat_datasets(datasets):
|
|||||||
datasets[split_name] = train_datasets
|
datasets[split_name] = train_datasets
|
||||||
|
|
||||||
return datasets
|
return datasets
|
||||||
|
|
||||||
|
@ -0,0 +1,26 @@
|
|||||||
|
import json
|
||||||
|
|
||||||
|
from torch.utils.data import Dataset, default_collate
|
||||||
|
import webdataset as wds
|
||||||
|
from minigpt4.datasets.datasets.base_dataset import BaseDataset
|
||||||
|
|
||||||
|
|
||||||
|
class GenericAudioDataset(BaseDataset):
|
||||||
|
def __init__(self, vision_processor, text_processor, location):
|
||||||
|
super().__init__(x_processor=vision_processor, text_processor=text_processor)
|
||||||
|
|
||||||
|
self.inner_dataset = wds.DataPipeline(
|
||||||
|
wds.ResampledShards(location),
|
||||||
|
wds.tarfile_to_samples(handler=wds.warn_and_continue),
|
||||||
|
wds.shuffle(1000, handler=wds.warn_and_continue),
|
||||||
|
wds.decode(wds.torch_audio, handler=wds.warn_and_continue),
|
||||||
|
wds.to_tuple("jpg", "json", handler=wds.warn_and_continue),
|
||||||
|
wds.map_tuple(self.x_processor, handler=wds.warn_and_continue),
|
||||||
|
wds.map(self.to_dict, handler=wds.warn_and_continue),
|
||||||
|
)
|
||||||
|
|
||||||
|
def to_dict(self, sample):
|
||||||
|
return {
|
||||||
|
"image": sample[0],
|
||||||
|
"text_input": self.text_processor(sample[1]["caption"]),
|
||||||
|
}
|
@ -8,25 +8,25 @@
|
|||||||
import json
|
import json
|
||||||
from typing import Iterable
|
from typing import Iterable
|
||||||
|
|
||||||
from torch.utils.data import Dataset, ConcatDataset
|
from torch.utils.data import Dataset
|
||||||
from torch.utils.data.dataloader import default_collate
|
from torch.utils.data.dataloader import default_collate
|
||||||
|
|
||||||
|
|
||||||
class BaseDataset(Dataset):
|
class BaseDataset(Dataset):
|
||||||
def __init__(
|
def __init__(
|
||||||
self, vis_processor=None, text_processor=None, vis_root=None, ann_paths=[]
|
self, x_processor=None, text_processor=None, x_root=None, ann_paths=[]
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
vis_root (string): Root directory of images (e.g. coco/images/)
|
x_root (string): Root directory of data in modality X (e.g. coco/images/, etc.)
|
||||||
ann_root (string): directory to store the annotation file
|
ann_root (string): directory to store the annotation file
|
||||||
"""
|
"""
|
||||||
self.vis_root = vis_root
|
self.x_root = x_root
|
||||||
|
|
||||||
self.annotation = []
|
self.annotation = []
|
||||||
for ann_path in ann_paths:
|
for ann_path in ann_paths:
|
||||||
self.annotation.extend(json.load(open(ann_path, "r"))['annotations'])
|
self.annotation.extend(json.load(open(ann_path, "r"))['annotations'])
|
||||||
|
|
||||||
self.vis_processor = vis_processor
|
self.x_processor = x_processor
|
||||||
self.text_processor = text_processor
|
self.text_processor = text_processor
|
||||||
|
|
||||||
self._add_instance_ids()
|
self._add_instance_ids()
|
||||||
@ -37,32 +37,11 @@ class BaseDataset(Dataset):
|
|||||||
def collater(self, samples):
|
def collater(self, samples):
|
||||||
return default_collate(samples)
|
return default_collate(samples)
|
||||||
|
|
||||||
def set_processors(self, vis_processor, text_processor):
|
def set_processors(self, x_processor, text_processor):
|
||||||
self.vis_processor = vis_processor
|
self.x_processor = x_processor
|
||||||
self.text_processor = text_processor
|
self.text_processor = text_processor
|
||||||
|
|
||||||
def _add_instance_ids(self, key="instance_id"):
|
def _add_instance_ids(self, key="instance_id"):
|
||||||
for idx, ann in enumerate(self.annotation):
|
for idx, ann in enumerate(self.annotation):
|
||||||
ann[key] = str(idx)
|
ann[key] = str(idx)
|
||||||
|
|
||||||
|
|
||||||
class ConcatDataset(ConcatDataset):
|
|
||||||
def __init__(self, datasets: Iterable[Dataset]) -> None:
|
|
||||||
super().__init__(datasets)
|
|
||||||
|
|
||||||
def collater(self, samples):
|
|
||||||
# TODO For now only supports datasets with same underlying collater implementations
|
|
||||||
|
|
||||||
all_keys = set()
|
|
||||||
for s in samples:
|
|
||||||
all_keys.update(s)
|
|
||||||
|
|
||||||
shared_keys = all_keys
|
|
||||||
for s in samples:
|
|
||||||
shared_keys = shared_keys & set(s.keys())
|
|
||||||
|
|
||||||
samples_shared_keys = []
|
|
||||||
for s in samples:
|
|
||||||
samples_shared_keys.append({k: s[k] for k in s.keys() if k in shared_keys})
|
|
||||||
|
|
||||||
return self.datasets[0].collater(samples_shared_keys)
|
|
||||||
|
@ -2,12 +2,12 @@ import os
|
|||||||
from PIL import Image
|
from PIL import Image
|
||||||
import webdataset as wds
|
import webdataset as wds
|
||||||
from minigpt4.datasets.datasets.base_dataset import BaseDataset
|
from minigpt4.datasets.datasets.base_dataset import BaseDataset
|
||||||
from minigpt4.datasets.datasets.caption_datasets import CaptionDataset
|
from minigpt4.datasets.datasets.image_caption.image_caption_datasets import ImageCaptionDataset
|
||||||
|
|
||||||
|
|
||||||
class CCSBUDataset(BaseDataset):
|
class CCSBUDataset(BaseDataset):
|
||||||
def __init__(self, vis_processor, text_processor, location):
|
def __init__(self, vision_processor, text_processor, location):
|
||||||
super().__init__(vis_processor=vis_processor, text_processor=text_processor)
|
super().__init__(x_processor=vision_processor, text_processor=text_processor)
|
||||||
|
|
||||||
self.inner_dataset = wds.DataPipeline(
|
self.inner_dataset = wds.DataPipeline(
|
||||||
wds.ResampledShards(location),
|
wds.ResampledShards(location),
|
||||||
@ -15,7 +15,7 @@ class CCSBUDataset(BaseDataset):
|
|||||||
wds.shuffle(1000, handler=wds.warn_and_continue),
|
wds.shuffle(1000, handler=wds.warn_and_continue),
|
||||||
wds.decode("pilrgb", handler=wds.warn_and_continue),
|
wds.decode("pilrgb", handler=wds.warn_and_continue),
|
||||||
wds.to_tuple("jpg", "json", handler=wds.warn_and_continue),
|
wds.to_tuple("jpg", "json", handler=wds.warn_and_continue),
|
||||||
wds.map_tuple(self.vis_processor, handler=wds.warn_and_continue),
|
wds.map_tuple(self.x_processor, handler=wds.warn_and_continue),
|
||||||
wds.map(self.to_dict, handler=wds.warn_and_continue),
|
wds.map(self.to_dict, handler=wds.warn_and_continue),
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -26,7 +26,7 @@ class CCSBUDataset(BaseDataset):
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class CCSBUAlignDataset(CaptionDataset):
|
class CCSBUAlignDatasetImageImageCaptionDataset(ImageCaptionDataset):
|
||||||
|
|
||||||
def __getitem__(self, index):
|
def __getitem__(self, index):
|
||||||
|
|
||||||
@ -34,10 +34,10 @@ class CCSBUAlignDataset(CaptionDataset):
|
|||||||
ann = self.annotation[index]
|
ann = self.annotation[index]
|
||||||
|
|
||||||
img_file = '{}.jpg'.format(ann["image_id"])
|
img_file = '{}.jpg'.format(ann["image_id"])
|
||||||
image_path = os.path.join(self.vis_root, img_file)
|
image_path = os.path.join(self.x_root, img_file)
|
||||||
image = Image.open(image_path).convert("RGB")
|
image = Image.open(image_path).convert("RGB")
|
||||||
|
|
||||||
image = self.vis_processor(image)
|
image = self.x_processor(image)
|
||||||
caption = ann["caption"]
|
caption = ann["caption"]
|
||||||
|
|
||||||
return {
|
return {
|
@ -6,32 +6,20 @@
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
from collections import OrderedDict
|
|
||||||
|
|
||||||
from minigpt4.datasets.datasets.base_dataset import BaseDataset
|
from minigpt4.datasets.datasets.base_dataset import BaseDataset
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
|
from minigpt4.datasets.datasets.mixins.mixins import __ImageDisplMixin
|
||||||
class __DisplMixin:
|
|
||||||
def displ_item(self, index):
|
|
||||||
sample, ann = self.__getitem__(index), self.annotation[index]
|
|
||||||
|
|
||||||
return OrderedDict(
|
|
||||||
{
|
|
||||||
"file": ann["image"],
|
|
||||||
"caption": ann["caption"],
|
|
||||||
"image": sample["image"],
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class CaptionDataset(BaseDataset, __DisplMixin):
|
class ImageCaptionDataset(BaseDataset, __ImageDisplMixin):
|
||||||
def __init__(self, vis_processor, text_processor, vis_root, ann_paths):
|
def __init__(self, vision_processor, text_processor, x_root, ann_paths):
|
||||||
"""
|
"""
|
||||||
vis_root (string): Root directory of images (e.g. coco/images/)
|
vis_root (string): Root directory of images (e.g. coco/images/)
|
||||||
ann_root (string): directory to store the annotation file
|
ann_root (string): directory to store the annotation file
|
||||||
"""
|
"""
|
||||||
super().__init__(vis_processor, text_processor, vis_root, ann_paths)
|
super().__init__(vision_processor, text_processor, x_root, ann_paths)
|
||||||
|
|
||||||
self.img_ids = {}
|
self.img_ids = {}
|
||||||
n = 0
|
n = 0
|
||||||
@ -47,10 +35,10 @@ class CaptionDataset(BaseDataset, __DisplMixin):
|
|||||||
ann = self.annotation[index]
|
ann = self.annotation[index]
|
||||||
|
|
||||||
img_file = '{:0>12}.jpg'.format(ann["image_id"])
|
img_file = '{:0>12}.jpg'.format(ann["image_id"])
|
||||||
image_path = os.path.join(self.vis_root, img_file)
|
image_path = os.path.join(self.x_root, img_file)
|
||||||
image = Image.open(image_path).convert("RGB")
|
image = Image.open(image_path).convert("RGB")
|
||||||
|
|
||||||
image = self.vis_processor(image)
|
image = self.x_processor(image)
|
||||||
caption = self.text_processor(ann["caption"])
|
caption = self.text_processor(ann["caption"])
|
||||||
|
|
||||||
return {
|
return {
|
||||||
@ -60,23 +48,23 @@ class CaptionDataset(BaseDataset, __DisplMixin):
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class CaptionEvalDataset(BaseDataset, __DisplMixin):
|
class CaptionEvalDataset(BaseDataset, __ImageDisplMixin):
|
||||||
def __init__(self, vis_processor, text_processor, vis_root, ann_paths):
|
def __init__(self, vision_processor, text_processor, x_root, ann_paths):
|
||||||
"""
|
"""
|
||||||
vis_root (string): Root directory of images (e.g. coco/images/)
|
vis_root (string): Root directory of images (e.g. coco/images/)
|
||||||
ann_root (string): directory to store the annotation file
|
ann_root (string): directory to store the annotation file
|
||||||
split (string): val or test
|
split (string): val or test
|
||||||
"""
|
"""
|
||||||
super().__init__(vis_processor, text_processor, vis_root, ann_paths)
|
super().__init__(vision_processor, text_processor, x_root, ann_paths)
|
||||||
|
|
||||||
def __getitem__(self, index):
|
def __getitem__(self, index):
|
||||||
|
|
||||||
ann = self.annotation[index]
|
ann = self.annotation[index]
|
||||||
|
|
||||||
image_path = os.path.join(self.vis_root, ann["image"])
|
image_path = os.path.join(self.x_root, ann["image"])
|
||||||
image = Image.open(image_path).convert("RGB")
|
image = Image.open(image_path).convert("RGB")
|
||||||
|
|
||||||
image = self.vis_processor(image)
|
image = self.x_processor(image)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"image": image,
|
"image": image,
|
@ -10,8 +10,8 @@ from minigpt4.datasets.datasets.base_dataset import BaseDataset
|
|||||||
|
|
||||||
|
|
||||||
class LaionDataset(BaseDataset):
|
class LaionDataset(BaseDataset):
|
||||||
def __init__(self, vis_processor, text_processor, location):
|
def __init__(self, vision_processor, text_processor, location):
|
||||||
super().__init__(vis_processor=vis_processor, text_processor=text_processor)
|
super().__init__(x_processor=vision_processor, text_processor=text_processor)
|
||||||
|
|
||||||
self.inner_dataset = wds.DataPipeline(
|
self.inner_dataset = wds.DataPipeline(
|
||||||
wds.ResampledShards(location),
|
wds.ResampledShards(location),
|
||||||
@ -19,7 +19,7 @@ class LaionDataset(BaseDataset):
|
|||||||
wds.shuffle(1000, handler=wds.warn_and_continue),
|
wds.shuffle(1000, handler=wds.warn_and_continue),
|
||||||
wds.decode("pilrgb", handler=wds.warn_and_continue),
|
wds.decode("pilrgb", handler=wds.warn_and_continue),
|
||||||
wds.to_tuple("jpg", "json", handler=wds.warn_and_continue),
|
wds.to_tuple("jpg", "json", handler=wds.warn_and_continue),
|
||||||
wds.map_tuple(self.vis_processor, handler=wds.warn_and_continue),
|
wds.map_tuple(self.x_processor, handler=wds.warn_and_continue),
|
||||||
wds.map(self.to_dict, handler=wds.warn_and_continue),
|
wds.map(self.to_dict, handler=wds.warn_and_continue),
|
||||||
)
|
)
|
||||||
|
|
0
minigpt4/datasets/datasets/mixins/__init__.py
Normal file
0
minigpt4/datasets/datasets/mixins/__init__.py
Normal file
30
minigpt4/datasets/datasets/mixins/mixins.py
Normal file
30
minigpt4/datasets/datasets/mixins/mixins.py
Normal file
@ -0,0 +1,30 @@
|
|||||||
|
from collections import OrderedDict
|
||||||
|
|
||||||
|
|
||||||
|
class __ImageDisplMixin:
|
||||||
|
def displ_item(self, index):
|
||||||
|
sample, ann = self.__getitem__(index), self.annotation[index]
|
||||||
|
|
||||||
|
return OrderedDict(
|
||||||
|
{
|
||||||
|
"file": ann["image"],
|
||||||
|
"caption": ann["caption"],
|
||||||
|
"image": sample["image"],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class __AudioDisplMixin:
|
||||||
|
def displ_item(self, index):
|
||||||
|
sample, ann = self.__getitem__(index), self.annotation[index]
|
||||||
|
|
||||||
|
# TODO: Finish the Audio Display Mixin
|
||||||
|
'''
|
||||||
|
return OrderedDict(
|
||||||
|
{
|
||||||
|
}
|
||||||
|
)
|
||||||
|
'''
|
||||||
|
|
||||||
|
raise NotImplementedError
|
||||||
|
|
@ -146,3 +146,4 @@ class ImageBindVisionEvalProcessor(ImageBindVisionBaseProcessor):
|
|||||||
std = cfg.get("std", None)
|
std = cfg.get("std", None)
|
||||||
|
|
||||||
return cls(image_size=image_size, mean=mean, std=std)
|
return cls(image_size=image_size, mean=mean, std=std)
|
||||||
|
|
||||||
|
@ -24,7 +24,7 @@ from minigpt4.common.dist_utils import (
|
|||||||
)
|
)
|
||||||
from minigpt4.common.registry import registry
|
from minigpt4.common.registry import registry
|
||||||
from minigpt4.common.utils import is_url
|
from minigpt4.common.utils import is_url
|
||||||
from minigpt4.datasets.data_utils import concat_datasets, reorg_datasets_by_split, ChainDataset
|
from minigpt4.datasets.data_utils import concat_datasets, reorg_datasets_by_split, WrappedChainDataset
|
||||||
from minigpt4.datasets.datasets.dataloader_utils import (
|
from minigpt4.datasets.datasets.dataloader_utils import (
|
||||||
IterLoader,
|
IterLoader,
|
||||||
MultiIterLoader,
|
MultiIterLoader,
|
||||||
@ -219,7 +219,7 @@ class RunnerBase:
|
|||||||
num_records = sum(
|
num_records = sum(
|
||||||
[
|
[
|
||||||
len(d)
|
len(d)
|
||||||
if not type(d) in [wds.DataPipeline, ChainDataset]
|
if not type(d) in [wds.DataPipeline, WrappedChainDataset]
|
||||||
else 0
|
else 0
|
||||||
for d in self.datasets[split_name]
|
for d in self.datasets[split_name]
|
||||||
]
|
]
|
||||||
@ -503,7 +503,7 @@ class RunnerBase:
|
|||||||
|
|
||||||
def _create_loader(dataset, num_workers, bsz, is_train, collate_fn):
|
def _create_loader(dataset, num_workers, bsz, is_train, collate_fn):
|
||||||
# create a single dataloader for each split
|
# create a single dataloader for each split
|
||||||
if isinstance(dataset, ChainDataset) or isinstance(
|
if isinstance(dataset, WrappedChainDataset) or isinstance(
|
||||||
dataset, wds.DataPipeline
|
dataset, wds.DataPipeline
|
||||||
):
|
):
|
||||||
# wds.WebdDataset instance are chained together
|
# wds.WebdDataset instance are chained together
|
||||||
|
57
train_configs/bindgpt4.yaml
Normal file
57
train_configs/bindgpt4.yaml
Normal file
@ -0,0 +1,57 @@
|
|||||||
|
model:
|
||||||
|
arch: bind_gpt4
|
||||||
|
model_type: pretrain_vicuna
|
||||||
|
freeze_imagebind: True
|
||||||
|
freeze_qformer: False
|
||||||
|
|
||||||
|
|
||||||
|
datasets:
|
||||||
|
laion:
|
||||||
|
vis_processor:
|
||||||
|
train:
|
||||||
|
name: "imagebind_vision_train"
|
||||||
|
image_size: 224
|
||||||
|
text_processor:
|
||||||
|
train:
|
||||||
|
name: "imagebind_caption"
|
||||||
|
sample_ratio: 115
|
||||||
|
cc_sbu:
|
||||||
|
vis_processor:
|
||||||
|
train:
|
||||||
|
name: "imagebind_vision_train"
|
||||||
|
image_size: 224
|
||||||
|
text_processor:
|
||||||
|
train:
|
||||||
|
name: "imagebind_caption"
|
||||||
|
sample_ratio: 14
|
||||||
|
|
||||||
|
|
||||||
|
run:
|
||||||
|
task: imagebind_qformer_train
|
||||||
|
# optimizer
|
||||||
|
lr_sched: "linear_warmup_cosine_lr"
|
||||||
|
init_lr: 1e-4
|
||||||
|
min_lr: 8e-5
|
||||||
|
warmup_lr: 1e-6
|
||||||
|
|
||||||
|
weight_decay: 0.05
|
||||||
|
max_epoch: 4
|
||||||
|
batch_size_train: 64
|
||||||
|
batch_size_eval: 64
|
||||||
|
num_workers: 4
|
||||||
|
warmup_steps: 5000
|
||||||
|
iters_per_epoch: 5000
|
||||||
|
|
||||||
|
seed: 42
|
||||||
|
output_dir: "output/bindgpt4"
|
||||||
|
|
||||||
|
amp: True
|
||||||
|
resume_ckpt_path: null
|
||||||
|
|
||||||
|
evaluate: False
|
||||||
|
train_splits: ["train"]
|
||||||
|
|
||||||
|
device: "cuda"
|
||||||
|
world_size: 1
|
||||||
|
dist_url: "env://"
|
||||||
|
distributed: True
|
Loading…
Reference in New Issue
Block a user