init audio data config (#2)

- Add audio datasets
- Add audio processors 
- Add audio support in bindgpt
- Add audio training config

---------

Co-authored-by: bingyikang <bingyikang@bytedance.com>
Co-authored-by: zhaoyang <913556700@qq.com>
This commit is contained in:
Bingyi Kang 2023-05-26 11:44:18 +08:00 committed by GitHub
parent 3efda2ac76
commit 05220fe3c1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
31 changed files with 717 additions and 267 deletions

175
README.md
View File

@ -1,170 +1,5 @@
# MiniGPT-4: Enhancing Vision-language Understanding with Advanced Large Language Models
[Deyao Zhu](https://tsutikgiau.github.io/)* (On Job Market!), [Jun Chen](https://junchen14.github.io/)* (On Job Market!), [Xiaoqian Shen](https://xiaoqian-shen.github.io), [Xiang Li](https://xiangli.ac.cn), and [Mohamed Elhoseiny](https://www.mohamed-elhoseiny.com/). *Equal Contribution
**King Abdullah University of Science and Technology**
<a href='https://minigpt-4.github.io'><img src='https://img.shields.io/badge/Project-Page-Green'></a> <a href='https://arxiv.org/abs/2304.10592'><img src='https://img.shields.io/badge/Paper-Arxiv-red'></a> <a href='https://huggingface.co/spaces/Vision-CAIR/minigpt4'><img src='https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue'></a> <a href='https://huggingface.co/Vision-CAIR/MiniGPT-4'><img src='https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Model-blue'></a> [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1OK4kYsZphwt5DXchKkzMBjYF6jnkqh4R?usp=sharing) [![YouTube](https://badges.aleen42.com/src/youtube.svg)](https://www.youtube.com/watch?v=__tftoxpBAw&feature=youtu.be)
## News
We now provide a pretrained MiniGPT-4 aligned with Vicuna-7B! The demo GPU memory consumption now can be as low as 12GB.
## Online Demo
Click the image to chat with MiniGPT-4 around your images
[![demo](figs/online_demo.png)](https://minigpt-4.github.io)
## Examples
| | |
:-------------------------:|:-------------------------:
![find wild](figs/examples/wop_2.png) | ![write story](figs/examples/ad_2.png)
![solve problem](figs/examples/fix_1.png) | ![write Poem](figs/examples/rhyme_1.png)
More examples can be found in the [project page](https://minigpt-4.github.io).
## Introduction
- MiniGPT-4 aligns a frozen visual encoder from BLIP-2 with a frozen LLM, Vicuna, using just one projection layer.
- We train MiniGPT-4 with two stages. The first traditional pretraining stage is trained using roughly 5 million aligned image-text pairs in 10 hours using 4 A100s. After the first stage, Vicuna is able to understand the image. But the generation ability of Vicuna is heavilly impacted.
- To address this issue and improve usability, we propose a novel way to create high-quality image-text pairs by the model itself and ChatGPT together. Based on this, we then create a small (3500 pairs in total) yet high-quality dataset.
- The second finetuning stage is trained on this dataset in a conversation template to significantly improve its generation reliability and overall usability. To our surprise, this stage is computationally efficient and takes only around 7 minutes with a single A100.
- MiniGPT-4 yields many emerging vision-language capabilities similar to those demonstrated in GPT-4.
![overview](figs/overview.png)
## Getting Started
### Installation
**1. Prepare the code and the environment**
Git clone our repository, creating a python environment and ativate it via the following command
```bash
git clone https://github.com/Vision-CAIR/MiniGPT-4.git
cd MiniGPT-4
conda env create -f environment.yml
conda activate minigpt4
```
**2. Prepare the pretrained Vicuna weights**
The current version of MiniGPT-4 is built on the v0 versoin of Vicuna-13B.
Please refer to our instruction [here](PrepareVicuna.md)
to prepare the Vicuna weights.
The final weights would be in a single folder in a structure similar to the following:
```
vicuna_weights
├── config.json
├── generation_config.json
├── pytorch_model.bin.index.json
├── pytorch_model-00001-of-00003.bin
...
```
Then, set the path to the vicuna weight in the model config file
[here](minigpt4/configs/models/minigpt4.yaml#L16) at Line 16.
**3. Prepare the pretrained MiniGPT-4 checkpoint**
Download the pretrained checkpoints according to the Vicuna model you prepare.
| Checkpoint Aligned with Vicuna 13B | Checkpoint Aligned with Vicuna 7B |
:------------------------------------------------------------------------------------------------:|:----------------------------------------------------------------------------------------------:
[Downlad](https://drive.google.com/file/d/1a4zLvaiDBr-36pasffmgpvH5P7CKmpze/view?usp=share_link) | [Download](https://drive.google.com/file/d/1RY9jV0dyqLX-o38LrumkKRh6Jtaop58R/view?usp=sharing)
Then, set the path to the pretrained checkpoint in the evaluation config file
in [eval_configs/minigpt4_eval.yaml](eval_configs/minigpt4_eval.yaml#L10) at Line 11.
### Launching Demo Locally
Try out our demo [demo.py](eval_scripts/qualitative_eval.py) on your local machine by running
```
python demo.py --cfg-path eval_configs/minigpt4_eval.yaml --gpu-id 0
```
To save GPU memory, Vicuna loads as 8 bit by default, with a beam search width of 1.
This configuration requires about 23G GPU memory for Vicuna 13B and 11.5G GPU memory for Vicuna 7B.
For more powerful GPUs, you can run the model
in 16 bit by setting low_resource to False in the config file
[minigpt4_eval.yaml](eval_configs/minigpt4_eval.yaml) and use a larger beam search width.
Thanks [@WangRongsheng](https://github.com/WangRongsheng), you can also run our code on [Colab](https://colab.research.google.com/drive/1OK4kYsZphwt5DXchKkzMBjYF6jnkqh4R?usp=sharing)
### Training
The training of MiniGPT-4 contains two alignment stages.
**1. First pretraining stage**
In the first pretrained stage, the model is trained using image-text pairs from Laion and CC datasets
to align the vision and language model. To download and prepare the datasets, please check
our [first stage dataset preparation instruction](dataset/README_1_STAGE.md).
After the first stage, the visual features are mapped and can be understood by the language
model.
To launch the first stage training, run the following command. In our experiments, we use 4 A100.
You can change the save path in the config file
[train_configs/minigpt4_stage1_pretrain.yaml](train_configs/minigpt4_stage1_pretrain.yaml)
```bash
torchrun --nproc-per-node NUM_GPU train.py --cfg-path train_configs/minigpt4_stage1_pretrain.yaml
```
A MiniGPT-4 checkpoint with only stage one training can be downloaded
[here (13B)](https://drive.google.com/file/d/1u9FRRBB3VovP1HxCAlpD9Lw4t4P6-Yq8/view?usp=share_link) or [here (7B)](https://drive.google.com/file/d/1HihQtCEXUyBM1i9DQbaK934wW3TZi-h5/view?usp=share_link).
Compared to the model after stage two, this checkpoint generate incomplete and repeated sentences frequently.
**2. Second finetuning stage**
In the second stage, we use a small high quality image-text pair dataset created by ourselves
and convert it to a conversation format to further align MiniGPT-4.
To download and prepare our second stage dataset, please check our
[second stage dataset preparation instruction](dataset/README_2_STAGE.md).
To launch the second stage alignment,
first specify the path to the checkpoint file trained in stage 1 in
[train_configs/minigpt4_stage1_pretrain.yaml](train_configs/minigpt4_stage2_finetune.yaml).
You can also specify the output path there.
Then, run the following command. In our experiments, we use 1 A100.
```bash
torchrun --nproc-per-node NUM_GPU train.py --cfg-path train_configs/minigpt4_stage2_finetune.yaml
```
After the second stage alignment, MiniGPT-4 is able to talk about the image coherently and user-friendly.
## Acknowledgement
+ [BLIP2](https://huggingface.co/docs/transformers/main/model_doc/blip-2) The model architecture of MiniGPT-4 follows BLIP-2. Don't forget to check this great open-source work if you don't know it before!
+ [Lavis](https://github.com/salesforce/LAVIS) This repository is built upon Lavis!
+ [Vicuna](https://github.com/lm-sys/FastChat) The fantastic language ability of Vicuna with only 13B parameters is just amazing. And it is open-source!
If you're using MiniGPT-4 in your research or applications, please cite using this BibTeX:
```bibtex
@article{zhu2023minigpt,
title={MiniGPT-4: Enhancing Vision-Language Understanding with Advanced Large Language Models},
author={Zhu, Deyao and Chen, Jun and Shen, Xiaoqian and Li, Xiang and Elhoseiny, Mohamed},
journal={arXiv preprint arXiv:2304.10592},
year={2023}
}
```
## License
This repository is under [BSD 3-Clause License](LICENSE.md).
Many codes are based on [Lavis](https://github.com/salesforce/LAVIS) with
BSD 3-Clause License [here](LICENSE_Lavis.md).
# TODO List
- [ ] reconstruct the eval scripts
- [ ] support audio evaluation
- [ ] clean code repo & solve recursive import
- [x] audio processor: random sampling

View File

@ -6,7 +6,9 @@
pip3 install --upgrade pip
pip3 install -r requirements.txt
pip3 install byted-dataloader -i "https://bytedpypi.byted.org/simple"
mmengine-0.7.3
pip3 install mmmengine==0.7.3
pip3 install mmcv==2.0.0 -f https://download.openmmlab.com/mmcv/dist/cu117/torch1.13/index.html
# unset http_proxy && unset https_proxy && unset no_proxy
# # ----------------------------------------------------------------------------------------

View File

@ -1,4 +1,5 @@
import torch
import torchaudio
from PIL import Image
@ -13,3 +14,12 @@ def load_image(image, image_processor):
if len(image.shape) == 3:
image = image.unsqueeze(0)
return image
def load_audio(audio, audio_processor):
if isinstance(audio, str): # is a audio path
raw_audio = torchaudio.load(audio)
audio = audio_processor(audio)
# elif isinstance(audio, )
else:
raise NotImplementedError

View File

@ -125,6 +125,7 @@ with gr.Blocks() as demo:
with gr.Row():
with gr.Column(scale=0.5):
image = gr.Image(type="pil")
# audio = gr.Audio()
upload_button = gr.Button(value="Upload & Start Chat", interactive=True, variant="primary")
clear = gr.Button("Restart")
@ -150,7 +151,7 @@ with gr.Blocks() as demo:
chat_state = gr.State()
emb_list = gr.State()
chatbot = gr.Chatbot(label='BindGPT-4')
text_input = gr.Textbox(label='User', placeholder='Please upload your image first', interactive=False)
text_input = gr.Textbox(label='User', placeholder='Please upload your image/audio first', interactive=False)
upload_button.click(upload_img, [image, text_input, chat_state],
[image, text_input, upload_button, chat_state, emb_list])

View File

@ -15,7 +15,7 @@ import logging
from imagebind.models.multimodal_preprocessors import SimpleTokenizer
from PIL import Image
from pytorchvideo import transforms as pv_transforms
from pytorchvideo.data.clip_sampling import ConstantClipsPerVideoSampler
from pytorchvideo.data.clip_sampling import ConstantClipsPerVideoSampler, RandomMultiClipSampler
from pytorchvideo.data.encoded_video import EncodedVideo
from torchvision import transforms
@ -59,23 +59,11 @@ def waveform2melspec(waveform, sample_rate, num_mel_bins, target_length):
fbank = torch.nn.functional.pad(fbank, (0, p), mode="constant", value=0)
elif p < 0:
fbank = fbank[:, 0:target_length]
# Convert to [1, mel_bins, num_frames] shape, essentially like a 1
# channel image
# Convert to [1, mel_bins, num_frames] shape, essentially like a 1 channel image
fbank = fbank.unsqueeze(0)
return fbank
def get_clip_timepoints(clip_sampler, duration):
# Read out all clips in this video
all_clips_timepoints = []
is_last_clip = False
end = 0.0
while not is_last_clip:
start, end, _, _, is_last_clip = clip_sampler(end, duration, annotation=None)
all_clips_timepoints.append((start, end))
return all_clips_timepoints
def load_and_transform_vision_data(image_paths, device):
if image_paths is None:
return None
@ -137,14 +125,14 @@ def load_and_transform_audio_data(
waveform = torchaudio.functional.resample(
waveform, orig_freq=sr, new_freq=sample_rate
)
all_clips_timepoints = get_clip_timepoints(
all_clips_timepoints = get_constant_clip_timepoints(
clip_sampler, waveform.size(1) / sample_rate
)
all_clips = []
for clip_timepoints in all_clips_timepoints:
waveform_clip = waveform[
:,
int(clip_timepoints[0] * sample_rate) : int(
int(clip_timepoints[0] * sample_rate): int(
clip_timepoints[1] * sample_rate
),
]
@ -162,7 +150,8 @@ def load_and_transform_audio_data(
return torch.stack(audio_outputs, dim=0)
def get_clip_timepoints(clip_sampler, duration):
def get_constant_clip_timepoints(clip_sampler, duration):
assert isinstance(clip_sampler, ConstantClipsPerVideoSampler), "Incompatible Type of Sampler!"
# Read out all clips in this video
all_clips_timepoints = []
is_last_clip = False
@ -173,6 +162,13 @@ def get_clip_timepoints(clip_sampler, duration):
return all_clips_timepoints
def get_random_clip_timepoints(clip_sampler, duration):
assert isinstance(clip_sampler, RandomMultiClipSampler), "Incompatible Type of Sampler!"
starts, ends, _, _, _ = clip_sampler(0.0, duration, annotation=None)
all_clips_timepoints = sorted(list(zip(starts, ends)), key=lambda x: x[0])
return all_clips_timepoints
def crop_boxes(boxes, x_offset, y_offset):
"""
Perform crop on the bounding boxes given the offsets.
@ -244,7 +240,7 @@ def uniform_crop(images, size, spatial_idx, boxes=None, scale_size=None):
x_offset = 0
elif spatial_idx == 2:
x_offset = width - size
cropped = images[:, :, y_offset : y_offset + size, x_offset : x_offset + size]
cropped = images[:, :, y_offset: y_offset + size, x_offset: x_offset + size]
cropped_boxes = crop_boxes(boxes, x_offset, y_offset) if boxes is not None else None
if ndim == 3:
cropped = cropped.squeeze(0)
@ -328,7 +324,7 @@ def load_and_transform_video_data(
**{"sample_rate": sample_rate},
)
all_clips_timepoints = get_clip_timepoints(clip_sampler, video.duration)
all_clips_timepoints = get_constant_clip_timepoints(clip_sampler, video.duration)
all_video = []
for clip_timepoints in all_clips_timepoints:
@ -347,4 +343,4 @@ def load_and_transform_video_data(
all_video = torch.stack(all_video, dim=0)
video_outputs.append(all_video)
return torch.stack(video_outputs, dim=0).to(device)
return torch.stack(video_outputs, dim=0).to(device)

View File

@ -50,46 +50,59 @@ ModalityType = SimpleNamespace(
class ImageBindJoiner(nn.Module):
def __init__(self,
vision_query_token_num: int,
audio_query_token_num: int,
vision_qformer_frozen: bool = False,
vision_qformer_model: str = "", # The url or path of pre-trained vision Q-Former model
vision_pre_dims: List[int] = (), # Projection before Q-Former
vision_post_dims: List[int] = (768, 768) # Projection after Q-Former
vision_post_dims: List[int] = (1280, 768), # Projection after Q-Former
audio_pre_dims: List[int] = (),
audio_post_dims: List[int] = (768, 768)
):
super().__init__()
assert not (vision_qformer_frozen and vision_qformer_model == "")
self.modality_pre_projectors = self._create_modality_pre_projectors(vision_pre_dims)
self.modality_pre_projectors = self._create_modality_pre_projectors(vision_pre_dims, audio_pre_dims)
self.modality_qformers = self._create_modality_qformers(vision_query_token_num,
vision_qformer_frozen,
vision_qformer_model)
self.modality_post_projectors = self._create_modality_post_projectors(vision_post_dims)
vision_qformer_model,
audio_query_token_num)
self.modality_post_projectors = self._create_modality_post_projectors(vision_post_dims, audio_post_dims)
def _create_modality_pre_projectors(self,
vision_pre_dims
vision_pre_dims,
audio_pre_dims
):
modality_pre_projectors = {
ModalityType.VISION: create_projectors(vision_pre_dims)
ModalityType.VISION: create_projectors(vision_pre_dims),
ModalityType.AUDIO: create_projectors(audio_pre_dims)
}
return modality_pre_projectors
def _create_modality_qformers(self,
vision_query_token_num,
vision_qformer_frozen,
vision_qformer_model
vision_qformer_model,
audio_query_token_num
):
vision_qformer = SequenceGenericQFormer(num_query_token=vision_query_token_num,
freeze_qformer=vision_qformer_frozen,
encoder_width=1280, # TODO: fix hard-coding
q_former_model=vision_qformer_model)
audio_qformer = SequenceGenericQFormer(num_query_token=audio_query_token_num,
freeze_qformer=False,
encoder_width=768)
modality_qformers = {
ModalityType.VISION: vision_qformer
ModalityType.VISION: vision_qformer,
ModalityType.AUDIO: audio_qformer
}
return nn.ModuleDict(modality_qformers)
def _create_modality_post_projectors(self, vision_post_dims):
def _create_modality_post_projectors(self, vision_post_dims, audio_post_dims):
vision_projector = create_projectors(vision_post_dims)
audio_projector = create_projectors(audio_post_dims)
modality_projectors = {
ModalityType.VISION: vision_projector
ModalityType.VISION: vision_projector,
ModalityType.AUDIO: audio_projector
}
return nn.ModuleDict(modality_projectors)
@ -97,7 +110,6 @@ class ImageBindJoiner(nn.Module):
def forward(self, inputs: Dict[str, Tensor]) -> Dict[str, Tensor]:
outputs = {}
for modality_key, modality_value in inputs.items():
# assert modality_key == ModalityType.VISION, "Only Vision is Currently Supported."
if modality_value is not None:
modality_value = self.modality_pre_projectors[modality_key](modality_value)
modality_value = self.modality_qformers[modality_key](modality_value)
@ -544,7 +556,7 @@ class ImageBindModel(nn.Module):
# NOTE: The reduction operation has been modified.
if reduce_list:
modality_value = modality_value.reshape(B, S, *modality_value[2:])
modality_value = modality_value.reshape(B, S, *modality_value.shape[1:])
modality_value = modality_value.mean(dim=1)
outputs[modality_key] = modality_value

View File

@ -32,10 +32,12 @@ class Registry:
"""
def wrap(builder_cls):
# TODO: merge them or split builders by modality
from minigpt4.datasets.builders.image_base_dataset_builder import ImageBaseDatasetBuilder
from minigpt4.datasets.builders.audio_base_dataset_builder import AudioBaseDatasetBuilder
assert issubclass(
builder_cls, ImageBaseDatasetBuilder
builder_cls, (ImageBaseDatasetBuilder, AudioBaseDatasetBuilder)
), "All builders must inherit BaseDatasetBuilder class, found {}".format(
builder_cls
)

View File

@ -0,0 +1,5 @@
datasets:
audioset:
data_type: audio
build_info:
storage: /mnt/bn/zilongdata-hl/dataset/wavcaps/web_datasets/AudioSet_SL/AudioSet_SL{00..54}.tar

View File

@ -0,0 +1,5 @@
datasets:
bbc:
data_type: audio
build_info:
storage: /mnt/bn/zilongdata-hl/dataset/wavcaps/web_datasets/BBC_Sound_Effects/BBC_Sound_Effects{000000..000062}.tar

View File

@ -0,0 +1,5 @@
datasets:
freesound:
data_type: audio
build_info:
storage: /mnt/bn/zilongdata-hl/dataset/wavcaps/web_datasets/FreeSound/FreeSound{000000..000524}.tar

View File

@ -0,0 +1,5 @@
datasets:
soundbible:
data_type: audio
build_info:
storage: /mnt/bn/zilongdata-hl/dataset/wavcaps/web_datasets/SoundBible/SoundBible0.tar

View File

@ -11,12 +11,23 @@ from minigpt4.datasets.builders.image_text_pair_builder import (
LaionBuilderImage,
CCSBUAlignBuilderImage
)
from minigpt4.datasets.builders.audio_text_pair_builder import (
BBCBuilder,
AudioSetBuilder,
SoundBibleBuilder,
FreeSoundBuilder
)
from minigpt4.common.registry import registry
__all__ = [
"CCSBUBuilderImage",
"LaionBuilderImage",
"CCSBUAlignBuilderImage"
"CCSBUAlignBuilderImage",
# Audio builders
"BBCBuilder",
"AudioSetBuilder",
"SoundBibleBuilder",
"FreeSoundBuilder",
]

View File

@ -31,6 +31,51 @@ class AudioBaseDatasetBuilder:
self.data_type = self.config.data_type
self.audio_processors = {"train": BaseProcessor(), "eval": BaseProcessor()}
self.text_processors = {"train": BaseProcessor(), "eval": BaseProcessor()}
def build_datasets(self):
# download, split, etc...
# only called on 1 GPU/TPU in distributed
if is_main_process():
self._download_data()
if is_dist_avail_and_initialized():
dist.barrier()
# at this point, all the annotations and image/videos should be all downloaded to the specified locations.
logging.info("Building datasets...")
datasets = self.build() # dataset['train'/'val'/'test']
return datasets
def build_processors(self):
aud_proc_cfg = self.config.get("audio_processor")
txt_proc_cfg = self.config.get("text_processor")
if aud_proc_cfg is not None:
aud_train_cfg = aud_proc_cfg.get("train")
aud_eval_cfg = aud_proc_cfg.get("eval")
self.audio_processors["train"] = self._build_proc_from_cfg(aud_train_cfg)
self.audio_processors["eval"] = self._build_proc_from_cfg(aud_eval_cfg)
if txt_proc_cfg is not None:
txt_train_cfg = txt_proc_cfg.get("train")
txt_eval_cfg = txt_proc_cfg.get("eval")
self.text_processors["train"] = self._build_proc_from_cfg(txt_train_cfg)
self.text_processors["eval"] = self._build_proc_from_cfg(txt_eval_cfg)
@staticmethod
def _build_proc_from_cfg(cfg):
return (
registry.get_processor_class(cfg.name).from_config(cfg)
if cfg is not None
else None
)
@classmethod
def default_config_path(cls, type="default"):
return utils.get_abs_path(cls.DATASET_CONFIG_DICT[type])
@ -95,5 +140,3 @@ class AudioBaseDatasetBuilder:
filename = os.path.basename(storage_path)
download_url(url=url_or_filename, root=dirname, filename=filename)

View File

@ -0,0 +1,56 @@
import os
import logging
import warnings
from minigpt4.common.registry import registry
from minigpt4.datasets.builders.audio_base_dataset_builder import AudioBaseDatasetBuilder
from minigpt4.datasets.datasets.audio_caption import GenericAudioDataset
class GenericAudioBuilder(AudioBaseDatasetBuilder):
train_dataset_cls = GenericAudioDataset
def _download_ann(self):
pass
def _download_aud(self):
pass
def build(self):
self.build_processors()
build_info = self.config.build_info
datasets = dict()
split = "train"
# create datasets
dataset_cls = self.train_dataset_cls
datasets[split] = dataset_cls(
audio_processor=self.audio_processors[split],
text_processor=self.text_processors[split],
location=build_info.storage,
).inner_dataset
return datasets
@registry.register_builder("bbc")
class BBCBuilder(GenericAudioBuilder):
DATASET_CONFIG_DICT = {"default": "configs/datasets/bbc/defaults.yaml"}
@registry.register_builder("audioset")
class AudioSetBuilder(GenericAudioBuilder):
DATASET_CONFIG_DICT = {"default": "configs/datasets/audioset/defaults.yaml"}
@registry.register_builder("soundbible")
class SoundBibleBuilder(GenericAudioBuilder):
DATASET_CONFIG_DICT = {"default": "configs/datasets/soundbible/defaults.yaml"}
@registry.register_builder("freesound")
class FreeSoundBuilder(GenericAudioBuilder):
DATASET_CONFIG_DICT = {"default": "configs/datasets/freesound/defaults.yaml"}

View File

@ -0,0 +1 @@
from minigpt4.datasets.datasets.audio_caption.audio_caption_datasets import GenericAudioDataset

View File

@ -6,21 +6,22 @@ from minigpt4.datasets.datasets.base_dataset import BaseDataset
class GenericAudioDataset(BaseDataset):
def __init__(self, vision_processor, text_processor, location):
super().__init__(x_processor=vision_processor, text_processor=text_processor)
def __init__(self, audio_processor, text_processor, location):
super().__init__(x_processor=audio_processor, text_processor=text_processor)
self.inner_dataset = wds.DataPipeline(
wds.ResampledShards(location),
wds.tarfile_to_samples(handler=wds.warn_and_continue),
wds.shuffle(1000, handler=wds.warn_and_continue),
wds.decode(wds.torch_audio, handler=wds.warn_and_continue),
wds.to_tuple("jpg", "json", handler=wds.warn_and_continue),
wds.to_tuple("flac", "json", handler=wds.warn_and_continue),
wds.map_tuple(self.x_processor, handler=wds.warn_and_continue),
wds.map(self.to_dict, handler=wds.warn_and_continue),
)
def to_dict(self, sample):
return {
"image": sample[0],
"audio": sample[0],
# [clips_per_video, channel, mel_bins, time_steps]
"text_input": self.text_processor(sample[1]["caption"]),
}

View File

@ -21,7 +21,7 @@ class CCSBUDataset(BaseDataset):
def to_dict(self, sample):
return {
"image": sample[0],
"vision": sample[0],
"text_input": self.text_processor(sample[1]["caption"]),
}
@ -41,7 +41,7 @@ class CCSBUAlignDatasetImageImageCaptionDataset(ImageCaptionDataset):
caption = ann["caption"]
return {
"image": image,
"vision": image,
"text_input": caption,
"image_id": self.img_ids[ann["image_id"]],
}
@ -49,7 +49,7 @@ class CCSBUAlignDatasetImageImageCaptionDataset(ImageCaptionDataset):
class CCDataset(BaseDataset):
def __init__(self, vis_processor, text_processor, location):
super().__init__(vis_processor=vis_processor, text_processor=text_processor)
super().__init__(x_processor=vis_processor, text_processor=text_processor)
self.inner_dataset = wds.DataPipeline(
wds.ResampledShards(location),
@ -57,12 +57,12 @@ class CCDataset(BaseDataset):
wds.shuffle(1000, handler=wds.warn_and_continue),
wds.decode("pilrgb", handler=wds.warn_and_continue),
wds.to_tuple("jpg", "txt", handler=wds.warn_and_continue),
wds.map_tuple(self.vis_processor, handler=wds.warn_and_continue),
wds.map_tuple(self.x_processor, handler=wds.warn_and_continue),
wds.map(self.to_dict, handler=wds.warn_and_continue),
)
def to_dict(self, sample):
return {
"image": sample[0],
"vision": sample[0],
"text_input": sample[1],
}

View File

@ -42,7 +42,7 @@ class ImageCaptionDataset(BaseDataset, __ImageDisplMixin):
caption = self.text_processor(ann["caption"])
return {
"image": image,
"vision": image,
"text_input": caption,
"image_id": self.img_ids[ann["image_id"]],
}
@ -67,7 +67,7 @@ class CaptionEvalDataset(BaseDataset, __ImageDisplMixin):
image = self.x_processor(image)
return {
"image": image,
"vision": image,
"image_id": ann["image_id"],
"instance_id": ann["instance_id"],
}

View File

@ -25,7 +25,7 @@ class LaionDataset(BaseDataset):
def to_dict(self, sample):
return {
"image": sample[0],
"vision": sample[0],
"text_input": self.text_processor(sample[1]["caption"]),
}

View File

@ -9,7 +9,7 @@ class __ImageDisplMixin:
{
"file": ann["image"],
"caption": ann["caption"],
"image": sample["image"],
"vision": sample["vision"],
}
)

View File

@ -1,8 +1,9 @@
import random
from typing import Dict, Tuple
from typing import Dict, Tuple, List
import torch
import torch.nn as nn
import re
from torch import Tensor
from transformers import LlamaTokenizer
@ -12,6 +13,36 @@ from minigpt4.models.blip2 import BaseModel
from minigpt4.models.modeling_llama import LlamaForCausalLM
def filter_prompt(input_embeds: Dict[str, Tensor], prompt_list: List[str]) -> List[str]:
if not prompt_list:
return prompt_list
input_modal_set = set([k.title() for k in input_embeds if input_embeds[k] is not None])
prompt_modal_sets = [set(re.findall("<([^<>]+)><ModalityHere></\\1>", prompt)) for prompt in prompt_list]
results = [prompt_list[i] for i, prompt_modal_set in enumerate(prompt_modal_sets) if
prompt_modal_set == input_modal_set]
return results
def arrange_modalities(input_embeds: Dict[str, Tensor], prompt: str) -> List[Tensor]:
prompt_modalities = re.findall("<([^<>]+)><ModalityHere></\\1>", prompt)
return [input_embeds[modality.lower()] for modality in prompt_modalities]
def concat_all_embeddings(input_embeds: Dict[str, Tensor], dim: int) -> Tensor:
embeds = [input_embeds[key] for key in input_embeds if input_embeds[key] is not None]
return torch.cat(embeds, dim=dim)
def filter_modalities(inputs):
filtered_inputs = {}
for k in ModalityType.__dict__.values():
if k in inputs:
filtered_inputs[k] = inputs[k]
return filtered_inputs
@registry.register_model("bind_gpt4")
class BindGPT4(BaseModel):
"""
@ -61,7 +92,9 @@ class BindGPT4(BaseModel):
print('Loading Q-Former and Adapter/Projector')
self.multimodal_joiner = ImageBindJoiner(vision_query_token_num=num_query_token,
vision_qformer_frozen=freeze_qformer,
vision_post_dims=[768, self.llama_model.config.hidden_size]
vision_post_dims=[768, self.llama_model.config.hidden_size],
audio_query_token_num=num_query_token,
audio_post_dims=[768, self.llama_model.config.hidden_size]
# vision_qformer_model=q_former_model,
# vision_pre_dims=(1280, 1408)
)
@ -84,42 +117,37 @@ class BindGPT4(BaseModel):
def encode_inputs(self, inputs: Dict[str, Tensor]) -> Dict[str, Tensor]:
imagebind_outputs = self.multimodal_encoder(inputs)
llama_inputs = self.multimodal_joiner(imagebind_outputs)
# NOTE: only accept image here
return llama_inputs
def prompt_wrap(self, inputs: Dict[str, Tensor], modality_name: str, prompt: str) -> Tuple[Tensor, Tensor]:
# TODO: Accept More Modalities.
input_embeds = inputs[modality_name]
attns_input = torch.ones(input_embeds.size()[:-1], dtype=torch.long).to(input_embeds.device)
if prompt:
batch_size = input_embeds.shape[0]
p_before, p_after = prompt.split('<ModalityHere>')
p_before_tokens = self.llama_tokenizer(
p_before, return_tensors="pt", add_special_tokens=False).to(input_embeds.device)
p_after_tokens = self.llama_tokenizer(
p_after, return_tensors="pt", add_special_tokens=False).to(input_embeds.device)
p_before_embeds = self.llama_model.model.embed_tokens(p_before_tokens.input_ids).expand(batch_size, -1, -1)
p_after_embeds = self.llama_model.model.embed_tokens(p_after_tokens.input_ids).expand(batch_size, -1, -1)
wrapped_input_embeds = torch.cat([p_before_embeds, inputs, p_after_embeds], dim=1)
wrapped_atts_input = attns_input[:, :1].expand(-1, wrapped_input_embeds.shape[1])
return wrapped_input_embeds, wrapped_atts_input
else:
def prompt_wrap(self, inputs: Dict[str, Tensor], prompt: str) -> Tuple[Tensor, Tensor]:
if not prompt:
input_embeds = concat_all_embeddings(inputs, dim=1)
attns_input = torch.ones(input_embeds.size()[:-1], dtype=torch.long).to(input_embeds.device)
return input_embeds, attns_input
input_embeds_list = arrange_modalities(inputs, prompt)
batch_size = input_embeds_list[0].shape[0]
prompt_slices = prompt.split('<ModalityHere>')
prompt_tokens = [self.llama_tokenizer(prompt_slice, return_tensors="pt", add_special_tokens=False)
.to(input_embeds_list[0].device) for prompt_slice in prompt_slices]
prompt_embeds = [self.llama_model.model.embed_tokens(prompt_token.input_ids).expand(batch_size, -1, -1)
for prompt_token in prompt_tokens]
result_embeds = [emb for pair in zip(prompt_embeds[:-1], input_embeds_list)
for emb in pair] + [prompt_embeds[-1]]
wrapped_input_embeds = torch.cat(result_embeds, dim=1)
wrapped_atts_input = torch.ones(wrapped_input_embeds.size()[:-1],
dtype=torch.long).to(wrapped_input_embeds.device)
return wrapped_input_embeds, wrapped_atts_input
def forward(self, inputs: Dict[str, Tensor]) -> Dict[str, Tensor]:
"""
TODO: More Modalities.
Only accept image inputs here.
Other modalities will conflict with the pre-defined prompt and wrapping strategy.
"""
bind_inputs = {ModalityType.VISION: inputs['image']}
embeds = self.encode_inputs(bind_inputs)
# assert "vision" in embeds, "Only Vision Input Can Be Accepted Now."
if self.prompt_list:
prompt = random.choice(self.prompt_list)
# filter `inputs` as it may contain informatioins other than modalities
modality_inputs = filter_modalities(inputs)
embeds = self.encode_inputs(modality_inputs)
filtered_prompts = filter_prompt(embeds, self.prompt_list)
if filtered_prompts:
prompt = random.choice(filtered_prompts)
else:
prompt = None
img_embeds, atts_img = self.prompt_wrap(embeds, ModalityType.VISION, prompt)
input_embs, input_atts = self.prompt_wrap(embeds, prompt)
# NOTE: No modifications from the next line to the end. Except for the autocast part.
@ -134,28 +162,28 @@ class BindGPT4(BaseModel):
truncation=True,
max_length=self.max_txt_len,
add_special_tokens=False
).to(img_embeds.device)
).to(input_embs.device)
targets = to_regress_tokens.input_ids.masked_fill(
to_regress_tokens.input_ids == self.llama_tokenizer.pad_token_id, -100
)
empty_targets = (
torch.ones([atts_img.shape[0], atts_img.shape[1] + 1],
dtype=torch.long).to(img_embeds.device).fill_(-100) # plus one for bos
torch.ones([input_atts.shape[0], input_atts.shape[1] + 1],
dtype=torch.long).to(input_embs.device).fill_(-100) # plus one for bos
)
targets = torch.cat([empty_targets, targets], dim=1)
batch_size = img_embeds.shape[0]
batch_size = input_embs.shape[0]
bos = torch.ones([batch_size, 1],
dtype=to_regress_tokens.input_ids.dtype,
device=to_regress_tokens.input_ids.device) * self.llama_tokenizer.bos_token_id
bos_embeds = self.llama_model.model.embed_tokens(bos)
atts_bos = atts_img[:, :1]
atts_bos = input_atts[:, :1]
to_regress_embeds = self.llama_model.model.embed_tokens(to_regress_tokens.input_ids)
inputs_embeds = torch.cat([bos_embeds, img_embeds, to_regress_embeds], dim=1)
attention_mask = torch.cat([atts_bos, atts_img, to_regress_tokens.attention_mask], dim=1)
inputs_embeds = torch.cat([bos_embeds, input_embs, to_regress_embeds], dim=1)
attention_mask = torch.cat([atts_bos, input_atts, to_regress_tokens.attention_mask], dim=1)
outputs = self.llama_model(
inputs_embeds=inputs_embeds,

View File

@ -144,7 +144,7 @@ def compute_sim_matrix(model, data_loader, **kwargs):
vit_feats = []
image_embeds = []
for samples in data_loader:
image = samples["image"]
image = samples["vision"]
image = image.to(model.device)
image_feat, vit_feat = model.forward_image(image)

View File

@ -164,7 +164,7 @@ class MiniGPT4(Blip2Base):
return img_embeds, atts_img
def forward(self, samples):
image = samples["image"]
image = samples["vision"]
img_embeds, atts_img = self.encode_img(image)
if hasattr(samples, 'question_split'): # VQA dataset
print('VQA Batch')

View File

@ -11,11 +11,15 @@ from minigpt4.processors.blip_processors import (
Blip2ImageEvalProcessor,
BlipCaptionProcessor,
)
from minigpt4.processors.imagebind_processor import (
from minigpt4.processors.imagebind_vision_processor import (
ImageBindCaptionProcessor,
ImageBindVisionTrainProcessor,
ImageBindVisionEvalProcessor
)
from minigpt4.processors.imagebind_audio_processor import (
ImageBindAudioTrainProcessor,
ImageBindAudioEvalProcessor,
)
from minigpt4.common.registry import registry
@ -26,7 +30,9 @@ __all__ = [
"BlipCaptionProcessor",
"ImageBindCaptionProcessor",
"ImageBindVisionTrainProcessor",
"ImageBindVisionEvalProcessor"
"ImageBindVisionEvalProcessor",
"ImageBindAudioTrainProcessor",
"ImageBindAudioEvalProcessor",
]

View File

@ -0,0 +1,190 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# @Author : Xinhao Mei @CVSSP, University of Surrey
# @E-mail : x.mei@surrey.ac.uk
"""
Implemenation of SpecAugment++,
Adapated from Qiuqiang Kong's trochlibrosa:
https://github.com/qiuqiangkong/torchlibrosa/blob/master/torchlibrosa/augmentation.py
"""
import torch
import torch.nn as nn
class DropStripes:
def __init__(self, dim, drop_width, stripes_num):
""" Drop stripes.
args:
dim: int, dimension along which to drop
drop_width: int, maximum width of stripes to drop
stripes_num: int, how many stripes to drop
"""
super(DropStripes, self).__init__()
assert dim in [2, 3] # dim 2: time; dim 3: frequency
self.dim = dim
self.drop_width = drop_width
self.stripes_num = stripes_num
def __call__(self, input):
"""input: (batch_size, channels, time_steps, freq_bins)"""
assert input.ndimension() == 4
batch_size = input.shape[0]
total_width = input.shape[self.dim]
for n in range(batch_size):
self.transform_slice(input[n], total_width)
return input
def transform_slice(self, e, total_width):
""" e: (channels, time_steps, freq_bins)"""
for _ in range(self.stripes_num):
distance = torch.randint(low=0, high=self.drop_width, size=(1,))[0]
bgn = torch.randint(low=0, high=total_width - distance, size=(1,))[0]
if self.dim == 2:
e[:, bgn: bgn + distance, :] = 0
elif self.dim == 3:
e[:, :, bgn: bgn + distance] = 0
class MixStripes:
def __init__(self, dim, mix_width, stripes_num):
""" Mix stripes
args:
dim: int, dimension along which to mix
mix_width: int, maximum width of stripes to mix
stripes_num: int, how many stripes to mix
"""
super(MixStripes, self).__init__()
assert dim in [2, 3]
self.dim = dim
self.mix_width = mix_width
self.stripes_num = stripes_num
def __call__(self, input):
"""input: (batch_size, channel, time_steps, freq_bins)"""
assert input.ndimension() == 4
batch_size = input.shape[0]
total_width = input.shape[self.dim]
rand_sample = input[torch.randperm(batch_size)]
for i in range(batch_size):
self.transform_slice(input[i], rand_sample[i], total_width)
return input
def transform_slice(self, input, random_sample, total_width):
for _ in range(self.stripes_num):
distance = torch.randint(low=0, high=self.mix_width, size=(1,))[0]
bgn = torch.randint(low=0, high=total_width - distance, size=(1,))[0]
if self.dim == 2:
input[:, bgn: bgn + distance, :] = 0.5 * input[:, bgn: bgn + distance, :] + \
0.5 * random_sample[:, bgn: bgn + distance, :]
elif self.dim == 3:
input[:, :, bgn: bgn + distance] = 0.5 * input[:, :, bgn: bgn + distance] + \
0.5 * random_sample[:, :, bgn: bgn + distance]
class CutStripes:
def __init__(self, dim, cut_width, stripes_num):
""" Cutting stripes with another randomly selected sample in mini-batch.
args:
dim: int, dimension along which to cut
cut_width: int, maximum width of stripes to cut
stripes_num: int, how many stripes to cut
"""
super(CutStripes, self).__init__()
assert dim in [2, 3]
self.dim = dim
self.cut_width = cut_width
self.stripes_num = stripes_num
def __call__(self, input):
"""input: (batch_size, channel, time_steps, freq_bins)"""
assert input.ndimension() == 4
batch_size = input.shape[0]
total_width = input.shape[self.dim]
rand_sample = input[torch.randperm(batch_size)]
for i in range(batch_size):
self.transform_slice(input[i], rand_sample[i], total_width)
return input
def transform_slice(self, input, random_sample, total_width):
for _ in range(self.stripes_num):
distance = torch.randint(low=0, high=self.cut_width, size=(1,))[0]
bgn = torch.randint(low=0, high=total_width - distance, size=(1,))[0]
if self.dim == 2:
input[:, bgn: bgn + distance, :] = random_sample[:, bgn: bgn + distance, :]
elif self.dim == 3:
input[:, :, bgn: bgn + distance] = random_sample[:, :, bgn: bgn + distance]
class SpecAugmentation:
def __init__(self, time_drop_width, time_stripes_num, freq_drop_width, freq_stripes_num,
mask_type='mixture'):
"""Spec augmetation and SpecAugment++.
[ref] Park, D.S., Chan, W., Zhang, Y., Chiu, C.C., Zoph, B., Cubuk, E.D.
and Le, Q.V., 2019. Specaugment: A simple data augmentation method
for automatic speech recognition. arXiv preprint arXiv:1904.08779.
[ref] Wang H, Zou Y, Wang W., 2021. SpecAugment++: A Hidden Space
Data Augmentation Method for Acoustic Scene Classification. arXiv
preprint arXiv:2103.16858.
Args:
time_drop_width: int
time_stripes_num: int
freq_drop_width: int
freq_stripes_num: int
mask_type: str, mask type in SpecAugment++ (zero_value, mixture, cutting)
"""
super(SpecAugmentation, self).__init__()
if mask_type == 'zero_value':
self.time_augmentator = DropStripes(dim=2, drop_width=time_drop_width,
stripes_num=time_stripes_num)
self.freq_augmentator = DropStripes(dim=3, drop_width=freq_drop_width,
stripes_num=freq_stripes_num)
elif mask_type == 'mixture':
self.time_augmentator = MixStripes(dim=2, mix_width=time_drop_width,
stripes_num=time_stripes_num)
self.freq_augmentator = MixStripes(dim=3, mix_width=freq_drop_width,
stripes_num=freq_stripes_num)
elif mask_type == 'cutting':
self.time_augmentator = CutStripes(dim=2, cut_width=time_drop_width,
stripes_num=time_stripes_num)
self.freq_augmentator = CutStripes(dim=3, cut_width=freq_drop_width,
stripes_num=freq_stripes_num)
else:
raise NameError('No such mask type in SpecAugment++')
def __call__(self, inputs):
# x should be in size [batch_size, channel, time_steps, freq_bins]
x = self.time_augmentator(inputs)
x = self.freq_augmentator(x)
return x

View File

@ -9,7 +9,7 @@ import re
from minigpt4.common.registry import registry
from minigpt4.processors.base_processor import BaseProcessor
from minigpt4.processors.randaugment import RandomAugment
from minigpt4.processors.vision_augment import RandomAugment
from omegaconf import OmegaConf
from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode

View File

@ -0,0 +1,166 @@
from typing import Union, List
import torch
import torchaudio
from omegaconf import OmegaConf
from pytorchvideo.data.clip_sampling import ConstantClipsPerVideoSampler, RandomMultiClipSampler
from torch import Tensor
from imagebind.data.data_utils import waveform2melspec, get_constant_clip_timepoints, \
get_random_clip_timepoints
from minigpt4.processors.base_processor import BaseProcessor
from torchvision import transforms
from minigpt4.common.registry import registry
from minigpt4.processors.audio_augment import SpecAugmentation
class ImageBindAudioBaseProcessor(BaseProcessor):
def __init__(self, mean=None, std=None, target_sr=None, clip_duration=None, clips_per_video=None,
num_mel_bins=None, target_length=None, clip_sample_method="Random"):
super().__init__()
self.mean = -4.268 if mean is None else mean
self.std = 9.138 if std is None else std
self.target_sr = 16000 if target_sr is None else target_sr
self.num_mel_bins = num_mel_bins
self.target_length = target_length
self.clip_sampler = self._construct_clip_sampler(clip_duration, clips_per_video, clip_sample_method)
self.normalize = transforms.Normalize(self.mean, self.std)
def _construct_clip_sampler(self, clip_duration, clips_per_video, clip_sample_method):
if clip_duration is None or clips_per_video is None:
return None
if clip_sample_method == "Constant":
return ConstantClipsPerVideoSampler(
clip_duration=clip_duration, clips_per_video=clips_per_video
)
elif clip_sample_method == "Random":
return RandomMultiClipSampler(clip_duration=clip_duration, num_clips=clips_per_video)
else:
raise NotImplementedError
def waveform_resample(self, waveform: Tensor, origin_sr: int) -> Tensor:
return torchaudio.functional.resample(
waveform, orig_freq=origin_sr, new_freq=self.target_sr
)
def clip_sample(self, waveform: Tensor) -> List[Tensor]:
if self.clip_sampler is None:
return [waveform]
elif isinstance(self.clip_sampler, ConstantClipsPerVideoSampler):
all_clips_timepoints = get_constant_clip_timepoints(self.clip_sampler, waveform.size(1) / self.target_sr)
elif isinstance(self.clip_sampler, RandomMultiClipSampler):
all_clips_timepoints = get_random_clip_timepoints(self.clip_sampler, waveform.size(1) / self.target_sr)
else:
raise NotImplementedError
all_clips = []
for clip_timepoints in all_clips_timepoints:
start_pos = int(clip_timepoints[0] * self.target_sr)
end_pos = int(clip_timepoints[1] * self.target_sr)
waveform_clip = waveform[:, start_pos: end_pos]
all_clips.append(waveform_clip)
return all_clips
def waveform_melspec(self, waveforms: Union[List[Tensor], Tensor]) -> List[Tensor]:
if isinstance(waveforms, Tensor):
return waveform2melspec(waveforms, self.target_sr, self.num_mel_bins, self.target_length)
else:
return [waveform2melspec(waveform, self.target_sr, self.num_mel_bins, self.target_length)
for waveform in waveforms]
@registry.register_processor("imagebind_audio_train")
class ImageBindAudioTrainProcessor(ImageBindAudioBaseProcessor):
def __init__(self, mean=None, std=None, target_sr=None, clip_duration=None, clips_per_video=None,
clip_sample_method="Random", num_mel_bins=None, target_length=None, time_drop_width=13,
time_stripes_num=2, freq_drop_width=8, freq_stripes_num=2, mask_type='mixture'):
super().__init__(mean=mean, std=std, target_sr=target_sr,
clip_duration=clip_duration, clips_per_video=clips_per_video,
num_mel_bins=num_mel_bins, target_length=target_length,
clip_sample_method=clip_sample_method)
self.spec_augment = SpecAugmentation(time_drop_width, time_stripes_num,
freq_drop_width, freq_stripes_num, mask_type)
def __call__(self, item):
# item: Tuple[Tensor, int]
waveform, origin_sr = item[0], item[1]
waveform = self.waveform_resample(waveform, origin_sr)
waveform_clips = self.clip_sample(waveform)
melspec_clips = self.waveform_melspec(waveform_clips)
normed_melspecs = [self.normalize(clip) for clip in melspec_clips]
all_clips = torch.stack(normed_melspecs, dim=0)
# all_clips: [clips_per_video, channel, mel_bins, time_steps]
# augment: [batch_size, channel, time_steps, freq_bins]
augmented_clips = self.spec_augment(all_clips.transpose(-2, -1)).transpose(-2, -1)
return augmented_clips
@classmethod
def from_config(cls, cfg=None):
if cfg is None:
cfg = OmegaConf.create()
target_sr = cfg.get("target_sr", 16000)
clip_duration = cfg.get("clip_duration", 2)
clips_per_video = cfg.get("clips_per_video", 3)
num_mel_bins = cfg.get("num_mel_bins", 128)
target_length = cfg.get("target_length", 204)
time_drop_width = cfg.get("time_drop_width", 13)
time_stripes_num = cfg.get("time_stripes_num", 2)
# 13 * 2 / 204 = 12.75% Time Mask
freq_drop_width = cfg.get("freq_drop_width", 8)
freq_stripes_num = cfg.get("freq_stripes_num", 2)
# 8 * 2 / 128 = 12.5% Freq Mask
mask_type = cfg.get("mask_type", 'mixture')
mean = cfg.get("mean", None)
std = cfg.get("std", None)
return cls(
mean=mean, std=std, target_sr=target_sr,
clip_duration=clip_duration, clips_per_video=clips_per_video,
num_mel_bins=num_mel_bins, target_length=target_length,
time_drop_width=time_drop_width, time_stripes_num=time_stripes_num,
freq_drop_width=freq_drop_width, freq_stripes_num=freq_stripes_num,
mask_type=mask_type
)
@registry.register_processor("imagebind_audio_eval")
class ImageBindAudioEvalProcessor(ImageBindAudioBaseProcessor):
def __init__(self, mean=None, std=None, target_sr=None, clip_duration=None, clips_per_video=None,
clip_sample_method="Constant", num_mel_bins=None, target_length=None):
super().__init__(mean=mean, std=std, target_sr=target_sr,
clip_duration=clip_duration, clips_per_video=clips_per_video,
num_mel_bins=num_mel_bins, target_length=target_length,
clip_sample_method=clip_sample_method)
def __call__(self, item):
# item: Tuple[Tensor, int]
waveform, origin_sr = item[0], item[1]
waveform = self.waveform_resample(waveform, origin_sr)
waveform_clips = self.clip_sample(waveform)
melspec_clips = self.waveform_melspec(waveform_clips)
normed_melspecs = [self.normalize(clip) for clip in melspec_clips]
all_clips = torch.stack(normed_melspecs, dim=0)
# all_clips: [clips_per_video, channel, mel_bins, time_steps]
return all_clips
@classmethod
def from_config(cls, cfg=None):
if cfg is None:
cfg = OmegaConf.create()
target_sr = cfg.get("target_sr", 16000)
clip_duration = cfg.get("clip_duration", 2)
clips_per_video = cfg.get("clips_per_video", 3)
num_mel_bins = cfg.get("num_mel_bins", 128)
target_length = cfg.get("target_length", 204)
mean = cfg.get("mean", None)
std = cfg.get("std", None)
return cls(
mean=mean, std=std, target_sr=target_sr,
clip_duration=clip_duration, clips_per_video=clips_per_video,
num_mel_bins=num_mel_bins, target_length=target_length
)

View File

@ -9,7 +9,7 @@ import re
from minigpt4.common.registry import registry
from minigpt4.processors.base_processor import BaseProcessor
from minigpt4.processors.randaugment import RandomAugment
from minigpt4.processors.vision_augment import RandomAugment
from omegaconf import OmegaConf
from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode
@ -147,3 +147,5 @@ class ImageBindVisionEvalProcessor(ImageBindVisionBaseProcessor):
return cls(image_size=image_size, mean=mean, std=std)

View File

@ -53,7 +53,6 @@ ipdb
tensorflow-cpu
tensorboardX
# mmcv==2.0.0 -f https://download.openmmlab.com/mmcv/dist/cu117/torch1.13/index.html
bs4==0.0.1 # Needed for text cleaning
bson==0.5.10
byted-dataloader==0.3.6
@ -65,4 +64,5 @@ sentencepiece==0.1.99 # Needed for T5 tokenizer
tensorboard==2.11.2
tensorflow==2.11.0 # Needed for tensorboard hdfs support
tensorflow-io==0.30.0 # Needed for tensorboard hdfs support
tqdm==4.64.1
tqdm==4.64.1
pytorchvideo==0.1.5

View File

@ -0,0 +1,68 @@
model:
arch: bind_gpt4
model_type: pretrain_vicuna
freeze_imagebind: True
freeze_qformer: False
datasets:
bbc:
audio_processor:
train:
name: "imagebind_audio_train"
text_processor:
train:
name: "imagebind_caption"
sample_ratio: 31
audioset:
audio_processor:
train:
name: "imagebind_audio_train"
text_processor:
train:
name: "imagebind_caption"
sample_ratio: 108
soundbible:
audio_processor:
train:
name: "imagebind_audio_train"
text_processor:
train:
name: "imagebind_caption"
sample_ratio: 2
freesound:
audio_processor:
train:
name: "imagebind_audio_train"
text_processor:
train:
name: "imagebind_caption"
sample_ratio: 262
run:
task: image_text_pretrain
# optimizer
lr_sched: "linear_warmup_cosine_lr"
init_lr: 1e-4
min_lr: 8e-5
warmup_lr: 1e-6
weight_decay: 0.05
max_epoch: 4
batch_size_train: 64
batch_size_eval: 64
num_workers: 4
warmup_steps: 5000
iters_per_epoch: 5000
seed: 42
output_dir: "output/bindgpt4_stage1_audio_pretrain"
amp: True
resume_ckpt_path: null
evaluate: False
train_splits: ["train"]
device: "cuda"
world_size: 1
dist_url: "env://"
distributed: True