mirror of
https://github.com/Vision-CAIR/MiniGPT-4.git
synced 2025-04-05 02:20:47 +00:00
rename some components for further extension
This commit is contained in:
parent
49f6b84880
commit
d2ba3c48b6
@ -32,10 +32,10 @@ class Registry:
|
||||
"""
|
||||
|
||||
def wrap(builder_cls):
|
||||
from minigpt4.datasets.builders.base_dataset_builder import BaseDatasetBuilder
|
||||
from minigpt4.datasets.builders.image_base_dataset_builder import ImageBaseDatasetBuilder
|
||||
|
||||
assert issubclass(
|
||||
builder_cls, BaseDatasetBuilder
|
||||
builder_cls, ImageBaseDatasetBuilder
|
||||
), "All builders must inherit BaseDatasetBuilder class, found {}".format(
|
||||
builder_cls
|
||||
)
|
||||
|
@ -5,7 +5,7 @@ model:
|
||||
freeze_imagebind: True
|
||||
|
||||
# Q-Former
|
||||
freeze_qformer: True
|
||||
freeze_qformer: False
|
||||
num_query_token: 32
|
||||
|
||||
# Vicuna
|
||||
|
@ -5,18 +5,18 @@
|
||||
For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
||||
"""
|
||||
|
||||
from minigpt4.datasets.builders.base_dataset_builder import load_dataset_config
|
||||
from minigpt4.datasets.builders.image_base_dataset_builder import load_dataset_config
|
||||
from minigpt4.datasets.builders.image_text_pair_builder import (
|
||||
CCSBUBuilder,
|
||||
LaionBuilder,
|
||||
CCSBUAlignBuilder
|
||||
CCSBUBuilderImage,
|
||||
LaionBuilderImage,
|
||||
CCSBUAlignBuilderImage
|
||||
)
|
||||
from minigpt4.common.registry import registry
|
||||
|
||||
__all__ = [
|
||||
"CCSBUBuilder",
|
||||
"LaionBuilder",
|
||||
"CCSBUAlignBuilder"
|
||||
"CCSBUBuilderImage",
|
||||
"LaionBuilderImage",
|
||||
"CCSBUAlignBuilderImage"
|
||||
]
|
||||
|
||||
|
||||
|
99
minigpt4/datasets/builders/audio_base_dataset_builder.py
Normal file
99
minigpt4/datasets/builders/audio_base_dataset_builder.py
Normal file
@ -0,0 +1,99 @@
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
import warnings
|
||||
|
||||
from omegaconf import OmegaConf
|
||||
import torch.distributed as dist
|
||||
from torchvision.datasets.utils import download_url
|
||||
|
||||
import minigpt4.common.utils as utils
|
||||
from minigpt4.common.dist_utils import is_dist_avail_and_initialized, is_main_process
|
||||
from minigpt4.common.registry import registry
|
||||
from minigpt4.datasets.builders import load_dataset_config
|
||||
from minigpt4.processors.base_processor import BaseProcessor
|
||||
|
||||
|
||||
class AudioBaseDatasetBuilder:
|
||||
train_dataset_cls, eval_dataset_cls = None, None
|
||||
|
||||
def __init__(self, cfg=None):
|
||||
super().__init__()
|
||||
|
||||
if cfg is None:
|
||||
# help to create datasets from default config.
|
||||
self.config = load_dataset_config(self.default_config_path())
|
||||
elif isinstance(cfg, str):
|
||||
self.config = load_dataset_config(cfg)
|
||||
else:
|
||||
# when called from task.build_dataset()
|
||||
self.config = cfg
|
||||
|
||||
self.data_type = self.config.data_type
|
||||
|
||||
@classmethod
|
||||
def default_config_path(cls, type="default"):
|
||||
return utils.get_abs_path(cls.DATASET_CONFIG_DICT[type])
|
||||
|
||||
def _download_data(self):
|
||||
self._download_ann()
|
||||
self._download_aud()
|
||||
|
||||
def _download_ann(self):
|
||||
"""
|
||||
Download annotation files if necessary.
|
||||
All the audio-language datasets should have annotations of unified format.
|
||||
|
||||
storage_path can be:
|
||||
(1) relative/absolute: will be prefixed with env.cache_root to make full path if relative.
|
||||
(2) basename/dirname: will be suffixed with base name of URL if dirname is provided.
|
||||
|
||||
Local annotation paths should be relative.
|
||||
"""
|
||||
anns = self.config.build_info.annotations
|
||||
|
||||
splits = anns.keys()
|
||||
|
||||
cache_root = registry.get_path("cache_root")
|
||||
|
||||
for split in splits:
|
||||
info = anns[split]
|
||||
|
||||
urls, storage_paths = info.get("url", None), info.storage
|
||||
|
||||
if isinstance(urls, str):
|
||||
urls = [urls]
|
||||
if isinstance(storage_paths, str):
|
||||
storage_paths = [storage_paths]
|
||||
|
||||
assert len(urls) == len(storage_paths)
|
||||
|
||||
for url_or_filename, storage_path in zip(urls, storage_paths):
|
||||
# if storage_path is relative, make it full by prefixing with cache_root.
|
||||
if not os.path.isabs(storage_path):
|
||||
storage_path = os.path.join(cache_root, storage_path)
|
||||
|
||||
dirname = os.path.dirname(storage_path)
|
||||
if not os.path.exists(dirname):
|
||||
os.makedirs(dirname)
|
||||
|
||||
if os.path.isfile(url_or_filename):
|
||||
src, dst = url_or_filename, storage_path
|
||||
if not os.path.exists(dst):
|
||||
shutil.copyfile(src=src, dst=dst)
|
||||
else:
|
||||
logging.info("Using existing file {}.".format(dst))
|
||||
else:
|
||||
if os.path.isdir(storage_path):
|
||||
# if only dirname is provided, suffix with basename of URL.
|
||||
raise ValueError(
|
||||
"Expecting storage_path to be a file path, got directory {}".format(
|
||||
storage_path
|
||||
)
|
||||
)
|
||||
else:
|
||||
filename = os.path.basename(storage_path)
|
||||
|
||||
download_url(url=url_or_filename, root=dirname, filename=filename)
|
||||
|
||||
|
@ -21,8 +21,7 @@ from minigpt4.common.registry import registry
|
||||
from minigpt4.processors.base_processor import BaseProcessor
|
||||
|
||||
|
||||
|
||||
class BaseDatasetBuilder:
|
||||
class ImageBaseDatasetBuilder:
|
||||
train_dataset_cls, eval_dataset_cls = None, None
|
||||
|
||||
def __init__(self, cfg=None):
|
@ -3,13 +3,13 @@ import logging
|
||||
import warnings
|
||||
|
||||
from minigpt4.common.registry import registry
|
||||
from minigpt4.datasets.builders.base_dataset_builder import BaseDatasetBuilder
|
||||
from minigpt4.datasets.datasets.laion_dataset import LaionDataset
|
||||
from minigpt4.datasets.datasets.cc_sbu_dataset import CCSBUDataset, CCSBUAlignDataset
|
||||
from minigpt4.datasets.builders.image_base_dataset_builder import ImageBaseDatasetBuilder
|
||||
from minigpt4.datasets.datasets.image_caption.laion_dataset import LaionDataset
|
||||
from minigpt4.datasets.datasets.image_caption.cc_sbu_dataset import CCSBUDataset, CCSBUAlignDatasetImageImageCaptionDataset
|
||||
|
||||
|
||||
@registry.register_builder("cc_sbu")
|
||||
class CCSBUBuilder(BaseDatasetBuilder):
|
||||
class CCSBUBuilderImage(ImageBaseDatasetBuilder):
|
||||
train_dataset_cls = CCSBUDataset
|
||||
|
||||
DATASET_CONFIG_DICT = {"default": "configs/datasets/cc_sbu/defaults.yaml"}
|
||||
@ -41,7 +41,7 @@ class CCSBUBuilder(BaseDatasetBuilder):
|
||||
|
||||
|
||||
@registry.register_builder("laion")
|
||||
class LaionBuilder(BaseDatasetBuilder):
|
||||
class LaionBuilderImage(ImageBaseDatasetBuilder):
|
||||
train_dataset_cls = LaionDataset
|
||||
|
||||
DATASET_CONFIG_DICT = {"default": "configs/datasets/laion/defaults.yaml"}
|
||||
@ -73,8 +73,8 @@ class LaionBuilder(BaseDatasetBuilder):
|
||||
|
||||
|
||||
@registry.register_builder("cc_sbu_align")
|
||||
class CCSBUAlignBuilder(BaseDatasetBuilder):
|
||||
train_dataset_cls = CCSBUAlignDataset
|
||||
class CCSBUAlignBuilderImage(ImageBaseDatasetBuilder):
|
||||
train_dataset_cls = CCSBUAlignDatasetImageImageCaptionDataset
|
||||
|
||||
DATASET_CONFIG_DICT = {
|
||||
"default": "configs/datasets/cc_sbu/align.yaml",
|
||||
|
@ -5,32 +5,44 @@
|
||||
For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
||||
"""
|
||||
|
||||
import gzip
|
||||
import logging
|
||||
import os
|
||||
import random as rnd
|
||||
import tarfile
|
||||
import zipfile
|
||||
import random
|
||||
from typing import List
|
||||
from tqdm import tqdm
|
||||
from typing import List, Iterable
|
||||
|
||||
import decord
|
||||
from decord import VideoReader
|
||||
import webdataset as wds
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.data.dataset import IterableDataset
|
||||
from torch.utils.data import IterableDataset, Dataset, ConcatDataset
|
||||
|
||||
from minigpt4.common.registry import registry
|
||||
from minigpt4.datasets.datasets.base_dataset import ConcatDataset
|
||||
|
||||
|
||||
decord.bridge.set_bridge("torch")
|
||||
MAX_INT = registry.get("MAX_INT")
|
||||
|
||||
|
||||
class ChainDataset(wds.DataPipeline):
|
||||
class WrappedConcatDataset(ConcatDataset):
|
||||
def __init__(self, datasets: Iterable[Dataset]) -> None:
|
||||
super().__init__(datasets)
|
||||
|
||||
def collater(self, samples):
|
||||
# TODO For now only supports datasets with same underlying collater implementations
|
||||
|
||||
all_keys = set()
|
||||
for s in samples:
|
||||
all_keys.update(s)
|
||||
|
||||
shared_keys = all_keys
|
||||
for s in samples:
|
||||
shared_keys = shared_keys & set(s.keys())
|
||||
|
||||
samples_shared_keys = []
|
||||
for s in samples:
|
||||
samples_shared_keys.append({k: s[k] for k in s.keys() if k in shared_keys})
|
||||
|
||||
return self.datasets[0].collater(samples_shared_keys)
|
||||
|
||||
|
||||
class WrappedChainDataset(wds.DataPipeline):
|
||||
r"""Dataset for chaining multiple :class:`DataPipeline` s.
|
||||
|
||||
This class is useful to assemble different existing dataset streams. The
|
||||
@ -40,6 +52,7 @@ class ChainDataset(wds.DataPipeline):
|
||||
Args:
|
||||
datasets (iterable of IterableDataset): datasets to be chained together
|
||||
"""
|
||||
|
||||
def __init__(self, datasets: List[wds.DataPipeline]) -> None:
|
||||
super().__init__()
|
||||
self.datasets = datasets
|
||||
@ -149,7 +162,7 @@ def concat_datasets(datasets):
|
||||
for split_name in datasets:
|
||||
if split_name != "train":
|
||||
assert (
|
||||
len(datasets[split_name]) == 1
|
||||
len(datasets[split_name]) == 1
|
||||
), "Do not support multiple {} datasets.".format(split_name)
|
||||
datasets[split_name] = datasets[split_name][0]
|
||||
else:
|
||||
@ -173,7 +186,7 @@ def concat_datasets(datasets):
|
||||
# concatenate map-style datasets and iterable-style datasets separately
|
||||
if len(iterable_datasets) > 1:
|
||||
chained_datasets = (
|
||||
ChainDataset(iterable_datasets)
|
||||
WrappedChainDataset(iterable_datasets)
|
||||
)
|
||||
elif len(iterable_datasets) == 1:
|
||||
chained_datasets = iterable_datasets[0]
|
||||
@ -181,7 +194,7 @@ def concat_datasets(datasets):
|
||||
chained_datasets = None
|
||||
|
||||
concat_datasets = (
|
||||
ConcatDataset(map_datasets) if len(map_datasets) > 0 else None
|
||||
WrappedConcatDataset(map_datasets) if len(map_datasets) > 0 else None
|
||||
)
|
||||
|
||||
train_datasets = concat_datasets, chained_datasets
|
||||
@ -193,4 +206,3 @@ def concat_datasets(datasets):
|
||||
datasets[split_name] = train_datasets
|
||||
|
||||
return datasets
|
||||
|
||||
|
@ -0,0 +1,26 @@
|
||||
import json
|
||||
|
||||
from torch.utils.data import Dataset, default_collate
|
||||
import webdataset as wds
|
||||
from minigpt4.datasets.datasets.base_dataset import BaseDataset
|
||||
|
||||
|
||||
class GenericAudioDataset(BaseDataset):
|
||||
def __init__(self, vision_processor, text_processor, location):
|
||||
super().__init__(x_processor=vision_processor, text_processor=text_processor)
|
||||
|
||||
self.inner_dataset = wds.DataPipeline(
|
||||
wds.ResampledShards(location),
|
||||
wds.tarfile_to_samples(handler=wds.warn_and_continue),
|
||||
wds.shuffle(1000, handler=wds.warn_and_continue),
|
||||
wds.decode(wds.torch_audio, handler=wds.warn_and_continue),
|
||||
wds.to_tuple("jpg", "json", handler=wds.warn_and_continue),
|
||||
wds.map_tuple(self.x_processor, handler=wds.warn_and_continue),
|
||||
wds.map(self.to_dict, handler=wds.warn_and_continue),
|
||||
)
|
||||
|
||||
def to_dict(self, sample):
|
||||
return {
|
||||
"image": sample[0],
|
||||
"text_input": self.text_processor(sample[1]["caption"]),
|
||||
}
|
@ -8,25 +8,25 @@
|
||||
import json
|
||||
from typing import Iterable
|
||||
|
||||
from torch.utils.data import Dataset, ConcatDataset
|
||||
from torch.utils.data import Dataset
|
||||
from torch.utils.data.dataloader import default_collate
|
||||
|
||||
|
||||
class BaseDataset(Dataset):
|
||||
def __init__(
|
||||
self, vis_processor=None, text_processor=None, vis_root=None, ann_paths=[]
|
||||
self, x_processor=None, text_processor=None, x_root=None, ann_paths=[]
|
||||
):
|
||||
"""
|
||||
vis_root (string): Root directory of images (e.g. coco/images/)
|
||||
x_root (string): Root directory of data in modality X (e.g. coco/images/, etc.)
|
||||
ann_root (string): directory to store the annotation file
|
||||
"""
|
||||
self.vis_root = vis_root
|
||||
self.x_root = x_root
|
||||
|
||||
self.annotation = []
|
||||
for ann_path in ann_paths:
|
||||
self.annotation.extend(json.load(open(ann_path, "r"))['annotations'])
|
||||
|
||||
self.vis_processor = vis_processor
|
||||
self.x_processor = x_processor
|
||||
self.text_processor = text_processor
|
||||
|
||||
self._add_instance_ids()
|
||||
@ -37,32 +37,11 @@ class BaseDataset(Dataset):
|
||||
def collater(self, samples):
|
||||
return default_collate(samples)
|
||||
|
||||
def set_processors(self, vis_processor, text_processor):
|
||||
self.vis_processor = vis_processor
|
||||
def set_processors(self, x_processor, text_processor):
|
||||
self.x_processor = x_processor
|
||||
self.text_processor = text_processor
|
||||
|
||||
def _add_instance_ids(self, key="instance_id"):
|
||||
for idx, ann in enumerate(self.annotation):
|
||||
ann[key] = str(idx)
|
||||
|
||||
|
||||
class ConcatDataset(ConcatDataset):
|
||||
def __init__(self, datasets: Iterable[Dataset]) -> None:
|
||||
super().__init__(datasets)
|
||||
|
||||
def collater(self, samples):
|
||||
# TODO For now only supports datasets with same underlying collater implementations
|
||||
|
||||
all_keys = set()
|
||||
for s in samples:
|
||||
all_keys.update(s)
|
||||
|
||||
shared_keys = all_keys
|
||||
for s in samples:
|
||||
shared_keys = shared_keys & set(s.keys())
|
||||
|
||||
samples_shared_keys = []
|
||||
for s in samples:
|
||||
samples_shared_keys.append({k: s[k] for k in s.keys() if k in shared_keys})
|
||||
|
||||
return self.datasets[0].collater(samples_shared_keys)
|
||||
|
@ -2,12 +2,12 @@ import os
|
||||
from PIL import Image
|
||||
import webdataset as wds
|
||||
from minigpt4.datasets.datasets.base_dataset import BaseDataset
|
||||
from minigpt4.datasets.datasets.caption_datasets import CaptionDataset
|
||||
from minigpt4.datasets.datasets.image_caption.image_caption_datasets import ImageCaptionDataset
|
||||
|
||||
|
||||
class CCSBUDataset(BaseDataset):
|
||||
def __init__(self, vis_processor, text_processor, location):
|
||||
super().__init__(vis_processor=vis_processor, text_processor=text_processor)
|
||||
def __init__(self, vision_processor, text_processor, location):
|
||||
super().__init__(x_processor=vision_processor, text_processor=text_processor)
|
||||
|
||||
self.inner_dataset = wds.DataPipeline(
|
||||
wds.ResampledShards(location),
|
||||
@ -15,7 +15,7 @@ class CCSBUDataset(BaseDataset):
|
||||
wds.shuffle(1000, handler=wds.warn_and_continue),
|
||||
wds.decode("pilrgb", handler=wds.warn_and_continue),
|
||||
wds.to_tuple("jpg", "json", handler=wds.warn_and_continue),
|
||||
wds.map_tuple(self.vis_processor, handler=wds.warn_and_continue),
|
||||
wds.map_tuple(self.x_processor, handler=wds.warn_and_continue),
|
||||
wds.map(self.to_dict, handler=wds.warn_and_continue),
|
||||
)
|
||||
|
||||
@ -26,7 +26,7 @@ class CCSBUDataset(BaseDataset):
|
||||
}
|
||||
|
||||
|
||||
class CCSBUAlignDataset(CaptionDataset):
|
||||
class CCSBUAlignDatasetImageImageCaptionDataset(ImageCaptionDataset):
|
||||
|
||||
def __getitem__(self, index):
|
||||
|
||||
@ -34,10 +34,10 @@ class CCSBUAlignDataset(CaptionDataset):
|
||||
ann = self.annotation[index]
|
||||
|
||||
img_file = '{}.jpg'.format(ann["image_id"])
|
||||
image_path = os.path.join(self.vis_root, img_file)
|
||||
image_path = os.path.join(self.x_root, img_file)
|
||||
image = Image.open(image_path).convert("RGB")
|
||||
|
||||
image = self.vis_processor(image)
|
||||
image = self.x_processor(image)
|
||||
caption = ann["caption"]
|
||||
|
||||
return {
|
@ -6,32 +6,20 @@
|
||||
"""
|
||||
|
||||
import os
|
||||
from collections import OrderedDict
|
||||
|
||||
from minigpt4.datasets.datasets.base_dataset import BaseDataset
|
||||
from PIL import Image
|
||||
|
||||
|
||||
class __DisplMixin:
|
||||
def displ_item(self, index):
|
||||
sample, ann = self.__getitem__(index), self.annotation[index]
|
||||
|
||||
return OrderedDict(
|
||||
{
|
||||
"file": ann["image"],
|
||||
"caption": ann["caption"],
|
||||
"image": sample["image"],
|
||||
}
|
||||
)
|
||||
from minigpt4.datasets.datasets.mixins.mixins import __ImageDisplMixin
|
||||
|
||||
|
||||
class CaptionDataset(BaseDataset, __DisplMixin):
|
||||
def __init__(self, vis_processor, text_processor, vis_root, ann_paths):
|
||||
class ImageCaptionDataset(BaseDataset, __ImageDisplMixin):
|
||||
def __init__(self, vision_processor, text_processor, x_root, ann_paths):
|
||||
"""
|
||||
vis_root (string): Root directory of images (e.g. coco/images/)
|
||||
ann_root (string): directory to store the annotation file
|
||||
"""
|
||||
super().__init__(vis_processor, text_processor, vis_root, ann_paths)
|
||||
super().__init__(vision_processor, text_processor, x_root, ann_paths)
|
||||
|
||||
self.img_ids = {}
|
||||
n = 0
|
||||
@ -47,10 +35,10 @@ class CaptionDataset(BaseDataset, __DisplMixin):
|
||||
ann = self.annotation[index]
|
||||
|
||||
img_file = '{:0>12}.jpg'.format(ann["image_id"])
|
||||
image_path = os.path.join(self.vis_root, img_file)
|
||||
image_path = os.path.join(self.x_root, img_file)
|
||||
image = Image.open(image_path).convert("RGB")
|
||||
|
||||
image = self.vis_processor(image)
|
||||
image = self.x_processor(image)
|
||||
caption = self.text_processor(ann["caption"])
|
||||
|
||||
return {
|
||||
@ -60,23 +48,23 @@ class CaptionDataset(BaseDataset, __DisplMixin):
|
||||
}
|
||||
|
||||
|
||||
class CaptionEvalDataset(BaseDataset, __DisplMixin):
|
||||
def __init__(self, vis_processor, text_processor, vis_root, ann_paths):
|
||||
class CaptionEvalDataset(BaseDataset, __ImageDisplMixin):
|
||||
def __init__(self, vision_processor, text_processor, x_root, ann_paths):
|
||||
"""
|
||||
vis_root (string): Root directory of images (e.g. coco/images/)
|
||||
ann_root (string): directory to store the annotation file
|
||||
split (string): val or test
|
||||
"""
|
||||
super().__init__(vis_processor, text_processor, vis_root, ann_paths)
|
||||
super().__init__(vision_processor, text_processor, x_root, ann_paths)
|
||||
|
||||
def __getitem__(self, index):
|
||||
|
||||
ann = self.annotation[index]
|
||||
|
||||
image_path = os.path.join(self.vis_root, ann["image"])
|
||||
image_path = os.path.join(self.x_root, ann["image"])
|
||||
image = Image.open(image_path).convert("RGB")
|
||||
|
||||
image = self.vis_processor(image)
|
||||
image = self.x_processor(image)
|
||||
|
||||
return {
|
||||
"image": image,
|
@ -10,8 +10,8 @@ from minigpt4.datasets.datasets.base_dataset import BaseDataset
|
||||
|
||||
|
||||
class LaionDataset(BaseDataset):
|
||||
def __init__(self, vis_processor, text_processor, location):
|
||||
super().__init__(vis_processor=vis_processor, text_processor=text_processor)
|
||||
def __init__(self, vision_processor, text_processor, location):
|
||||
super().__init__(x_processor=vision_processor, text_processor=text_processor)
|
||||
|
||||
self.inner_dataset = wds.DataPipeline(
|
||||
wds.ResampledShards(location),
|
||||
@ -19,7 +19,7 @@ class LaionDataset(BaseDataset):
|
||||
wds.shuffle(1000, handler=wds.warn_and_continue),
|
||||
wds.decode("pilrgb", handler=wds.warn_and_continue),
|
||||
wds.to_tuple("jpg", "json", handler=wds.warn_and_continue),
|
||||
wds.map_tuple(self.vis_processor, handler=wds.warn_and_continue),
|
||||
wds.map_tuple(self.x_processor, handler=wds.warn_and_continue),
|
||||
wds.map(self.to_dict, handler=wds.warn_and_continue),
|
||||
)
|
||||
|
0
minigpt4/datasets/datasets/mixins/__init__.py
Normal file
0
minigpt4/datasets/datasets/mixins/__init__.py
Normal file
30
minigpt4/datasets/datasets/mixins/mixins.py
Normal file
30
minigpt4/datasets/datasets/mixins/mixins.py
Normal file
@ -0,0 +1,30 @@
|
||||
from collections import OrderedDict
|
||||
|
||||
|
||||
class __ImageDisplMixin:
|
||||
def displ_item(self, index):
|
||||
sample, ann = self.__getitem__(index), self.annotation[index]
|
||||
|
||||
return OrderedDict(
|
||||
{
|
||||
"file": ann["image"],
|
||||
"caption": ann["caption"],
|
||||
"image": sample["image"],
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class __AudioDisplMixin:
|
||||
def displ_item(self, index):
|
||||
sample, ann = self.__getitem__(index), self.annotation[index]
|
||||
|
||||
# TODO: Finish the Audio Display Mixin
|
||||
'''
|
||||
return OrderedDict(
|
||||
{
|
||||
}
|
||||
)
|
||||
'''
|
||||
|
||||
raise NotImplementedError
|
||||
|
@ -145,4 +145,5 @@ class ImageBindVisionEvalProcessor(ImageBindVisionBaseProcessor):
|
||||
mean = cfg.get("mean", None)
|
||||
std = cfg.get("std", None)
|
||||
|
||||
return cls(image_size=image_size, mean=mean, std=std)
|
||||
return cls(image_size=image_size, mean=mean, std=std)
|
||||
|
||||
|
@ -24,7 +24,7 @@ from minigpt4.common.dist_utils import (
|
||||
)
|
||||
from minigpt4.common.registry import registry
|
||||
from minigpt4.common.utils import is_url
|
||||
from minigpt4.datasets.data_utils import concat_datasets, reorg_datasets_by_split, ChainDataset
|
||||
from minigpt4.datasets.data_utils import concat_datasets, reorg_datasets_by_split, WrappedChainDataset
|
||||
from minigpt4.datasets.datasets.dataloader_utils import (
|
||||
IterLoader,
|
||||
MultiIterLoader,
|
||||
@ -219,7 +219,7 @@ class RunnerBase:
|
||||
num_records = sum(
|
||||
[
|
||||
len(d)
|
||||
if not type(d) in [wds.DataPipeline, ChainDataset]
|
||||
if not type(d) in [wds.DataPipeline, WrappedChainDataset]
|
||||
else 0
|
||||
for d in self.datasets[split_name]
|
||||
]
|
||||
@ -503,7 +503,7 @@ class RunnerBase:
|
||||
|
||||
def _create_loader(dataset, num_workers, bsz, is_train, collate_fn):
|
||||
# create a single dataloader for each split
|
||||
if isinstance(dataset, ChainDataset) or isinstance(
|
||||
if isinstance(dataset, WrappedChainDataset) or isinstance(
|
||||
dataset, wds.DataPipeline
|
||||
):
|
||||
# wds.WebdDataset instance are chained together
|
||||
|
57
train_configs/bindgpt4.yaml
Normal file
57
train_configs/bindgpt4.yaml
Normal file
@ -0,0 +1,57 @@
|
||||
model:
|
||||
arch: bind_gpt4
|
||||
model_type: pretrain_vicuna
|
||||
freeze_imagebind: True
|
||||
freeze_qformer: False
|
||||
|
||||
|
||||
datasets:
|
||||
laion:
|
||||
vis_processor:
|
||||
train:
|
||||
name: "imagebind_vision_train"
|
||||
image_size: 224
|
||||
text_processor:
|
||||
train:
|
||||
name: "imagebind_caption"
|
||||
sample_ratio: 115
|
||||
cc_sbu:
|
||||
vis_processor:
|
||||
train:
|
||||
name: "imagebind_vision_train"
|
||||
image_size: 224
|
||||
text_processor:
|
||||
train:
|
||||
name: "imagebind_caption"
|
||||
sample_ratio: 14
|
||||
|
||||
|
||||
run:
|
||||
task: imagebind_qformer_train
|
||||
# optimizer
|
||||
lr_sched: "linear_warmup_cosine_lr"
|
||||
init_lr: 1e-4
|
||||
min_lr: 8e-5
|
||||
warmup_lr: 1e-6
|
||||
|
||||
weight_decay: 0.05
|
||||
max_epoch: 4
|
||||
batch_size_train: 64
|
||||
batch_size_eval: 64
|
||||
num_workers: 4
|
||||
warmup_steps: 5000
|
||||
iters_per_epoch: 5000
|
||||
|
||||
seed: 42
|
||||
output_dir: "output/bindgpt4"
|
||||
|
||||
amp: True
|
||||
resume_ckpt_path: null
|
||||
|
||||
evaluate: False
|
||||
train_splits: ["train"]
|
||||
|
||||
device: "cuda"
|
||||
world_size: 1
|
||||
dist_url: "env://"
|
||||
distributed: True
|
Loading…
Reference in New Issue
Block a user