From 05220fe3c13f1b7af6aa2e052a2e9240995f07d6 Mon Sep 17 00:00:00 2001 From: Bingyi Kang Date: Fri, 26 May 2023 11:44:18 +0800 Subject: [PATCH] init audio data config (#2) - Add audio datasets - Add audio processors - Add audio support in bindgpt - Add audio training config --------- Co-authored-by: bingyikang Co-authored-by: zhaoyang <913556700@qq.com> --- README.md | 175 +--------------- arnold_before.sh | 4 +- eval_scripts/eval_utils.py | 10 + eval_scripts/qualitative_eval.py | 3 +- imagebind/data/data_utils.py | 36 ++-- imagebind/models/image_bind.py | 36 ++-- minigpt4/common/registry.py | 4 +- .../configs/datasets/audioset/defaults.yaml | 5 + minigpt4/configs/datasets/bbc/defaults.yaml | 5 + .../configs/datasets/freesound/defaults.yaml | 5 + .../configs/datasets/soundbible/defaults.yaml | 5 + minigpt4/datasets/builders/__init__.py | 13 +- .../builders/audio_base_dataset_builder.py | 47 ++++- .../builders/audio_text_pair_builder.py | 56 ++++++ .../datasets/audio_caption/__init__.py | 1 + .../audio_caption/audio_caption_datasets.py | 9 +- .../datasets/image_caption/cc_sbu_dataset.py | 10 +- .../image_caption/image_caption_datasets.py | 4 +- .../datasets/image_caption/laion_dataset.py | 2 +- minigpt4/datasets/datasets/mixins/mixins.py | 2 +- minigpt4/models/bind_gpt4.py | 104 ++++++---- minigpt4/models/blip2.py | 2 +- minigpt4/models/mini_gpt4.py | 2 +- minigpt4/processors/__init__.py | 10 +- minigpt4/processors/audio_augment.py | 190 ++++++++++++++++++ minigpt4/processors/blip_processors.py | 2 +- .../processors/imagebind_audio_processor.py | 166 +++++++++++++++ ...essor.py => imagebind_vision_processor.py} | 4 +- .../{randaugment.py => vision_augment.py} | 0 requirements.txt | 4 +- .../bindgpt4_stage1_audio_pretrain.yaml | 68 +++++++ 31 files changed, 717 insertions(+), 267 deletions(-) create mode 100644 minigpt4/configs/datasets/audioset/defaults.yaml create mode 100644 minigpt4/configs/datasets/bbc/defaults.yaml create mode 100644 minigpt4/configs/datasets/freesound/defaults.yaml create mode 100644 minigpt4/configs/datasets/soundbible/defaults.yaml create mode 100644 minigpt4/datasets/builders/audio_text_pair_builder.py create mode 100644 minigpt4/processors/audio_augment.py create mode 100644 minigpt4/processors/imagebind_audio_processor.py rename minigpt4/processors/{imagebind_processor.py => imagebind_vision_processor.py} (98%) rename minigpt4/processors/{randaugment.py => vision_augment.py} (100%) create mode 100644 train_configs/bindgpt4_stage1_audio_pretrain.yaml diff --git a/README.md b/README.md index b1f8961..8979bf1 100644 --- a/README.md +++ b/README.md @@ -1,170 +1,5 @@ -# MiniGPT-4: Enhancing Vision-language Understanding with Advanced Large Language Models -[Deyao Zhu](https://tsutikgiau.github.io/)* (On Job Market!), [Jun Chen](https://junchen14.github.io/)* (On Job Market!), [Xiaoqian Shen](https://xiaoqian-shen.github.io), [Xiang Li](https://xiangli.ac.cn), and [Mohamed Elhoseiny](https://www.mohamed-elhoseiny.com/). *Equal Contribution - -**King Abdullah University of Science and Technology** - - [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1OK4kYsZphwt5DXchKkzMBjYF6jnkqh4R?usp=sharing) [![YouTube](https://badges.aleen42.com/src/youtube.svg)](https://www.youtube.com/watch?v=__tftoxpBAw&feature=youtu.be) - - -## News -We now provide a pretrained MiniGPT-4 aligned with Vicuna-7B! The demo GPU memory consumption now can be as low as 12GB. - - -## Online Demo - -Click the image to chat with MiniGPT-4 around your images -[![demo](figs/online_demo.png)](https://minigpt-4.github.io) - - -## Examples - | | | -:-------------------------:|:-------------------------: -![find wild](figs/examples/wop_2.png) | ![write story](figs/examples/ad_2.png) -![solve problem](figs/examples/fix_1.png) | ![write Poem](figs/examples/rhyme_1.png) - -More examples can be found in the [project page](https://minigpt-4.github.io). - - - -## Introduction -- MiniGPT-4 aligns a frozen visual encoder from BLIP-2 with a frozen LLM, Vicuna, using just one projection layer. -- We train MiniGPT-4 with two stages. The first traditional pretraining stage is trained using roughly 5 million aligned image-text pairs in 10 hours using 4 A100s. After the first stage, Vicuna is able to understand the image. But the generation ability of Vicuna is heavilly impacted. -- To address this issue and improve usability, we propose a novel way to create high-quality image-text pairs by the model itself and ChatGPT together. Based on this, we then create a small (3500 pairs in total) yet high-quality dataset. -- The second finetuning stage is trained on this dataset in a conversation template to significantly improve its generation reliability and overall usability. To our surprise, this stage is computationally efficient and takes only around 7 minutes with a single A100. -- MiniGPT-4 yields many emerging vision-language capabilities similar to those demonstrated in GPT-4. - - -![overview](figs/overview.png) - - -## Getting Started -### Installation - -**1. Prepare the code and the environment** - -Git clone our repository, creating a python environment and ativate it via the following command - -```bash -git clone https://github.com/Vision-CAIR/MiniGPT-4.git -cd MiniGPT-4 -conda env create -f environment.yml -conda activate minigpt4 -``` - - -**2. Prepare the pretrained Vicuna weights** - -The current version of MiniGPT-4 is built on the v0 versoin of Vicuna-13B. -Please refer to our instruction [here](PrepareVicuna.md) -to prepare the Vicuna weights. -The final weights would be in a single folder in a structure similar to the following: - -``` -vicuna_weights -├── config.json -├── generation_config.json -├── pytorch_model.bin.index.json -├── pytorch_model-00001-of-00003.bin -... -``` - -Then, set the path to the vicuna weight in the model config file -[here](minigpt4/configs/models/minigpt4.yaml#L16) at Line 16. - -**3. Prepare the pretrained MiniGPT-4 checkpoint** - -Download the pretrained checkpoints according to the Vicuna model you prepare. - -| Checkpoint Aligned with Vicuna 13B | Checkpoint Aligned with Vicuna 7B | -:------------------------------------------------------------------------------------------------:|:----------------------------------------------------------------------------------------------: - [Downlad](https://drive.google.com/file/d/1a4zLvaiDBr-36pasffmgpvH5P7CKmpze/view?usp=share_link) | [Download](https://drive.google.com/file/d/1RY9jV0dyqLX-o38LrumkKRh6Jtaop58R/view?usp=sharing) - - -Then, set the path to the pretrained checkpoint in the evaluation config file -in [eval_configs/minigpt4_eval.yaml](eval_configs/minigpt4_eval.yaml#L10) at Line 11. - - - -### Launching Demo Locally - -Try out our demo [demo.py](eval_scripts/qualitative_eval.py) on your local machine by running - -``` -python demo.py --cfg-path eval_configs/minigpt4_eval.yaml --gpu-id 0 -``` - -To save GPU memory, Vicuna loads as 8 bit by default, with a beam search width of 1. -This configuration requires about 23G GPU memory for Vicuna 13B and 11.5G GPU memory for Vicuna 7B. -For more powerful GPUs, you can run the model -in 16 bit by setting low_resource to False in the config file -[minigpt4_eval.yaml](eval_configs/minigpt4_eval.yaml) and use a larger beam search width. - -Thanks [@WangRongsheng](https://github.com/WangRongsheng), you can also run our code on [Colab](https://colab.research.google.com/drive/1OK4kYsZphwt5DXchKkzMBjYF6jnkqh4R?usp=sharing) - - -### Training -The training of MiniGPT-4 contains two alignment stages. - -**1. First pretraining stage** - -In the first pretrained stage, the model is trained using image-text pairs from Laion and CC datasets -to align the vision and language model. To download and prepare the datasets, please check -our [first stage dataset preparation instruction](dataset/README_1_STAGE.md). -After the first stage, the visual features are mapped and can be understood by the language -model. -To launch the first stage training, run the following command. In our experiments, we use 4 A100. -You can change the save path in the config file -[train_configs/minigpt4_stage1_pretrain.yaml](train_configs/minigpt4_stage1_pretrain.yaml) - -```bash -torchrun --nproc-per-node NUM_GPU train.py --cfg-path train_configs/minigpt4_stage1_pretrain.yaml -``` - -A MiniGPT-4 checkpoint with only stage one training can be downloaded -[here (13B)](https://drive.google.com/file/d/1u9FRRBB3VovP1HxCAlpD9Lw4t4P6-Yq8/view?usp=share_link) or [here (7B)](https://drive.google.com/file/d/1HihQtCEXUyBM1i9DQbaK934wW3TZi-h5/view?usp=share_link). -Compared to the model after stage two, this checkpoint generate incomplete and repeated sentences frequently. - - -**2. Second finetuning stage** - -In the second stage, we use a small high quality image-text pair dataset created by ourselves -and convert it to a conversation format to further align MiniGPT-4. -To download and prepare our second stage dataset, please check our -[second stage dataset preparation instruction](dataset/README_2_STAGE.md). -To launch the second stage alignment, -first specify the path to the checkpoint file trained in stage 1 in -[train_configs/minigpt4_stage1_pretrain.yaml](train_configs/minigpt4_stage2_finetune.yaml). -You can also specify the output path there. -Then, run the following command. In our experiments, we use 1 A100. - -```bash -torchrun --nproc-per-node NUM_GPU train.py --cfg-path train_configs/minigpt4_stage2_finetune.yaml -``` - -After the second stage alignment, MiniGPT-4 is able to talk about the image coherently and user-friendly. - - - - -## Acknowledgement - -+ [BLIP2](https://huggingface.co/docs/transformers/main/model_doc/blip-2) The model architecture of MiniGPT-4 follows BLIP-2. Don't forget to check this great open-source work if you don't know it before! -+ [Lavis](https://github.com/salesforce/LAVIS) This repository is built upon Lavis! -+ [Vicuna](https://github.com/lm-sys/FastChat) The fantastic language ability of Vicuna with only 13B parameters is just amazing. And it is open-source! - - -If you're using MiniGPT-4 in your research or applications, please cite using this BibTeX: -```bibtex -@article{zhu2023minigpt, - title={MiniGPT-4: Enhancing Vision-Language Understanding with Advanced Large Language Models}, - author={Zhu, Deyao and Chen, Jun and Shen, Xiaoqian and Li, Xiang and Elhoseiny, Mohamed}, - journal={arXiv preprint arXiv:2304.10592}, - year={2023} -} -``` - - -## License -This repository is under [BSD 3-Clause License](LICENSE.md). -Many codes are based on [Lavis](https://github.com/salesforce/LAVIS) with -BSD 3-Clause License [here](LICENSE_Lavis.md). +# TODO List +- [ ] reconstruct the eval scripts +- [ ] support audio evaluation +- [ ] clean code repo & solve recursive import +- [x] audio processor: random sampling \ No newline at end of file diff --git a/arnold_before.sh b/arnold_before.sh index 844651a..6b019dc 100644 --- a/arnold_before.sh +++ b/arnold_before.sh @@ -6,7 +6,9 @@ pip3 install --upgrade pip pip3 install -r requirements.txt pip3 install byted-dataloader -i "https://bytedpypi.byted.org/simple" -mmengine-0.7.3 +pip3 install mmmengine==0.7.3 +pip3 install mmcv==2.0.0 -f https://download.openmmlab.com/mmcv/dist/cu117/torch1.13/index.html + # unset http_proxy && unset https_proxy && unset no_proxy # # ---------------------------------------------------------------------------------------- diff --git a/eval_scripts/eval_utils.py b/eval_scripts/eval_utils.py index 45aec39..2f8b7a8 100644 --- a/eval_scripts/eval_utils.py +++ b/eval_scripts/eval_utils.py @@ -1,4 +1,5 @@ import torch +import torchaudio from PIL import Image @@ -13,3 +14,12 @@ def load_image(image, image_processor): if len(image.shape) == 3: image = image.unsqueeze(0) return image + + +def load_audio(audio, audio_processor): + if isinstance(audio, str): # is a audio path + raw_audio = torchaudio.load(audio) + audio = audio_processor(audio) + # elif isinstance(audio, ) + else: + raise NotImplementedError diff --git a/eval_scripts/qualitative_eval.py b/eval_scripts/qualitative_eval.py index ebc5012..c10cbe0 100644 --- a/eval_scripts/qualitative_eval.py +++ b/eval_scripts/qualitative_eval.py @@ -125,6 +125,7 @@ with gr.Blocks() as demo: with gr.Row(): with gr.Column(scale=0.5): image = gr.Image(type="pil") + # audio = gr.Audio() upload_button = gr.Button(value="Upload & Start Chat", interactive=True, variant="primary") clear = gr.Button("Restart") @@ -150,7 +151,7 @@ with gr.Blocks() as demo: chat_state = gr.State() emb_list = gr.State() chatbot = gr.Chatbot(label='BindGPT-4') - text_input = gr.Textbox(label='User', placeholder='Please upload your image first', interactive=False) + text_input = gr.Textbox(label='User', placeholder='Please upload your image/audio first', interactive=False) upload_button.click(upload_img, [image, text_input, chat_state], [image, text_input, upload_button, chat_state, emb_list]) diff --git a/imagebind/data/data_utils.py b/imagebind/data/data_utils.py index 49ec571..a3c6c79 100644 --- a/imagebind/data/data_utils.py +++ b/imagebind/data/data_utils.py @@ -15,7 +15,7 @@ import logging from imagebind.models.multimodal_preprocessors import SimpleTokenizer from PIL import Image from pytorchvideo import transforms as pv_transforms -from pytorchvideo.data.clip_sampling import ConstantClipsPerVideoSampler +from pytorchvideo.data.clip_sampling import ConstantClipsPerVideoSampler, RandomMultiClipSampler from pytorchvideo.data.encoded_video import EncodedVideo from torchvision import transforms @@ -59,23 +59,11 @@ def waveform2melspec(waveform, sample_rate, num_mel_bins, target_length): fbank = torch.nn.functional.pad(fbank, (0, p), mode="constant", value=0) elif p < 0: fbank = fbank[:, 0:target_length] - # Convert to [1, mel_bins, num_frames] shape, essentially like a 1 - # channel image + # Convert to [1, mel_bins, num_frames] shape, essentially like a 1 channel image fbank = fbank.unsqueeze(0) return fbank -def get_clip_timepoints(clip_sampler, duration): - # Read out all clips in this video - all_clips_timepoints = [] - is_last_clip = False - end = 0.0 - while not is_last_clip: - start, end, _, _, is_last_clip = clip_sampler(end, duration, annotation=None) - all_clips_timepoints.append((start, end)) - return all_clips_timepoints - - def load_and_transform_vision_data(image_paths, device): if image_paths is None: return None @@ -137,14 +125,14 @@ def load_and_transform_audio_data( waveform = torchaudio.functional.resample( waveform, orig_freq=sr, new_freq=sample_rate ) - all_clips_timepoints = get_clip_timepoints( + all_clips_timepoints = get_constant_clip_timepoints( clip_sampler, waveform.size(1) / sample_rate ) all_clips = [] for clip_timepoints in all_clips_timepoints: waveform_clip = waveform[ :, - int(clip_timepoints[0] * sample_rate) : int( + int(clip_timepoints[0] * sample_rate): int( clip_timepoints[1] * sample_rate ), ] @@ -162,7 +150,8 @@ def load_and_transform_audio_data( return torch.stack(audio_outputs, dim=0) -def get_clip_timepoints(clip_sampler, duration): +def get_constant_clip_timepoints(clip_sampler, duration): + assert isinstance(clip_sampler, ConstantClipsPerVideoSampler), "Incompatible Type of Sampler!" # Read out all clips in this video all_clips_timepoints = [] is_last_clip = False @@ -173,6 +162,13 @@ def get_clip_timepoints(clip_sampler, duration): return all_clips_timepoints +def get_random_clip_timepoints(clip_sampler, duration): + assert isinstance(clip_sampler, RandomMultiClipSampler), "Incompatible Type of Sampler!" + starts, ends, _, _, _ = clip_sampler(0.0, duration, annotation=None) + all_clips_timepoints = sorted(list(zip(starts, ends)), key=lambda x: x[0]) + return all_clips_timepoints + + def crop_boxes(boxes, x_offset, y_offset): """ Perform crop on the bounding boxes given the offsets. @@ -244,7 +240,7 @@ def uniform_crop(images, size, spatial_idx, boxes=None, scale_size=None): x_offset = 0 elif spatial_idx == 2: x_offset = width - size - cropped = images[:, :, y_offset : y_offset + size, x_offset : x_offset + size] + cropped = images[:, :, y_offset: y_offset + size, x_offset: x_offset + size] cropped_boxes = crop_boxes(boxes, x_offset, y_offset) if boxes is not None else None if ndim == 3: cropped = cropped.squeeze(0) @@ -328,7 +324,7 @@ def load_and_transform_video_data( **{"sample_rate": sample_rate}, ) - all_clips_timepoints = get_clip_timepoints(clip_sampler, video.duration) + all_clips_timepoints = get_constant_clip_timepoints(clip_sampler, video.duration) all_video = [] for clip_timepoints in all_clips_timepoints: @@ -347,4 +343,4 @@ def load_and_transform_video_data( all_video = torch.stack(all_video, dim=0) video_outputs.append(all_video) - return torch.stack(video_outputs, dim=0).to(device) \ No newline at end of file + return torch.stack(video_outputs, dim=0).to(device) diff --git a/imagebind/models/image_bind.py b/imagebind/models/image_bind.py index c93fcd0..487f479 100644 --- a/imagebind/models/image_bind.py +++ b/imagebind/models/image_bind.py @@ -50,46 +50,59 @@ ModalityType = SimpleNamespace( class ImageBindJoiner(nn.Module): def __init__(self, vision_query_token_num: int, + audio_query_token_num: int, vision_qformer_frozen: bool = False, vision_qformer_model: str = "", # The url or path of pre-trained vision Q-Former model vision_pre_dims: List[int] = (), # Projection before Q-Former - vision_post_dims: List[int] = (768, 768) # Projection after Q-Former + vision_post_dims: List[int] = (1280, 768), # Projection after Q-Former + audio_pre_dims: List[int] = (), + audio_post_dims: List[int] = (768, 768) ): super().__init__() assert not (vision_qformer_frozen and vision_qformer_model == "") - self.modality_pre_projectors = self._create_modality_pre_projectors(vision_pre_dims) + self.modality_pre_projectors = self._create_modality_pre_projectors(vision_pre_dims, audio_pre_dims) self.modality_qformers = self._create_modality_qformers(vision_query_token_num, vision_qformer_frozen, - vision_qformer_model) - self.modality_post_projectors = self._create_modality_post_projectors(vision_post_dims) + vision_qformer_model, + audio_query_token_num) + self.modality_post_projectors = self._create_modality_post_projectors(vision_post_dims, audio_post_dims) def _create_modality_pre_projectors(self, - vision_pre_dims + vision_pre_dims, + audio_pre_dims ): modality_pre_projectors = { - ModalityType.VISION: create_projectors(vision_pre_dims) + ModalityType.VISION: create_projectors(vision_pre_dims), + ModalityType.AUDIO: create_projectors(audio_pre_dims) } return modality_pre_projectors def _create_modality_qformers(self, vision_query_token_num, vision_qformer_frozen, - vision_qformer_model + vision_qformer_model, + audio_query_token_num ): vision_qformer = SequenceGenericQFormer(num_query_token=vision_query_token_num, freeze_qformer=vision_qformer_frozen, encoder_width=1280, # TODO: fix hard-coding q_former_model=vision_qformer_model) + audio_qformer = SequenceGenericQFormer(num_query_token=audio_query_token_num, + freeze_qformer=False, + encoder_width=768) modality_qformers = { - ModalityType.VISION: vision_qformer + ModalityType.VISION: vision_qformer, + ModalityType.AUDIO: audio_qformer } return nn.ModuleDict(modality_qformers) - def _create_modality_post_projectors(self, vision_post_dims): + def _create_modality_post_projectors(self, vision_post_dims, audio_post_dims): vision_projector = create_projectors(vision_post_dims) + audio_projector = create_projectors(audio_post_dims) modality_projectors = { - ModalityType.VISION: vision_projector + ModalityType.VISION: vision_projector, + ModalityType.AUDIO: audio_projector } return nn.ModuleDict(modality_projectors) @@ -97,7 +110,6 @@ class ImageBindJoiner(nn.Module): def forward(self, inputs: Dict[str, Tensor]) -> Dict[str, Tensor]: outputs = {} for modality_key, modality_value in inputs.items(): - # assert modality_key == ModalityType.VISION, "Only Vision is Currently Supported." if modality_value is not None: modality_value = self.modality_pre_projectors[modality_key](modality_value) modality_value = self.modality_qformers[modality_key](modality_value) @@ -544,7 +556,7 @@ class ImageBindModel(nn.Module): # NOTE: The reduction operation has been modified. if reduce_list: - modality_value = modality_value.reshape(B, S, *modality_value[2:]) + modality_value = modality_value.reshape(B, S, *modality_value.shape[1:]) modality_value = modality_value.mean(dim=1) outputs[modality_key] = modality_value diff --git a/minigpt4/common/registry.py b/minigpt4/common/registry.py index 75a6c23..b0a24cb 100644 --- a/minigpt4/common/registry.py +++ b/minigpt4/common/registry.py @@ -32,10 +32,12 @@ class Registry: """ def wrap(builder_cls): + # TODO: merge them or split builders by modality from minigpt4.datasets.builders.image_base_dataset_builder import ImageBaseDatasetBuilder + from minigpt4.datasets.builders.audio_base_dataset_builder import AudioBaseDatasetBuilder assert issubclass( - builder_cls, ImageBaseDatasetBuilder + builder_cls, (ImageBaseDatasetBuilder, AudioBaseDatasetBuilder) ), "All builders must inherit BaseDatasetBuilder class, found {}".format( builder_cls ) diff --git a/minigpt4/configs/datasets/audioset/defaults.yaml b/minigpt4/configs/datasets/audioset/defaults.yaml new file mode 100644 index 0000000..4ce44bf --- /dev/null +++ b/minigpt4/configs/datasets/audioset/defaults.yaml @@ -0,0 +1,5 @@ +datasets: + audioset: + data_type: audio + build_info: + storage: /mnt/bn/zilongdata-hl/dataset/wavcaps/web_datasets/AudioSet_SL/AudioSet_SL{00..54}.tar diff --git a/minigpt4/configs/datasets/bbc/defaults.yaml b/minigpt4/configs/datasets/bbc/defaults.yaml new file mode 100644 index 0000000..a0aa380 --- /dev/null +++ b/minigpt4/configs/datasets/bbc/defaults.yaml @@ -0,0 +1,5 @@ +datasets: + bbc: + data_type: audio + build_info: + storage: /mnt/bn/zilongdata-hl/dataset/wavcaps/web_datasets/BBC_Sound_Effects/BBC_Sound_Effects{000000..000062}.tar diff --git a/minigpt4/configs/datasets/freesound/defaults.yaml b/minigpt4/configs/datasets/freesound/defaults.yaml new file mode 100644 index 0000000..edc520f --- /dev/null +++ b/minigpt4/configs/datasets/freesound/defaults.yaml @@ -0,0 +1,5 @@ +datasets: + freesound: + data_type: audio + build_info: + storage: /mnt/bn/zilongdata-hl/dataset/wavcaps/web_datasets/FreeSound/FreeSound{000000..000524}.tar diff --git a/minigpt4/configs/datasets/soundbible/defaults.yaml b/minigpt4/configs/datasets/soundbible/defaults.yaml new file mode 100644 index 0000000..7b46a67 --- /dev/null +++ b/minigpt4/configs/datasets/soundbible/defaults.yaml @@ -0,0 +1,5 @@ +datasets: + soundbible: + data_type: audio + build_info: + storage: /mnt/bn/zilongdata-hl/dataset/wavcaps/web_datasets/SoundBible/SoundBible0.tar diff --git a/minigpt4/datasets/builders/__init__.py b/minigpt4/datasets/builders/__init__.py index 744fe2f..7fd3f7d 100644 --- a/minigpt4/datasets/builders/__init__.py +++ b/minigpt4/datasets/builders/__init__.py @@ -11,12 +11,23 @@ from minigpt4.datasets.builders.image_text_pair_builder import ( LaionBuilderImage, CCSBUAlignBuilderImage ) +from minigpt4.datasets.builders.audio_text_pair_builder import ( + BBCBuilder, + AudioSetBuilder, + SoundBibleBuilder, + FreeSoundBuilder +) from minigpt4.common.registry import registry __all__ = [ "CCSBUBuilderImage", "LaionBuilderImage", - "CCSBUAlignBuilderImage" + "CCSBUAlignBuilderImage", + # Audio builders + "BBCBuilder", + "AudioSetBuilder", + "SoundBibleBuilder", + "FreeSoundBuilder", ] diff --git a/minigpt4/datasets/builders/audio_base_dataset_builder.py b/minigpt4/datasets/builders/audio_base_dataset_builder.py index 4e1361e..711fb7f 100644 --- a/minigpt4/datasets/builders/audio_base_dataset_builder.py +++ b/minigpt4/datasets/builders/audio_base_dataset_builder.py @@ -31,6 +31,51 @@ class AudioBaseDatasetBuilder: self.data_type = self.config.data_type + self.audio_processors = {"train": BaseProcessor(), "eval": BaseProcessor()} + self.text_processors = {"train": BaseProcessor(), "eval": BaseProcessor()} + + def build_datasets(self): + # download, split, etc... + # only called on 1 GPU/TPU in distributed + + if is_main_process(): + self._download_data() + + if is_dist_avail_and_initialized(): + dist.barrier() + + # at this point, all the annotations and image/videos should be all downloaded to the specified locations. + logging.info("Building datasets...") + datasets = self.build() # dataset['train'/'val'/'test'] + + return datasets + + def build_processors(self): + aud_proc_cfg = self.config.get("audio_processor") + txt_proc_cfg = self.config.get("text_processor") + + if aud_proc_cfg is not None: + aud_train_cfg = aud_proc_cfg.get("train") + aud_eval_cfg = aud_proc_cfg.get("eval") + + self.audio_processors["train"] = self._build_proc_from_cfg(aud_train_cfg) + self.audio_processors["eval"] = self._build_proc_from_cfg(aud_eval_cfg) + + if txt_proc_cfg is not None: + txt_train_cfg = txt_proc_cfg.get("train") + txt_eval_cfg = txt_proc_cfg.get("eval") + + self.text_processors["train"] = self._build_proc_from_cfg(txt_train_cfg) + self.text_processors["eval"] = self._build_proc_from_cfg(txt_eval_cfg) + + @staticmethod + def _build_proc_from_cfg(cfg): + return ( + registry.get_processor_class(cfg.name).from_config(cfg) + if cfg is not None + else None + ) + @classmethod def default_config_path(cls, type="default"): return utils.get_abs_path(cls.DATASET_CONFIG_DICT[type]) @@ -95,5 +140,3 @@ class AudioBaseDatasetBuilder: filename = os.path.basename(storage_path) download_url(url=url_or_filename, root=dirname, filename=filename) - - diff --git a/minigpt4/datasets/builders/audio_text_pair_builder.py b/minigpt4/datasets/builders/audio_text_pair_builder.py new file mode 100644 index 0000000..d63bdcb --- /dev/null +++ b/minigpt4/datasets/builders/audio_text_pair_builder.py @@ -0,0 +1,56 @@ +import os +import logging +import warnings + +from minigpt4.common.registry import registry +from minigpt4.datasets.builders.audio_base_dataset_builder import AudioBaseDatasetBuilder +from minigpt4.datasets.datasets.audio_caption import GenericAudioDataset + + + +class GenericAudioBuilder(AudioBaseDatasetBuilder): + train_dataset_cls = GenericAudioDataset + + def _download_ann(self): + pass + + def _download_aud(self): + pass + + def build(self): + self.build_processors() + + build_info = self.config.build_info + + datasets = dict() + split = "train" + + # create datasets + dataset_cls = self.train_dataset_cls + datasets[split] = dataset_cls( + audio_processor=self.audio_processors[split], + text_processor=self.text_processors[split], + location=build_info.storage, + ).inner_dataset + + return datasets + + +@registry.register_builder("bbc") +class BBCBuilder(GenericAudioBuilder): + DATASET_CONFIG_DICT = {"default": "configs/datasets/bbc/defaults.yaml"} + + +@registry.register_builder("audioset") +class AudioSetBuilder(GenericAudioBuilder): + DATASET_CONFIG_DICT = {"default": "configs/datasets/audioset/defaults.yaml"} + + +@registry.register_builder("soundbible") +class SoundBibleBuilder(GenericAudioBuilder): + DATASET_CONFIG_DICT = {"default": "configs/datasets/soundbible/defaults.yaml"} + + +@registry.register_builder("freesound") +class FreeSoundBuilder(GenericAudioBuilder): + DATASET_CONFIG_DICT = {"default": "configs/datasets/freesound/defaults.yaml"} diff --git a/minigpt4/datasets/datasets/audio_caption/__init__.py b/minigpt4/datasets/datasets/audio_caption/__init__.py index e69de29..c511eb3 100644 --- a/minigpt4/datasets/datasets/audio_caption/__init__.py +++ b/minigpt4/datasets/datasets/audio_caption/__init__.py @@ -0,0 +1 @@ +from minigpt4.datasets.datasets.audio_caption.audio_caption_datasets import GenericAudioDataset diff --git a/minigpt4/datasets/datasets/audio_caption/audio_caption_datasets.py b/minigpt4/datasets/datasets/audio_caption/audio_caption_datasets.py index 52b6dcd..57aa027 100644 --- a/minigpt4/datasets/datasets/audio_caption/audio_caption_datasets.py +++ b/minigpt4/datasets/datasets/audio_caption/audio_caption_datasets.py @@ -6,21 +6,22 @@ from minigpt4.datasets.datasets.base_dataset import BaseDataset class GenericAudioDataset(BaseDataset): - def __init__(self, vision_processor, text_processor, location): - super().__init__(x_processor=vision_processor, text_processor=text_processor) + def __init__(self, audio_processor, text_processor, location): + super().__init__(x_processor=audio_processor, text_processor=text_processor) self.inner_dataset = wds.DataPipeline( wds.ResampledShards(location), wds.tarfile_to_samples(handler=wds.warn_and_continue), wds.shuffle(1000, handler=wds.warn_and_continue), wds.decode(wds.torch_audio, handler=wds.warn_and_continue), - wds.to_tuple("jpg", "json", handler=wds.warn_and_continue), + wds.to_tuple("flac", "json", handler=wds.warn_and_continue), wds.map_tuple(self.x_processor, handler=wds.warn_and_continue), wds.map(self.to_dict, handler=wds.warn_and_continue), ) def to_dict(self, sample): return { - "image": sample[0], + "audio": sample[0], + # [clips_per_video, channel, mel_bins, time_steps] "text_input": self.text_processor(sample[1]["caption"]), } diff --git a/minigpt4/datasets/datasets/image_caption/cc_sbu_dataset.py b/minigpt4/datasets/datasets/image_caption/cc_sbu_dataset.py index eaaba91..9a4ffde 100644 --- a/minigpt4/datasets/datasets/image_caption/cc_sbu_dataset.py +++ b/minigpt4/datasets/datasets/image_caption/cc_sbu_dataset.py @@ -21,7 +21,7 @@ class CCSBUDataset(BaseDataset): def to_dict(self, sample): return { - "image": sample[0], + "vision": sample[0], "text_input": self.text_processor(sample[1]["caption"]), } @@ -41,7 +41,7 @@ class CCSBUAlignDatasetImageImageCaptionDataset(ImageCaptionDataset): caption = ann["caption"] return { - "image": image, + "vision": image, "text_input": caption, "image_id": self.img_ids[ann["image_id"]], } @@ -49,7 +49,7 @@ class CCSBUAlignDatasetImageImageCaptionDataset(ImageCaptionDataset): class CCDataset(BaseDataset): def __init__(self, vis_processor, text_processor, location): - super().__init__(vis_processor=vis_processor, text_processor=text_processor) + super().__init__(x_processor=vis_processor, text_processor=text_processor) self.inner_dataset = wds.DataPipeline( wds.ResampledShards(location), @@ -57,12 +57,12 @@ class CCDataset(BaseDataset): wds.shuffle(1000, handler=wds.warn_and_continue), wds.decode("pilrgb", handler=wds.warn_and_continue), wds.to_tuple("jpg", "txt", handler=wds.warn_and_continue), - wds.map_tuple(self.vis_processor, handler=wds.warn_and_continue), + wds.map_tuple(self.x_processor, handler=wds.warn_and_continue), wds.map(self.to_dict, handler=wds.warn_and_continue), ) def to_dict(self, sample): return { - "image": sample[0], + "vision": sample[0], "text_input": sample[1], } diff --git a/minigpt4/datasets/datasets/image_caption/image_caption_datasets.py b/minigpt4/datasets/datasets/image_caption/image_caption_datasets.py index 120fcf3..607d888 100644 --- a/minigpt4/datasets/datasets/image_caption/image_caption_datasets.py +++ b/minigpt4/datasets/datasets/image_caption/image_caption_datasets.py @@ -42,7 +42,7 @@ class ImageCaptionDataset(BaseDataset, __ImageDisplMixin): caption = self.text_processor(ann["caption"]) return { - "image": image, + "vision": image, "text_input": caption, "image_id": self.img_ids[ann["image_id"]], } @@ -67,7 +67,7 @@ class CaptionEvalDataset(BaseDataset, __ImageDisplMixin): image = self.x_processor(image) return { - "image": image, + "vision": image, "image_id": ann["image_id"], "instance_id": ann["instance_id"], } diff --git a/minigpt4/datasets/datasets/image_caption/laion_dataset.py b/minigpt4/datasets/datasets/image_caption/laion_dataset.py index 8cb34a8..e56160c 100644 --- a/minigpt4/datasets/datasets/image_caption/laion_dataset.py +++ b/minigpt4/datasets/datasets/image_caption/laion_dataset.py @@ -25,7 +25,7 @@ class LaionDataset(BaseDataset): def to_dict(self, sample): return { - "image": sample[0], + "vision": sample[0], "text_input": self.text_processor(sample[1]["caption"]), } diff --git a/minigpt4/datasets/datasets/mixins/mixins.py b/minigpt4/datasets/datasets/mixins/mixins.py index e143147..1c677fd 100644 --- a/minigpt4/datasets/datasets/mixins/mixins.py +++ b/minigpt4/datasets/datasets/mixins/mixins.py @@ -9,7 +9,7 @@ class __ImageDisplMixin: { "file": ann["image"], "caption": ann["caption"], - "image": sample["image"], + "vision": sample["vision"], } ) diff --git a/minigpt4/models/bind_gpt4.py b/minigpt4/models/bind_gpt4.py index 59c60f6..bfb4c17 100644 --- a/minigpt4/models/bind_gpt4.py +++ b/minigpt4/models/bind_gpt4.py @@ -1,8 +1,9 @@ import random -from typing import Dict, Tuple +from typing import Dict, Tuple, List import torch import torch.nn as nn +import re from torch import Tensor from transformers import LlamaTokenizer @@ -12,6 +13,36 @@ from minigpt4.models.blip2 import BaseModel from minigpt4.models.modeling_llama import LlamaForCausalLM +def filter_prompt(input_embeds: Dict[str, Tensor], prompt_list: List[str]) -> List[str]: + if not prompt_list: + return prompt_list + input_modal_set = set([k.title() for k in input_embeds if input_embeds[k] is not None]) + prompt_modal_sets = [set(re.findall("<([^<>]+)>", prompt)) for prompt in prompt_list] + results = [prompt_list[i] for i, prompt_modal_set in enumerate(prompt_modal_sets) if + prompt_modal_set == input_modal_set] + return results + + +def arrange_modalities(input_embeds: Dict[str, Tensor], prompt: str) -> List[Tensor]: + prompt_modalities = re.findall("<([^<>]+)>", prompt) + return [input_embeds[modality.lower()] for modality in prompt_modalities] + + +def concat_all_embeddings(input_embeds: Dict[str, Tensor], dim: int) -> Tensor: + embeds = [input_embeds[key] for key in input_embeds if input_embeds[key] is not None] + return torch.cat(embeds, dim=dim) + + +def filter_modalities(inputs): + filtered_inputs = {} + + for k in ModalityType.__dict__.values(): + if k in inputs: + filtered_inputs[k] = inputs[k] + + return filtered_inputs + + @registry.register_model("bind_gpt4") class BindGPT4(BaseModel): """ @@ -61,7 +92,9 @@ class BindGPT4(BaseModel): print('Loading Q-Former and Adapter/Projector') self.multimodal_joiner = ImageBindJoiner(vision_query_token_num=num_query_token, vision_qformer_frozen=freeze_qformer, - vision_post_dims=[768, self.llama_model.config.hidden_size] + vision_post_dims=[768, self.llama_model.config.hidden_size], + audio_query_token_num=num_query_token, + audio_post_dims=[768, self.llama_model.config.hidden_size] # vision_qformer_model=q_former_model, # vision_pre_dims=(1280, 1408) ) @@ -84,42 +117,37 @@ class BindGPT4(BaseModel): def encode_inputs(self, inputs: Dict[str, Tensor]) -> Dict[str, Tensor]: imagebind_outputs = self.multimodal_encoder(inputs) llama_inputs = self.multimodal_joiner(imagebind_outputs) - # NOTE: only accept image here return llama_inputs - def prompt_wrap(self, inputs: Dict[str, Tensor], modality_name: str, prompt: str) -> Tuple[Tensor, Tensor]: - # TODO: Accept More Modalities. - input_embeds = inputs[modality_name] - attns_input = torch.ones(input_embeds.size()[:-1], dtype=torch.long).to(input_embeds.device) - if prompt: - batch_size = input_embeds.shape[0] - p_before, p_after = prompt.split('') - p_before_tokens = self.llama_tokenizer( - p_before, return_tensors="pt", add_special_tokens=False).to(input_embeds.device) - p_after_tokens = self.llama_tokenizer( - p_after, return_tensors="pt", add_special_tokens=False).to(input_embeds.device) - p_before_embeds = self.llama_model.model.embed_tokens(p_before_tokens.input_ids).expand(batch_size, -1, -1) - p_after_embeds = self.llama_model.model.embed_tokens(p_after_tokens.input_ids).expand(batch_size, -1, -1) - wrapped_input_embeds = torch.cat([p_before_embeds, inputs, p_after_embeds], dim=1) - wrapped_atts_input = attns_input[:, :1].expand(-1, wrapped_input_embeds.shape[1]) - return wrapped_input_embeds, wrapped_atts_input - else: + def prompt_wrap(self, inputs: Dict[str, Tensor], prompt: str) -> Tuple[Tensor, Tensor]: + if not prompt: + input_embeds = concat_all_embeddings(inputs, dim=1) + attns_input = torch.ones(input_embeds.size()[:-1], dtype=torch.long).to(input_embeds.device) return input_embeds, attns_input + input_embeds_list = arrange_modalities(inputs, prompt) + batch_size = input_embeds_list[0].shape[0] + prompt_slices = prompt.split('') + prompt_tokens = [self.llama_tokenizer(prompt_slice, return_tensors="pt", add_special_tokens=False) + .to(input_embeds_list[0].device) for prompt_slice in prompt_slices] + prompt_embeds = [self.llama_model.model.embed_tokens(prompt_token.input_ids).expand(batch_size, -1, -1) + for prompt_token in prompt_tokens] + result_embeds = [emb for pair in zip(prompt_embeds[:-1], input_embeds_list) + for emb in pair] + [prompt_embeds[-1]] + wrapped_input_embeds = torch.cat(result_embeds, dim=1) + wrapped_atts_input = torch.ones(wrapped_input_embeds.size()[:-1], + dtype=torch.long).to(wrapped_input_embeds.device) + return wrapped_input_embeds, wrapped_atts_input def forward(self, inputs: Dict[str, Tensor]) -> Dict[str, Tensor]: - """ - TODO: More Modalities. - Only accept image inputs here. - Other modalities will conflict with the pre-defined prompt and wrapping strategy. - """ - bind_inputs = {ModalityType.VISION: inputs['image']} - embeds = self.encode_inputs(bind_inputs) - # assert "vision" in embeds, "Only Vision Input Can Be Accepted Now." - if self.prompt_list: - prompt = random.choice(self.prompt_list) + # filter `inputs` as it may contain informatioins other than modalities + modality_inputs = filter_modalities(inputs) + embeds = self.encode_inputs(modality_inputs) + filtered_prompts = filter_prompt(embeds, self.prompt_list) + if filtered_prompts: + prompt = random.choice(filtered_prompts) else: prompt = None - img_embeds, atts_img = self.prompt_wrap(embeds, ModalityType.VISION, prompt) + input_embs, input_atts = self.prompt_wrap(embeds, prompt) # NOTE: No modifications from the next line to the end. Except for the autocast part. @@ -134,28 +162,28 @@ class BindGPT4(BaseModel): truncation=True, max_length=self.max_txt_len, add_special_tokens=False - ).to(img_embeds.device) + ).to(input_embs.device) targets = to_regress_tokens.input_ids.masked_fill( to_regress_tokens.input_ids == self.llama_tokenizer.pad_token_id, -100 ) empty_targets = ( - torch.ones([atts_img.shape[0], atts_img.shape[1] + 1], - dtype=torch.long).to(img_embeds.device).fill_(-100) # plus one for bos + torch.ones([input_atts.shape[0], input_atts.shape[1] + 1], + dtype=torch.long).to(input_embs.device).fill_(-100) # plus one for bos ) targets = torch.cat([empty_targets, targets], dim=1) - batch_size = img_embeds.shape[0] + batch_size = input_embs.shape[0] bos = torch.ones([batch_size, 1], dtype=to_regress_tokens.input_ids.dtype, device=to_regress_tokens.input_ids.device) * self.llama_tokenizer.bos_token_id bos_embeds = self.llama_model.model.embed_tokens(bos) - atts_bos = atts_img[:, :1] + atts_bos = input_atts[:, :1] to_regress_embeds = self.llama_model.model.embed_tokens(to_regress_tokens.input_ids) - inputs_embeds = torch.cat([bos_embeds, img_embeds, to_regress_embeds], dim=1) - attention_mask = torch.cat([atts_bos, atts_img, to_regress_tokens.attention_mask], dim=1) + inputs_embeds = torch.cat([bos_embeds, input_embs, to_regress_embeds], dim=1) + attention_mask = torch.cat([atts_bos, input_atts, to_regress_tokens.attention_mask], dim=1) outputs = self.llama_model( inputs_embeds=inputs_embeds, diff --git a/minigpt4/models/blip2.py b/minigpt4/models/blip2.py index e33bf82..46c1eea 100644 --- a/minigpt4/models/blip2.py +++ b/minigpt4/models/blip2.py @@ -144,7 +144,7 @@ def compute_sim_matrix(model, data_loader, **kwargs): vit_feats = [] image_embeds = [] for samples in data_loader: - image = samples["image"] + image = samples["vision"] image = image.to(model.device) image_feat, vit_feat = model.forward_image(image) diff --git a/minigpt4/models/mini_gpt4.py b/minigpt4/models/mini_gpt4.py index 667edd5..2824ad1 100644 --- a/minigpt4/models/mini_gpt4.py +++ b/minigpt4/models/mini_gpt4.py @@ -164,7 +164,7 @@ class MiniGPT4(Blip2Base): return img_embeds, atts_img def forward(self, samples): - image = samples["image"] + image = samples["vision"] img_embeds, atts_img = self.encode_img(image) if hasattr(samples, 'question_split'): # VQA dataset print('VQA Batch') diff --git a/minigpt4/processors/__init__.py b/minigpt4/processors/__init__.py index 0ce174e..57b6d9a 100644 --- a/minigpt4/processors/__init__.py +++ b/minigpt4/processors/__init__.py @@ -11,11 +11,15 @@ from minigpt4.processors.blip_processors import ( Blip2ImageEvalProcessor, BlipCaptionProcessor, ) -from minigpt4.processors.imagebind_processor import ( +from minigpt4.processors.imagebind_vision_processor import ( ImageBindCaptionProcessor, ImageBindVisionTrainProcessor, ImageBindVisionEvalProcessor ) +from minigpt4.processors.imagebind_audio_processor import ( + ImageBindAudioTrainProcessor, + ImageBindAudioEvalProcessor, +) from minigpt4.common.registry import registry @@ -26,7 +30,9 @@ __all__ = [ "BlipCaptionProcessor", "ImageBindCaptionProcessor", "ImageBindVisionTrainProcessor", - "ImageBindVisionEvalProcessor" + "ImageBindVisionEvalProcessor", + "ImageBindAudioTrainProcessor", + "ImageBindAudioEvalProcessor", ] diff --git a/minigpt4/processors/audio_augment.py b/minigpt4/processors/audio_augment.py new file mode 100644 index 0000000..869db2e --- /dev/null +++ b/minigpt4/processors/audio_augment.py @@ -0,0 +1,190 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# @Author : Xinhao Mei @CVSSP, University of Surrey +# @E-mail : x.mei@surrey.ac.uk + +""" + Implemenation of SpecAugment++, + Adapated from Qiuqiang Kong's trochlibrosa: + https://github.com/qiuqiangkong/torchlibrosa/blob/master/torchlibrosa/augmentation.py +""" + +import torch +import torch.nn as nn + + +class DropStripes: + + def __init__(self, dim, drop_width, stripes_num): + """ Drop stripes. + args: + dim: int, dimension along which to drop + drop_width: int, maximum width of stripes to drop + stripes_num: int, how many stripes to drop + """ + super(DropStripes, self).__init__() + + assert dim in [2, 3] # dim 2: time; dim 3: frequency + + self.dim = dim + self.drop_width = drop_width + self.stripes_num = stripes_num + + def __call__(self, input): + """input: (batch_size, channels, time_steps, freq_bins)""" + + assert input.ndimension() == 4 + batch_size = input.shape[0] + total_width = input.shape[self.dim] + + for n in range(batch_size): + self.transform_slice(input[n], total_width) + + return input + + def transform_slice(self, e, total_width): + """ e: (channels, time_steps, freq_bins)""" + + for _ in range(self.stripes_num): + distance = torch.randint(low=0, high=self.drop_width, size=(1,))[0] + bgn = torch.randint(low=0, high=total_width - distance, size=(1,))[0] + + if self.dim == 2: + e[:, bgn: bgn + distance, :] = 0 + elif self.dim == 3: + e[:, :, bgn: bgn + distance] = 0 + + +class MixStripes: + + def __init__(self, dim, mix_width, stripes_num): + """ Mix stripes + args: + dim: int, dimension along which to mix + mix_width: int, maximum width of stripes to mix + stripes_num: int, how many stripes to mix + """ + + super(MixStripes, self).__init__() + + assert dim in [2, 3] + + self.dim = dim + self.mix_width = mix_width + self.stripes_num = stripes_num + + def __call__(self, input): + """input: (batch_size, channel, time_steps, freq_bins)""" + + assert input.ndimension() == 4 + + batch_size = input.shape[0] + total_width = input.shape[self.dim] + + rand_sample = input[torch.randperm(batch_size)] + for i in range(batch_size): + self.transform_slice(input[i], rand_sample[i], total_width) + return input + + def transform_slice(self, input, random_sample, total_width): + + for _ in range(self.stripes_num): + distance = torch.randint(low=0, high=self.mix_width, size=(1,))[0] + bgn = torch.randint(low=0, high=total_width - distance, size=(1,))[0] + + if self.dim == 2: + input[:, bgn: bgn + distance, :] = 0.5 * input[:, bgn: bgn + distance, :] + \ + 0.5 * random_sample[:, bgn: bgn + distance, :] + elif self.dim == 3: + input[:, :, bgn: bgn + distance] = 0.5 * input[:, :, bgn: bgn + distance] + \ + 0.5 * random_sample[:, :, bgn: bgn + distance] + + +class CutStripes: + + def __init__(self, dim, cut_width, stripes_num): + """ Cutting stripes with another randomly selected sample in mini-batch. + args: + dim: int, dimension along which to cut + cut_width: int, maximum width of stripes to cut + stripes_num: int, how many stripes to cut + """ + + super(CutStripes, self).__init__() + + assert dim in [2, 3] + + self.dim = dim + self.cut_width = cut_width + self.stripes_num = stripes_num + + def __call__(self, input): + """input: (batch_size, channel, time_steps, freq_bins)""" + + assert input.ndimension() == 4 + + batch_size = input.shape[0] + total_width = input.shape[self.dim] + + rand_sample = input[torch.randperm(batch_size)] + for i in range(batch_size): + self.transform_slice(input[i], rand_sample[i], total_width) + return input + + def transform_slice(self, input, random_sample, total_width): + + for _ in range(self.stripes_num): + distance = torch.randint(low=0, high=self.cut_width, size=(1,))[0] + bgn = torch.randint(low=0, high=total_width - distance, size=(1,))[0] + + if self.dim == 2: + input[:, bgn: bgn + distance, :] = random_sample[:, bgn: bgn + distance, :] + elif self.dim == 3: + input[:, :, bgn: bgn + distance] = random_sample[:, :, bgn: bgn + distance] + + +class SpecAugmentation: + + def __init__(self, time_drop_width, time_stripes_num, freq_drop_width, freq_stripes_num, + mask_type='mixture'): + """Spec augmetation and SpecAugment++. + [ref] Park, D.S., Chan, W., Zhang, Y., Chiu, C.C., Zoph, B., Cubuk, E.D. + and Le, Q.V., 2019. Specaugment: A simple data augmentation method + for automatic speech recognition. arXiv preprint arXiv:1904.08779. + [ref] Wang H, Zou Y, Wang W., 2021. SpecAugment++: A Hidden Space + Data Augmentation Method for Acoustic Scene Classification. arXiv + preprint arXiv:2103.16858. + + Args: + time_drop_width: int + time_stripes_num: int + freq_drop_width: int + freq_stripes_num: int + mask_type: str, mask type in SpecAugment++ (zero_value, mixture, cutting) + """ + + super(SpecAugmentation, self).__init__() + + if mask_type == 'zero_value': + self.time_augmentator = DropStripes(dim=2, drop_width=time_drop_width, + stripes_num=time_stripes_num) + self.freq_augmentator = DropStripes(dim=3, drop_width=freq_drop_width, + stripes_num=freq_stripes_num) + elif mask_type == 'mixture': + self.time_augmentator = MixStripes(dim=2, mix_width=time_drop_width, + stripes_num=time_stripes_num) + self.freq_augmentator = MixStripes(dim=3, mix_width=freq_drop_width, + stripes_num=freq_stripes_num) + elif mask_type == 'cutting': + self.time_augmentator = CutStripes(dim=2, cut_width=time_drop_width, + stripes_num=time_stripes_num) + self.freq_augmentator = CutStripes(dim=3, cut_width=freq_drop_width, + stripes_num=freq_stripes_num) + else: + raise NameError('No such mask type in SpecAugment++') + + def __call__(self, inputs): + # x should be in size [batch_size, channel, time_steps, freq_bins] + x = self.time_augmentator(inputs) + x = self.freq_augmentator(x) + return x diff --git a/minigpt4/processors/blip_processors.py b/minigpt4/processors/blip_processors.py index fd26160..d2bc43f 100644 --- a/minigpt4/processors/blip_processors.py +++ b/minigpt4/processors/blip_processors.py @@ -9,7 +9,7 @@ import re from minigpt4.common.registry import registry from minigpt4.processors.base_processor import BaseProcessor -from minigpt4.processors.randaugment import RandomAugment +from minigpt4.processors.vision_augment import RandomAugment from omegaconf import OmegaConf from torchvision import transforms from torchvision.transforms.functional import InterpolationMode diff --git a/minigpt4/processors/imagebind_audio_processor.py b/minigpt4/processors/imagebind_audio_processor.py new file mode 100644 index 0000000..fa64cb3 --- /dev/null +++ b/minigpt4/processors/imagebind_audio_processor.py @@ -0,0 +1,166 @@ +from typing import Union, List + +import torch +import torchaudio +from omegaconf import OmegaConf +from pytorchvideo.data.clip_sampling import ConstantClipsPerVideoSampler, RandomMultiClipSampler +from torch import Tensor + +from imagebind.data.data_utils import waveform2melspec, get_constant_clip_timepoints, \ + get_random_clip_timepoints +from minigpt4.processors.base_processor import BaseProcessor +from torchvision import transforms +from minigpt4.common.registry import registry +from minigpt4.processors.audio_augment import SpecAugmentation + + +class ImageBindAudioBaseProcessor(BaseProcessor): + def __init__(self, mean=None, std=None, target_sr=None, clip_duration=None, clips_per_video=None, + num_mel_bins=None, target_length=None, clip_sample_method="Random"): + super().__init__() + self.mean = -4.268 if mean is None else mean + self.std = 9.138 if std is None else std + self.target_sr = 16000 if target_sr is None else target_sr + self.num_mel_bins = num_mel_bins + self.target_length = target_length + self.clip_sampler = self._construct_clip_sampler(clip_duration, clips_per_video, clip_sample_method) + self.normalize = transforms.Normalize(self.mean, self.std) + + def _construct_clip_sampler(self, clip_duration, clips_per_video, clip_sample_method): + if clip_duration is None or clips_per_video is None: + return None + if clip_sample_method == "Constant": + return ConstantClipsPerVideoSampler( + clip_duration=clip_duration, clips_per_video=clips_per_video + ) + elif clip_sample_method == "Random": + return RandomMultiClipSampler(clip_duration=clip_duration, num_clips=clips_per_video) + else: + raise NotImplementedError + + def waveform_resample(self, waveform: Tensor, origin_sr: int) -> Tensor: + return torchaudio.functional.resample( + waveform, orig_freq=origin_sr, new_freq=self.target_sr + ) + + def clip_sample(self, waveform: Tensor) -> List[Tensor]: + if self.clip_sampler is None: + return [waveform] + elif isinstance(self.clip_sampler, ConstantClipsPerVideoSampler): + all_clips_timepoints = get_constant_clip_timepoints(self.clip_sampler, waveform.size(1) / self.target_sr) + elif isinstance(self.clip_sampler, RandomMultiClipSampler): + all_clips_timepoints = get_random_clip_timepoints(self.clip_sampler, waveform.size(1) / self.target_sr) + else: + raise NotImplementedError + all_clips = [] + for clip_timepoints in all_clips_timepoints: + start_pos = int(clip_timepoints[0] * self.target_sr) + end_pos = int(clip_timepoints[1] * self.target_sr) + waveform_clip = waveform[:, start_pos: end_pos] + all_clips.append(waveform_clip) + return all_clips + + def waveform_melspec(self, waveforms: Union[List[Tensor], Tensor]) -> List[Tensor]: + if isinstance(waveforms, Tensor): + return waveform2melspec(waveforms, self.target_sr, self.num_mel_bins, self.target_length) + else: + return [waveform2melspec(waveform, self.target_sr, self.num_mel_bins, self.target_length) + for waveform in waveforms] + + +@registry.register_processor("imagebind_audio_train") +class ImageBindAudioTrainProcessor(ImageBindAudioBaseProcessor): + def __init__(self, mean=None, std=None, target_sr=None, clip_duration=None, clips_per_video=None, + clip_sample_method="Random", num_mel_bins=None, target_length=None, time_drop_width=13, + time_stripes_num=2, freq_drop_width=8, freq_stripes_num=2, mask_type='mixture'): + super().__init__(mean=mean, std=std, target_sr=target_sr, + clip_duration=clip_duration, clips_per_video=clips_per_video, + num_mel_bins=num_mel_bins, target_length=target_length, + clip_sample_method=clip_sample_method) + self.spec_augment = SpecAugmentation(time_drop_width, time_stripes_num, + freq_drop_width, freq_stripes_num, mask_type) + + def __call__(self, item): + # item: Tuple[Tensor, int] + waveform, origin_sr = item[0], item[1] + waveform = self.waveform_resample(waveform, origin_sr) + waveform_clips = self.clip_sample(waveform) + melspec_clips = self.waveform_melspec(waveform_clips) + normed_melspecs = [self.normalize(clip) for clip in melspec_clips] + all_clips = torch.stack(normed_melspecs, dim=0) + # all_clips: [clips_per_video, channel, mel_bins, time_steps] + # augment: [batch_size, channel, time_steps, freq_bins] + augmented_clips = self.spec_augment(all_clips.transpose(-2, -1)).transpose(-2, -1) + return augmented_clips + + @classmethod + def from_config(cls, cfg=None): + if cfg is None: + cfg = OmegaConf.create() + + target_sr = cfg.get("target_sr", 16000) + clip_duration = cfg.get("clip_duration", 2) + clips_per_video = cfg.get("clips_per_video", 3) + num_mel_bins = cfg.get("num_mel_bins", 128) + target_length = cfg.get("target_length", 204) + time_drop_width = cfg.get("time_drop_width", 13) + time_stripes_num = cfg.get("time_stripes_num", 2) + # 13 * 2 / 204 = 12.75% Time Mask + freq_drop_width = cfg.get("freq_drop_width", 8) + freq_stripes_num = cfg.get("freq_stripes_num", 2) + # 8 * 2 / 128 = 12.5% Freq Mask + mask_type = cfg.get("mask_type", 'mixture') + + mean = cfg.get("mean", None) + std = cfg.get("std", None) + + return cls( + mean=mean, std=std, target_sr=target_sr, + clip_duration=clip_duration, clips_per_video=clips_per_video, + num_mel_bins=num_mel_bins, target_length=target_length, + time_drop_width=time_drop_width, time_stripes_num=time_stripes_num, + freq_drop_width=freq_drop_width, freq_stripes_num=freq_stripes_num, + mask_type=mask_type + ) + + +@registry.register_processor("imagebind_audio_eval") +class ImageBindAudioEvalProcessor(ImageBindAudioBaseProcessor): + def __init__(self, mean=None, std=None, target_sr=None, clip_duration=None, clips_per_video=None, + clip_sample_method="Constant", num_mel_bins=None, target_length=None): + super().__init__(mean=mean, std=std, target_sr=target_sr, + clip_duration=clip_duration, clips_per_video=clips_per_video, + num_mel_bins=num_mel_bins, target_length=target_length, + clip_sample_method=clip_sample_method) + + def __call__(self, item): + # item: Tuple[Tensor, int] + waveform, origin_sr = item[0], item[1] + waveform = self.waveform_resample(waveform, origin_sr) + waveform_clips = self.clip_sample(waveform) + melspec_clips = self.waveform_melspec(waveform_clips) + normed_melspecs = [self.normalize(clip) for clip in melspec_clips] + all_clips = torch.stack(normed_melspecs, dim=0) + # all_clips: [clips_per_video, channel, mel_bins, time_steps] + return all_clips + + @classmethod + def from_config(cls, cfg=None): + if cfg is None: + cfg = OmegaConf.create() + + target_sr = cfg.get("target_sr", 16000) + clip_duration = cfg.get("clip_duration", 2) + clips_per_video = cfg.get("clips_per_video", 3) + num_mel_bins = cfg.get("num_mel_bins", 128) + target_length = cfg.get("target_length", 204) + + mean = cfg.get("mean", None) + std = cfg.get("std", None) + + return cls( + mean=mean, std=std, target_sr=target_sr, + clip_duration=clip_duration, clips_per_video=clips_per_video, + num_mel_bins=num_mel_bins, target_length=target_length + ) + diff --git a/minigpt4/processors/imagebind_processor.py b/minigpt4/processors/imagebind_vision_processor.py similarity index 98% rename from minigpt4/processors/imagebind_processor.py rename to minigpt4/processors/imagebind_vision_processor.py index 837f499..67b92f1 100644 --- a/minigpt4/processors/imagebind_processor.py +++ b/minigpt4/processors/imagebind_vision_processor.py @@ -9,7 +9,7 @@ import re from minigpt4.common.registry import registry from minigpt4.processors.base_processor import BaseProcessor -from minigpt4.processors.randaugment import RandomAugment +from minigpt4.processors.vision_augment import RandomAugment from omegaconf import OmegaConf from torchvision import transforms from torchvision.transforms.functional import InterpolationMode @@ -147,3 +147,5 @@ class ImageBindVisionEvalProcessor(ImageBindVisionBaseProcessor): return cls(image_size=image_size, mean=mean, std=std) + + diff --git a/minigpt4/processors/randaugment.py b/minigpt4/processors/vision_augment.py similarity index 100% rename from minigpt4/processors/randaugment.py rename to minigpt4/processors/vision_augment.py diff --git a/requirements.txt b/requirements.txt index 3125924..a908368 100644 --- a/requirements.txt +++ b/requirements.txt @@ -53,7 +53,6 @@ ipdb tensorflow-cpu tensorboardX # mmcv==2.0.0 -f https://download.openmmlab.com/mmcv/dist/cu117/torch1.13/index.html - bs4==0.0.1 # Needed for text cleaning bson==0.5.10 byted-dataloader==0.3.6 @@ -65,4 +64,5 @@ sentencepiece==0.1.99 # Needed for T5 tokenizer tensorboard==2.11.2 tensorflow==2.11.0 # Needed for tensorboard hdfs support tensorflow-io==0.30.0 # Needed for tensorboard hdfs support -tqdm==4.64.1 \ No newline at end of file +tqdm==4.64.1 +pytorchvideo==0.1.5 \ No newline at end of file diff --git a/train_configs/bindgpt4_stage1_audio_pretrain.yaml b/train_configs/bindgpt4_stage1_audio_pretrain.yaml new file mode 100644 index 0000000..9a2c023 --- /dev/null +++ b/train_configs/bindgpt4_stage1_audio_pretrain.yaml @@ -0,0 +1,68 @@ +model: + arch: bind_gpt4 + model_type: pretrain_vicuna + freeze_imagebind: True + freeze_qformer: False + +datasets: + bbc: + audio_processor: + train: + name: "imagebind_audio_train" + text_processor: + train: + name: "imagebind_caption" + sample_ratio: 31 + audioset: + audio_processor: + train: + name: "imagebind_audio_train" + text_processor: + train: + name: "imagebind_caption" + sample_ratio: 108 + soundbible: + audio_processor: + train: + name: "imagebind_audio_train" + text_processor: + train: + name: "imagebind_caption" + sample_ratio: 2 + freesound: + audio_processor: + train: + name: "imagebind_audio_train" + text_processor: + train: + name: "imagebind_caption" + sample_ratio: 262 + +run: + task: image_text_pretrain + # optimizer + lr_sched: "linear_warmup_cosine_lr" + init_lr: 1e-4 + min_lr: 8e-5 + warmup_lr: 1e-6 + + weight_decay: 0.05 + max_epoch: 4 + batch_size_train: 64 + batch_size_eval: 64 + num_workers: 4 + warmup_steps: 5000 + iters_per_epoch: 5000 + + seed: 42 + output_dir: "output/bindgpt4_stage1_audio_pretrain" + + amp: True + resume_ckpt_path: null + evaluate: False + train_splits: ["train"] + + device: "cuda" + world_size: 1 + dist_url: "env://" + distributed: True \ No newline at end of file