diff --git a/README.md b/README.md
index b1f8961..8979bf1 100644
--- a/README.md
+++ b/README.md
@@ -1,170 +1,5 @@
-# MiniGPT-4: Enhancing Vision-language Understanding with Advanced Large Language Models
-[Deyao Zhu](https://tsutikgiau.github.io/)* (On Job Market!), [Jun Chen](https://junchen14.github.io/)* (On Job Market!), [Xiaoqian Shen](https://xiaoqian-shen.github.io), [Xiang Li](https://xiangli.ac.cn), and [Mohamed Elhoseiny](https://www.mohamed-elhoseiny.com/). *Equal Contribution
-
-**King Abdullah University of Science and Technology**
-
-<a href='https://minigpt-4.github.io'><img src='https://img.shields.io/badge/Project-Page-Green'></a>  <a href='https://arxiv.org/abs/2304.10592'><img src='https://img.shields.io/badge/Paper-Arxiv-red'></a> <a href='https://huggingface.co/spaces/Vision-CAIR/minigpt4'><img src='https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue'></a> <a href='https://huggingface.co/Vision-CAIR/MiniGPT-4'><img src='https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Model-blue'></a> [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1OK4kYsZphwt5DXchKkzMBjYF6jnkqh4R?usp=sharing) [![YouTube](https://badges.aleen42.com/src/youtube.svg)](https://www.youtube.com/watch?v=__tftoxpBAw&feature=youtu.be)
-
-
-## News
-We now provide a pretrained MiniGPT-4 aligned with Vicuna-7B! The demo GPU memory consumption now can be as low as 12GB.
-
-
-## Online Demo
-
-Click the image to chat with MiniGPT-4 around your images
-[![demo](figs/online_demo.png)](https://minigpt-4.github.io)
-
-
-## Examples
-  |   |   |
-:-------------------------:|:-------------------------:
-![find wild](figs/examples/wop_2.png) |  ![write story](figs/examples/ad_2.png)
-![solve problem](figs/examples/fix_1.png)  |  ![write Poem](figs/examples/rhyme_1.png)
-
-More examples can be found in the [project page](https://minigpt-4.github.io).
-
-
-
-## Introduction
-- MiniGPT-4 aligns a frozen visual encoder from BLIP-2 with a frozen LLM, Vicuna, using just one projection layer. 
-- We train MiniGPT-4 with two stages. The first traditional pretraining stage is trained using roughly 5 million aligned image-text pairs in 10 hours using 4 A100s. After the first stage, Vicuna is able to understand the image. But the generation ability of Vicuna is heavilly impacted.
-- To address this issue and improve usability, we propose a novel way to create high-quality image-text pairs by the model itself and ChatGPT together. Based on this, we then create a small (3500 pairs in total) yet high-quality dataset.
-- The second finetuning stage is trained on this dataset in a conversation template to significantly improve its generation reliability and overall usability. To our surprise, this stage is computationally efficient and takes only around 7 minutes with a single A100.
-- MiniGPT-4 yields many emerging vision-language capabilities similar to those demonstrated in GPT-4. 
-
-
-![overview](figs/overview.png)
-
-
-## Getting Started
-### Installation
-
-**1. Prepare the code and the environment**
-
-Git clone our repository, creating a python environment and ativate it via the following command
-
-```bash
-git clone https://github.com/Vision-CAIR/MiniGPT-4.git
-cd MiniGPT-4
-conda env create -f environment.yml
-conda activate minigpt4
-```
-
-
-**2. Prepare the pretrained Vicuna weights**
-
-The current version of MiniGPT-4 is built on the v0 versoin of Vicuna-13B.
-Please refer to our instruction [here](PrepareVicuna.md) 
-to prepare the Vicuna weights.
-The final weights would be in a single folder in a structure similar to the following:
-
-```
-vicuna_weights
-├── config.json
-├── generation_config.json
-├── pytorch_model.bin.index.json
-├── pytorch_model-00001-of-00003.bin
-...   
-```
-
-Then, set the path to the vicuna weight in the model config file 
-[here](minigpt4/configs/models/minigpt4.yaml#L16) at Line 16.
-
-**3. Prepare the pretrained MiniGPT-4 checkpoint**
-
-Download the pretrained checkpoints according to the Vicuna model you prepare.
-
-|                                Checkpoint Aligned with Vicuna 13B                                |                               Checkpoint Aligned with Vicuna 7B                                |
-:------------------------------------------------------------------------------------------------:|:----------------------------------------------------------------------------------------------:
- [Downlad](https://drive.google.com/file/d/1a4zLvaiDBr-36pasffmgpvH5P7CKmpze/view?usp=share_link) | [Download](https://drive.google.com/file/d/1RY9jV0dyqLX-o38LrumkKRh6Jtaop58R/view?usp=sharing) 
-
-
-Then, set the path to the pretrained checkpoint in the evaluation config file 
-in [eval_configs/minigpt4_eval.yaml](eval_configs/minigpt4_eval.yaml#L10) at Line 11. 
-
-
-
-### Launching Demo Locally
-
-Try out our demo [demo.py](eval_scripts/qualitative_eval.py) on your local machine by running
-
-```
-python demo.py --cfg-path eval_configs/minigpt4_eval.yaml  --gpu-id 0
-```
-
-To save GPU memory, Vicuna loads as 8 bit by default, with a beam search width of 1. 
-This configuration requires about 23G GPU memory for Vicuna 13B and 11.5G GPU memory for Vicuna 7B. 
-For more powerful GPUs, you can run the model
-in 16 bit by setting low_resource to False in the config file 
-[minigpt4_eval.yaml](eval_configs/minigpt4_eval.yaml) and use a larger beam search width.
-
-Thanks [@WangRongsheng](https://github.com/WangRongsheng), you can also run our code on [Colab](https://colab.research.google.com/drive/1OK4kYsZphwt5DXchKkzMBjYF6jnkqh4R?usp=sharing)
-
-
-### Training
-The training of MiniGPT-4 contains two alignment stages.
-
-**1. First pretraining stage**
-
-In the first pretrained stage, the model is trained using image-text pairs from Laion and CC datasets
-to align the vision and language model. To download and prepare the datasets, please check 
-our [first stage dataset preparation instruction](dataset/README_1_STAGE.md). 
-After the first stage, the visual features are mapped and can be understood by the language
-model.
-To launch the first stage training, run the following command. In our experiments, we use 4 A100. 
-You can change the save path in the config file 
-[train_configs/minigpt4_stage1_pretrain.yaml](train_configs/minigpt4_stage1_pretrain.yaml)
-
-```bash
-torchrun --nproc-per-node NUM_GPU train.py --cfg-path train_configs/minigpt4_stage1_pretrain.yaml
-```
-
-A MiniGPT-4 checkpoint with only stage one training can be downloaded 
-[here (13B)](https://drive.google.com/file/d/1u9FRRBB3VovP1HxCAlpD9Lw4t4P6-Yq8/view?usp=share_link) or [here (7B)](https://drive.google.com/file/d/1HihQtCEXUyBM1i9DQbaK934wW3TZi-h5/view?usp=share_link).
-Compared to the model after stage two, this checkpoint generate incomplete and repeated sentences frequently.
-
-
-**2. Second finetuning stage**
-
-In the second stage, we use a small high quality image-text pair dataset created by ourselves
-and convert it to a conversation format to further align MiniGPT-4.
-To download and prepare our second stage dataset, please check our 
-[second stage dataset preparation instruction](dataset/README_2_STAGE.md).
-To launch the second stage alignment, 
-first specify the path to the checkpoint file trained in stage 1 in 
-[train_configs/minigpt4_stage1_pretrain.yaml](train_configs/minigpt4_stage2_finetune.yaml).
-You can also specify the output path there. 
-Then, run the following command. In our experiments, we use 1 A100.
-
-```bash
-torchrun --nproc-per-node NUM_GPU train.py --cfg-path train_configs/minigpt4_stage2_finetune.yaml
-```
-
-After the second stage alignment, MiniGPT-4 is able to talk about the image coherently and user-friendly. 
-
-
-
-
-## Acknowledgement
-
-+ [BLIP2](https://huggingface.co/docs/transformers/main/model_doc/blip-2) The model architecture of MiniGPT-4 follows BLIP-2. Don't forget to check this great open-source work if you don't know it before!
-+ [Lavis](https://github.com/salesforce/LAVIS) This repository is built upon Lavis!
-+ [Vicuna](https://github.com/lm-sys/FastChat) The fantastic language ability of Vicuna with only 13B parameters is just amazing. And it is open-source!
-
-
-If you're using MiniGPT-4 in your research or applications, please cite using this BibTeX:
-```bibtex
-@article{zhu2023minigpt,
-  title={MiniGPT-4: Enhancing Vision-Language Understanding with Advanced Large Language Models},
-  author={Zhu, Deyao and Chen, Jun and Shen, Xiaoqian and Li, Xiang and Elhoseiny, Mohamed},
-  journal={arXiv preprint arXiv:2304.10592},
-  year={2023}
-}
-```
-
-
-## License
-This repository is under [BSD 3-Clause License](LICENSE.md).
-Many codes are based on [Lavis](https://github.com/salesforce/LAVIS) with 
-BSD 3-Clause License [here](LICENSE_Lavis.md).
+# TODO List
+- [ ] reconstruct the eval scripts
+- [ ] support audio evaluation
+- [ ] clean code repo & solve recursive import
+- [x] audio processor: random sampling
\ No newline at end of file
diff --git a/arnold_before.sh b/arnold_before.sh
index 844651a..6b019dc 100644
--- a/arnold_before.sh
+++ b/arnold_before.sh
@@ -6,7 +6,9 @@
 pip3 install --upgrade pip
 pip3 install -r requirements.txt
 pip3 install byted-dataloader -i "https://bytedpypi.byted.org/simple"
-mmengine-0.7.3
+pip3 install mmmengine==0.7.3
+pip3 install mmcv==2.0.0 -f https://download.openmmlab.com/mmcv/dist/cu117/torch1.13/index.html
+
 # unset http_proxy && unset https_proxy && unset no_proxy
 
 # # ----------------------------------------------------------------------------------------
diff --git a/eval_scripts/eval_utils.py b/eval_scripts/eval_utils.py
index 45aec39..2f8b7a8 100644
--- a/eval_scripts/eval_utils.py
+++ b/eval_scripts/eval_utils.py
@@ -1,4 +1,5 @@
 import torch
+import torchaudio
 from PIL import Image
 
 
@@ -13,3 +14,12 @@ def load_image(image, image_processor):
         if len(image.shape) == 3:
             image = image.unsqueeze(0)
     return image
+
+
+def load_audio(audio, audio_processor):
+    if isinstance(audio, str):  # is a audio path
+        raw_audio = torchaudio.load(audio)
+        audio = audio_processor(audio)
+    # elif isinstance(audio, )
+    else:
+        raise NotImplementedError
diff --git a/eval_scripts/qualitative_eval.py b/eval_scripts/qualitative_eval.py
index ebc5012..c10cbe0 100644
--- a/eval_scripts/qualitative_eval.py
+++ b/eval_scripts/qualitative_eval.py
@@ -125,6 +125,7 @@ with gr.Blocks() as demo:
     with gr.Row():
         with gr.Column(scale=0.5):
             image = gr.Image(type="pil")
+            # audio = gr.Audio()
             upload_button = gr.Button(value="Upload & Start Chat", interactive=True, variant="primary")
             clear = gr.Button("Restart")
 
@@ -150,7 +151,7 @@ with gr.Blocks() as demo:
             chat_state = gr.State()
             emb_list = gr.State()
             chatbot = gr.Chatbot(label='BindGPT-4')
-            text_input = gr.Textbox(label='User', placeholder='Please upload your image first', interactive=False)
+            text_input = gr.Textbox(label='User', placeholder='Please upload your image/audio first', interactive=False)
 
     upload_button.click(upload_img, [image, text_input, chat_state],
                         [image, text_input, upload_button, chat_state, emb_list])
diff --git a/imagebind/data/data_utils.py b/imagebind/data/data_utils.py
index 49ec571..a3c6c79 100644
--- a/imagebind/data/data_utils.py
+++ b/imagebind/data/data_utils.py
@@ -15,7 +15,7 @@ import logging
 from imagebind.models.multimodal_preprocessors import SimpleTokenizer
 from PIL import Image
 from pytorchvideo import transforms as pv_transforms
-from pytorchvideo.data.clip_sampling import ConstantClipsPerVideoSampler
+from pytorchvideo.data.clip_sampling import ConstantClipsPerVideoSampler, RandomMultiClipSampler
 from pytorchvideo.data.encoded_video import EncodedVideo
 
 from torchvision import transforms
@@ -59,23 +59,11 @@ def waveform2melspec(waveform, sample_rate, num_mel_bins, target_length):
         fbank = torch.nn.functional.pad(fbank, (0, p), mode="constant", value=0)
     elif p < 0:
         fbank = fbank[:, 0:target_length]
-    # Convert to [1, mel_bins, num_frames] shape, essentially like a 1
-    # channel image
+    # Convert to [1, mel_bins, num_frames] shape, essentially like a 1 channel image
     fbank = fbank.unsqueeze(0)
     return fbank
 
 
-def get_clip_timepoints(clip_sampler, duration):
-    # Read out all clips in this video
-    all_clips_timepoints = []
-    is_last_clip = False
-    end = 0.0
-    while not is_last_clip:
-        start, end, _, _, is_last_clip = clip_sampler(end, duration, annotation=None)
-        all_clips_timepoints.append((start, end))
-    return all_clips_timepoints
-
-
 def load_and_transform_vision_data(image_paths, device):
     if image_paths is None:
         return None
@@ -137,14 +125,14 @@ def load_and_transform_audio_data(
             waveform = torchaudio.functional.resample(
                 waveform, orig_freq=sr, new_freq=sample_rate
             )
-        all_clips_timepoints = get_clip_timepoints(
+        all_clips_timepoints = get_constant_clip_timepoints(
             clip_sampler, waveform.size(1) / sample_rate
         )
         all_clips = []
         for clip_timepoints in all_clips_timepoints:
             waveform_clip = waveform[
                 :,
-                int(clip_timepoints[0] * sample_rate) : int(
+                int(clip_timepoints[0] * sample_rate): int(
                     clip_timepoints[1] * sample_rate
                 ),
             ]
@@ -162,7 +150,8 @@ def load_and_transform_audio_data(
     return torch.stack(audio_outputs, dim=0)
 
 
-def get_clip_timepoints(clip_sampler, duration):
+def get_constant_clip_timepoints(clip_sampler, duration):
+    assert isinstance(clip_sampler, ConstantClipsPerVideoSampler), "Incompatible Type of Sampler!"
     # Read out all clips in this video
     all_clips_timepoints = []
     is_last_clip = False
@@ -173,6 +162,13 @@ def get_clip_timepoints(clip_sampler, duration):
     return all_clips_timepoints
 
 
+def get_random_clip_timepoints(clip_sampler, duration):
+    assert isinstance(clip_sampler, RandomMultiClipSampler), "Incompatible Type of Sampler!"
+    starts, ends, _, _, _ = clip_sampler(0.0, duration, annotation=None)
+    all_clips_timepoints = sorted(list(zip(starts, ends)), key=lambda x: x[0])
+    return all_clips_timepoints
+
+
 def crop_boxes(boxes, x_offset, y_offset):
     """
     Perform crop on the bounding boxes given the offsets.
@@ -244,7 +240,7 @@ def uniform_crop(images, size, spatial_idx, boxes=None, scale_size=None):
             x_offset = 0
         elif spatial_idx == 2:
             x_offset = width - size
-    cropped = images[:, :, y_offset : y_offset + size, x_offset : x_offset + size]
+    cropped = images[:, :, y_offset: y_offset + size, x_offset: x_offset + size]
     cropped_boxes = crop_boxes(boxes, x_offset, y_offset) if boxes is not None else None
     if ndim == 3:
         cropped = cropped.squeeze(0)
@@ -328,7 +324,7 @@ def load_and_transform_video_data(
             **{"sample_rate": sample_rate},
         )
 
-        all_clips_timepoints = get_clip_timepoints(clip_sampler, video.duration)
+        all_clips_timepoints = get_constant_clip_timepoints(clip_sampler, video.duration)
 
         all_video = []
         for clip_timepoints in all_clips_timepoints:
@@ -347,4 +343,4 @@ def load_and_transform_video_data(
         all_video = torch.stack(all_video, dim=0)
         video_outputs.append(all_video)
 
-    return torch.stack(video_outputs, dim=0).to(device)
\ No newline at end of file
+    return torch.stack(video_outputs, dim=0).to(device)
diff --git a/imagebind/models/image_bind.py b/imagebind/models/image_bind.py
index c93fcd0..487f479 100644
--- a/imagebind/models/image_bind.py
+++ b/imagebind/models/image_bind.py
@@ -50,46 +50,59 @@ ModalityType = SimpleNamespace(
 class ImageBindJoiner(nn.Module):
     def __init__(self,
                  vision_query_token_num: int,
+                 audio_query_token_num: int,
                  vision_qformer_frozen: bool = False,
                  vision_qformer_model: str = "",  # The url or path of pre-trained vision Q-Former model
                  vision_pre_dims: List[int] = (),  # Projection before Q-Former
-                 vision_post_dims: List[int] = (768, 768)  # Projection after Q-Former
+                 vision_post_dims: List[int] = (1280, 768),  # Projection after Q-Former
+                 audio_pre_dims: List[int] = (),
+                 audio_post_dims: List[int] = (768, 768)
                  ):
         super().__init__()
         assert not (vision_qformer_frozen and vision_qformer_model == "")
-        self.modality_pre_projectors = self._create_modality_pre_projectors(vision_pre_dims)
+        self.modality_pre_projectors = self._create_modality_pre_projectors(vision_pre_dims, audio_pre_dims)
         self.modality_qformers = self._create_modality_qformers(vision_query_token_num,
                                                                 vision_qformer_frozen,
-                                                                vision_qformer_model)
-        self.modality_post_projectors = self._create_modality_post_projectors(vision_post_dims)
+                                                                vision_qformer_model,
+                                                                audio_query_token_num)
+        self.modality_post_projectors = self._create_modality_post_projectors(vision_post_dims, audio_post_dims)
 
     def _create_modality_pre_projectors(self,
-                                        vision_pre_dims
+                                        vision_pre_dims,
+                                        audio_pre_dims
                                         ):
         modality_pre_projectors = {
-            ModalityType.VISION: create_projectors(vision_pre_dims)
+            ModalityType.VISION: create_projectors(vision_pre_dims),
+            ModalityType.AUDIO: create_projectors(audio_pre_dims)
         }
         return modality_pre_projectors
 
     def _create_modality_qformers(self,
                                   vision_query_token_num,
                                   vision_qformer_frozen,
-                                  vision_qformer_model
+                                  vision_qformer_model,
+                                  audio_query_token_num
                                   ):
         vision_qformer = SequenceGenericQFormer(num_query_token=vision_query_token_num,
                                                 freeze_qformer=vision_qformer_frozen,
                                                 encoder_width=1280,  # TODO: fix hard-coding
                                                 q_former_model=vision_qformer_model)
+        audio_qformer = SequenceGenericQFormer(num_query_token=audio_query_token_num,
+                                               freeze_qformer=False,
+                                               encoder_width=768)
         modality_qformers = {
-            ModalityType.VISION: vision_qformer
+            ModalityType.VISION: vision_qformer,
+            ModalityType.AUDIO: audio_qformer
         }
 
         return nn.ModuleDict(modality_qformers)
 
-    def _create_modality_post_projectors(self, vision_post_dims):
+    def _create_modality_post_projectors(self, vision_post_dims, audio_post_dims):
         vision_projector = create_projectors(vision_post_dims)
+        audio_projector = create_projectors(audio_post_dims)
         modality_projectors = {
-            ModalityType.VISION: vision_projector
+            ModalityType.VISION: vision_projector,
+            ModalityType.AUDIO: audio_projector
         }
 
         return nn.ModuleDict(modality_projectors)
@@ -97,7 +110,6 @@ class ImageBindJoiner(nn.Module):
     def forward(self, inputs: Dict[str, Tensor]) -> Dict[str, Tensor]:
         outputs = {}
         for modality_key, modality_value in inputs.items():
-            # assert modality_key == ModalityType.VISION, "Only Vision is Currently Supported."
             if modality_value is not None:
                 modality_value = self.modality_pre_projectors[modality_key](modality_value)
                 modality_value = self.modality_qformers[modality_key](modality_value)
@@ -544,7 +556,7 @@ class ImageBindModel(nn.Module):
 
                 # NOTE: The reduction operation has been modified.
                 if reduce_list:
-                    modality_value = modality_value.reshape(B, S, *modality_value[2:])
+                    modality_value = modality_value.reshape(B, S, *modality_value.shape[1:])
                     modality_value = modality_value.mean(dim=1)
 
                 outputs[modality_key] = modality_value
diff --git a/minigpt4/common/registry.py b/minigpt4/common/registry.py
index 75a6c23..b0a24cb 100644
--- a/minigpt4/common/registry.py
+++ b/minigpt4/common/registry.py
@@ -32,10 +32,12 @@ class Registry:
         """
 
         def wrap(builder_cls):
+            # TODO: merge them or split builders by modality
             from minigpt4.datasets.builders.image_base_dataset_builder import ImageBaseDatasetBuilder
+            from minigpt4.datasets.builders.audio_base_dataset_builder import AudioBaseDatasetBuilder
 
             assert issubclass(
-                builder_cls, ImageBaseDatasetBuilder
+                builder_cls, (ImageBaseDatasetBuilder, AudioBaseDatasetBuilder)
             ), "All builders must inherit BaseDatasetBuilder class, found {}".format(
                 builder_cls
             )
diff --git a/minigpt4/configs/datasets/audioset/defaults.yaml b/minigpt4/configs/datasets/audioset/defaults.yaml
new file mode 100644
index 0000000..4ce44bf
--- /dev/null
+++ b/minigpt4/configs/datasets/audioset/defaults.yaml
@@ -0,0 +1,5 @@
+datasets:
+  audioset:
+    data_type: audio
+    build_info:
+      storage: /mnt/bn/zilongdata-hl/dataset/wavcaps/web_datasets/AudioSet_SL/AudioSet_SL{00..54}.tar
diff --git a/minigpt4/configs/datasets/bbc/defaults.yaml b/minigpt4/configs/datasets/bbc/defaults.yaml
new file mode 100644
index 0000000..a0aa380
--- /dev/null
+++ b/minigpt4/configs/datasets/bbc/defaults.yaml
@@ -0,0 +1,5 @@
+datasets:
+  bbc:
+    data_type: audio
+    build_info:
+      storage: /mnt/bn/zilongdata-hl/dataset/wavcaps/web_datasets/BBC_Sound_Effects/BBC_Sound_Effects{000000..000062}.tar
diff --git a/minigpt4/configs/datasets/freesound/defaults.yaml b/minigpt4/configs/datasets/freesound/defaults.yaml
new file mode 100644
index 0000000..edc520f
--- /dev/null
+++ b/minigpt4/configs/datasets/freesound/defaults.yaml
@@ -0,0 +1,5 @@
+datasets:
+  freesound:
+    data_type: audio
+    build_info:
+      storage: /mnt/bn/zilongdata-hl/dataset/wavcaps/web_datasets/FreeSound/FreeSound{000000..000524}.tar
diff --git a/minigpt4/configs/datasets/soundbible/defaults.yaml b/minigpt4/configs/datasets/soundbible/defaults.yaml
new file mode 100644
index 0000000..7b46a67
--- /dev/null
+++ b/minigpt4/configs/datasets/soundbible/defaults.yaml
@@ -0,0 +1,5 @@
+datasets:
+  soundbible:
+    data_type: audio
+    build_info:
+      storage: /mnt/bn/zilongdata-hl/dataset/wavcaps/web_datasets/SoundBible/SoundBible0.tar
diff --git a/minigpt4/datasets/builders/__init__.py b/minigpt4/datasets/builders/__init__.py
index 744fe2f..7fd3f7d 100644
--- a/minigpt4/datasets/builders/__init__.py
+++ b/minigpt4/datasets/builders/__init__.py
@@ -11,12 +11,23 @@ from minigpt4.datasets.builders.image_text_pair_builder import (
     LaionBuilderImage,
     CCSBUAlignBuilderImage
 )
+from minigpt4.datasets.builders.audio_text_pair_builder import (
+    BBCBuilder,
+    AudioSetBuilder,
+    SoundBibleBuilder,
+    FreeSoundBuilder
+)
 from minigpt4.common.registry import registry
 
 __all__ = [
     "CCSBUBuilderImage",
     "LaionBuilderImage",
-    "CCSBUAlignBuilderImage"
+    "CCSBUAlignBuilderImage",
+    # Audio builders
+    "BBCBuilder",
+    "AudioSetBuilder",
+    "SoundBibleBuilder",
+    "FreeSoundBuilder",
 ]
 
 
diff --git a/minigpt4/datasets/builders/audio_base_dataset_builder.py b/minigpt4/datasets/builders/audio_base_dataset_builder.py
index 4e1361e..711fb7f 100644
--- a/minigpt4/datasets/builders/audio_base_dataset_builder.py
+++ b/minigpt4/datasets/builders/audio_base_dataset_builder.py
@@ -31,6 +31,51 @@ class AudioBaseDatasetBuilder:
 
         self.data_type = self.config.data_type
 
+        self.audio_processors = {"train": BaseProcessor(), "eval": BaseProcessor()}
+        self.text_processors = {"train": BaseProcessor(), "eval": BaseProcessor()}
+
+    def build_datasets(self):
+        # download, split, etc...
+        # only called on 1 GPU/TPU in distributed
+
+        if is_main_process():
+            self._download_data()
+
+        if is_dist_avail_and_initialized():
+            dist.barrier()
+
+        # at this point, all the annotations and image/videos should be all downloaded to the specified locations.
+        logging.info("Building datasets...")
+        datasets = self.build()  # dataset['train'/'val'/'test']
+
+        return datasets
+
+    def build_processors(self):
+        aud_proc_cfg = self.config.get("audio_processor")
+        txt_proc_cfg = self.config.get("text_processor")
+
+        if aud_proc_cfg is not None:
+            aud_train_cfg = aud_proc_cfg.get("train")
+            aud_eval_cfg = aud_proc_cfg.get("eval")
+
+            self.audio_processors["train"] = self._build_proc_from_cfg(aud_train_cfg)
+            self.audio_processors["eval"] = self._build_proc_from_cfg(aud_eval_cfg)
+
+        if txt_proc_cfg is not None:
+            txt_train_cfg = txt_proc_cfg.get("train")
+            txt_eval_cfg = txt_proc_cfg.get("eval")
+
+            self.text_processors["train"] = self._build_proc_from_cfg(txt_train_cfg)
+            self.text_processors["eval"] = self._build_proc_from_cfg(txt_eval_cfg)
+
+    @staticmethod
+    def _build_proc_from_cfg(cfg):
+        return (
+            registry.get_processor_class(cfg.name).from_config(cfg)
+            if cfg is not None
+            else None
+        )
+
     @classmethod
     def default_config_path(cls, type="default"):
         return utils.get_abs_path(cls.DATASET_CONFIG_DICT[type])
@@ -95,5 +140,3 @@ class AudioBaseDatasetBuilder:
                         filename = os.path.basename(storage_path)
 
                     download_url(url=url_or_filename, root=dirname, filename=filename)
-
-
diff --git a/minigpt4/datasets/builders/audio_text_pair_builder.py b/minigpt4/datasets/builders/audio_text_pair_builder.py
new file mode 100644
index 0000000..d63bdcb
--- /dev/null
+++ b/minigpt4/datasets/builders/audio_text_pair_builder.py
@@ -0,0 +1,56 @@
+import os
+import logging
+import warnings
+
+from minigpt4.common.registry import registry
+from minigpt4.datasets.builders.audio_base_dataset_builder import AudioBaseDatasetBuilder
+from minigpt4.datasets.datasets.audio_caption import GenericAudioDataset
+
+
+
+class GenericAudioBuilder(AudioBaseDatasetBuilder):
+    train_dataset_cls = GenericAudioDataset
+
+    def _download_ann(self):
+        pass
+
+    def _download_aud(self):
+        pass
+
+    def build(self):
+        self.build_processors()
+
+        build_info = self.config.build_info
+
+        datasets = dict()
+        split = "train"
+
+        # create datasets
+        dataset_cls = self.train_dataset_cls
+        datasets[split] = dataset_cls(
+            audio_processor=self.audio_processors[split],
+            text_processor=self.text_processors[split],
+            location=build_info.storage,
+        ).inner_dataset
+
+        return datasets
+    
+
+@registry.register_builder("bbc")
+class BBCBuilder(GenericAudioBuilder):
+    DATASET_CONFIG_DICT = {"default": "configs/datasets/bbc/defaults.yaml"}
+
+
+@registry.register_builder("audioset")
+class AudioSetBuilder(GenericAudioBuilder):
+    DATASET_CONFIG_DICT = {"default": "configs/datasets/audioset/defaults.yaml"}
+
+
+@registry.register_builder("soundbible")
+class SoundBibleBuilder(GenericAudioBuilder):
+    DATASET_CONFIG_DICT = {"default": "configs/datasets/soundbible/defaults.yaml"}
+
+
+@registry.register_builder("freesound")
+class FreeSoundBuilder(GenericAudioBuilder):
+    DATASET_CONFIG_DICT = {"default": "configs/datasets/freesound/defaults.yaml"}
diff --git a/minigpt4/datasets/datasets/audio_caption/__init__.py b/minigpt4/datasets/datasets/audio_caption/__init__.py
index e69de29..c511eb3 100644
--- a/minigpt4/datasets/datasets/audio_caption/__init__.py
+++ b/minigpt4/datasets/datasets/audio_caption/__init__.py
@@ -0,0 +1 @@
+from minigpt4.datasets.datasets.audio_caption.audio_caption_datasets import GenericAudioDataset
diff --git a/minigpt4/datasets/datasets/audio_caption/audio_caption_datasets.py b/minigpt4/datasets/datasets/audio_caption/audio_caption_datasets.py
index 52b6dcd..57aa027 100644
--- a/minigpt4/datasets/datasets/audio_caption/audio_caption_datasets.py
+++ b/minigpt4/datasets/datasets/audio_caption/audio_caption_datasets.py
@@ -6,21 +6,22 @@ from minigpt4.datasets.datasets.base_dataset import BaseDataset
 
 
 class GenericAudioDataset(BaseDataset):
-    def __init__(self, vision_processor, text_processor, location):
-        super().__init__(x_processor=vision_processor, text_processor=text_processor)
+    def __init__(self, audio_processor, text_processor, location):
+        super().__init__(x_processor=audio_processor, text_processor=text_processor)
 
         self.inner_dataset = wds.DataPipeline(
             wds.ResampledShards(location),
             wds.tarfile_to_samples(handler=wds.warn_and_continue),
             wds.shuffle(1000, handler=wds.warn_and_continue),
             wds.decode(wds.torch_audio, handler=wds.warn_and_continue),
-            wds.to_tuple("jpg", "json", handler=wds.warn_and_continue),
+            wds.to_tuple("flac", "json", handler=wds.warn_and_continue),
             wds.map_tuple(self.x_processor, handler=wds.warn_and_continue),
             wds.map(self.to_dict, handler=wds.warn_and_continue),
         )
 
     def to_dict(self, sample):
         return {
-            "image": sample[0],
+            "audio": sample[0],
+            # [clips_per_video, channel, mel_bins, time_steps]
             "text_input": self.text_processor(sample[1]["caption"]),
         }
diff --git a/minigpt4/datasets/datasets/image_caption/cc_sbu_dataset.py b/minigpt4/datasets/datasets/image_caption/cc_sbu_dataset.py
index eaaba91..9a4ffde 100644
--- a/minigpt4/datasets/datasets/image_caption/cc_sbu_dataset.py
+++ b/minigpt4/datasets/datasets/image_caption/cc_sbu_dataset.py
@@ -21,7 +21,7 @@ class CCSBUDataset(BaseDataset):
 
     def to_dict(self, sample):
         return {
-            "image": sample[0],
+            "vision": sample[0],
             "text_input": self.text_processor(sample[1]["caption"]),
         }
 
@@ -41,7 +41,7 @@ class CCSBUAlignDatasetImageImageCaptionDataset(ImageCaptionDataset):
         caption = ann["caption"]
 
         return {
-            "image": image,
+            "vision": image,
             "text_input": caption,
             "image_id": self.img_ids[ann["image_id"]],
         }
@@ -49,7 +49,7 @@ class CCSBUAlignDatasetImageImageCaptionDataset(ImageCaptionDataset):
 
 class CCDataset(BaseDataset):
     def __init__(self, vis_processor, text_processor, location):
-        super().__init__(vis_processor=vis_processor, text_processor=text_processor)
+        super().__init__(x_processor=vis_processor, text_processor=text_processor)
 
         self.inner_dataset = wds.DataPipeline(
             wds.ResampledShards(location),
@@ -57,12 +57,12 @@ class CCDataset(BaseDataset):
             wds.shuffle(1000, handler=wds.warn_and_continue),
             wds.decode("pilrgb", handler=wds.warn_and_continue),
             wds.to_tuple("jpg", "txt", handler=wds.warn_and_continue),
-            wds.map_tuple(self.vis_processor, handler=wds.warn_and_continue),
+            wds.map_tuple(self.x_processor, handler=wds.warn_and_continue),
             wds.map(self.to_dict, handler=wds.warn_and_continue),
         )
 
     def to_dict(self, sample):
         return {
-            "image": sample[0],
+            "vision": sample[0],
             "text_input": sample[1],
         }
diff --git a/minigpt4/datasets/datasets/image_caption/image_caption_datasets.py b/minigpt4/datasets/datasets/image_caption/image_caption_datasets.py
index 120fcf3..607d888 100644
--- a/minigpt4/datasets/datasets/image_caption/image_caption_datasets.py
+++ b/minigpt4/datasets/datasets/image_caption/image_caption_datasets.py
@@ -42,7 +42,7 @@ class ImageCaptionDataset(BaseDataset, __ImageDisplMixin):
         caption = self.text_processor(ann["caption"])
 
         return {
-            "image": image,
+            "vision": image,
             "text_input": caption,
             "image_id": self.img_ids[ann["image_id"]],
         }
@@ -67,7 +67,7 @@ class CaptionEvalDataset(BaseDataset, __ImageDisplMixin):
         image = self.x_processor(image)
 
         return {
-            "image": image,
+            "vision": image,
             "image_id": ann["image_id"],
             "instance_id": ann["instance_id"],
         }
diff --git a/minigpt4/datasets/datasets/image_caption/laion_dataset.py b/minigpt4/datasets/datasets/image_caption/laion_dataset.py
index 8cb34a8..e56160c 100644
--- a/minigpt4/datasets/datasets/image_caption/laion_dataset.py
+++ b/minigpt4/datasets/datasets/image_caption/laion_dataset.py
@@ -25,7 +25,7 @@ class LaionDataset(BaseDataset):
 
     def to_dict(self, sample):
         return {
-            "image": sample[0],
+            "vision": sample[0],
             "text_input": self.text_processor(sample[1]["caption"]),
         }
 
diff --git a/minigpt4/datasets/datasets/mixins/mixins.py b/minigpt4/datasets/datasets/mixins/mixins.py
index e143147..1c677fd 100644
--- a/minigpt4/datasets/datasets/mixins/mixins.py
+++ b/minigpt4/datasets/datasets/mixins/mixins.py
@@ -9,7 +9,7 @@ class __ImageDisplMixin:
             {
                 "file": ann["image"],
                 "caption": ann["caption"],
-                "image": sample["image"],
+                "vision": sample["vision"],
             }
         )
 
diff --git a/minigpt4/models/bind_gpt4.py b/minigpt4/models/bind_gpt4.py
index 1320ced..a5973ba 100644
--- a/minigpt4/models/bind_gpt4.py
+++ b/minigpt4/models/bind_gpt4.py
@@ -1,8 +1,9 @@
 import random
-from typing import Dict, Tuple
+from typing import Dict, Tuple, List
 
 import torch
 import torch.nn as nn
+import re
 from torch import Tensor
 from transformers import LlamaTokenizer
 
@@ -12,6 +13,36 @@ from minigpt4.models.blip2 import BaseModel
 from minigpt4.models.modeling_llama import LlamaForCausalLM
 
 
+def filter_prompt(input_embeds: Dict[str, Tensor], prompt_list: List[str]) -> List[str]:
+    if not prompt_list:
+        return prompt_list
+    input_modal_set = set([k.title() for k in input_embeds if input_embeds[k] is not None])
+    prompt_modal_sets = [set(re.findall("<([^<>]+)><ModalityHere></\\1>", prompt)) for prompt in prompt_list]
+    results = [prompt_list[i] for i, prompt_modal_set in enumerate(prompt_modal_sets) if
+               prompt_modal_set == input_modal_set]
+    return results
+
+
+def arrange_modalities(input_embeds: Dict[str, Tensor], prompt: str) -> List[Tensor]:
+    prompt_modalities = re.findall("<([^<>]+)><ModalityHere></\\1>", prompt)
+    return [input_embeds[modality.lower()] for modality in prompt_modalities]
+
+
+def concat_all_embeddings(input_embeds: Dict[str, Tensor], dim: int) -> Tensor:
+    embeds = [input_embeds[key] for key in input_embeds if input_embeds[key] is not None]
+    return torch.cat(embeds, dim=dim)
+
+
+def filter_modalities(inputs):
+    filtered_inputs = {}
+
+    for k in ModalityType.__dict__.values():
+        if k in inputs:
+            filtered_inputs[k] = inputs[k]
+
+    return filtered_inputs
+
+
 @registry.register_model("bind_gpt4")
 class BindGPT4(BaseModel):
     """
@@ -61,7 +92,9 @@ class BindGPT4(BaseModel):
         print('Loading Q-Former and Adapter/Projector')
         self.multimodal_joiner = ImageBindJoiner(vision_query_token_num=num_query_token,
                                                  vision_qformer_frozen=freeze_qformer,
-                                                 vision_post_dims=[768, self.llama_model.config.hidden_size]
+                                                 vision_post_dims=[768, self.llama_model.config.hidden_size],
+                                                 audio_query_token_num=num_query_token,
+                                                 audio_post_dims=[768, self.llama_model.config.hidden_size]
                                                  # vision_qformer_model=q_former_model,
                                                  # vision_pre_dims=(1280, 1408)
                                                  )
@@ -84,42 +117,37 @@ class BindGPT4(BaseModel):
     def encode_inputs(self, inputs: Dict[str, Tensor]) -> Dict[str, Tensor]:
         imagebind_outputs = self.multimodal_encoder(inputs)
         llama_inputs = self.multimodal_joiner(imagebind_outputs)
-        # NOTE: only accept image here
         return llama_inputs
 
-    def prompt_wrap(self, inputs: Dict[str, Tensor], modality_name: str, prompt: str) -> Tuple[Tensor, Tensor]:
-        # TODO: Accept More Modalities.
-        input_embeds = inputs[modality_name]
-        attns_input = torch.ones(input_embeds.size()[:-1], dtype=torch.long).to(input_embeds.device)
-        if prompt:
-            batch_size = input_embeds.shape[0]
-            p_before, p_after = prompt.split('<ModalityHere>')
-            p_before_tokens = self.llama_tokenizer(
-                p_before, return_tensors="pt", add_special_tokens=False).to(input_embeds.device)
-            p_after_tokens = self.llama_tokenizer(
-                p_after, return_tensors="pt", add_special_tokens=False).to(input_embeds.device)
-            p_before_embeds = self.llama_model.model.embed_tokens(p_before_tokens.input_ids).expand(batch_size, -1, -1)
-            p_after_embeds = self.llama_model.model.embed_tokens(p_after_tokens.input_ids).expand(batch_size, -1, -1)
-            wrapped_input_embeds = torch.cat([p_before_embeds, inputs, p_after_embeds], dim=1)
-            wrapped_atts_input = attns_input[:, :1].expand(-1, wrapped_input_embeds.shape[1])
-            return wrapped_input_embeds, wrapped_atts_input
-        else:
+    def prompt_wrap(self, inputs: Dict[str, Tensor], prompt: str) -> Tuple[Tensor, Tensor]:
+        if not prompt:
+            input_embeds = concat_all_embeddings(inputs, dim=1)
+            attns_input = torch.ones(input_embeds.size()[:-1], dtype=torch.long).to(input_embeds.device)
             return input_embeds, attns_input
+        input_embeds_list = arrange_modalities(inputs, prompt)
+        batch_size = input_embeds_list[0].shape[0]
+        prompt_slices = prompt.split('<ModalityHere>')
+        prompt_tokens = [self.llama_tokenizer(prompt_slice, return_tensors="pt", add_special_tokens=False)
+                             .to(input_embeds_list[0].device) for prompt_slice in prompt_slices]
+        prompt_embeds = [self.llama_model.model.embed_tokens(prompt_token.input_ids).expand(batch_size, -1, -1)
+                         for prompt_token in prompt_tokens]
+        result_embeds = [emb for pair in zip(prompt_embeds[:-1], input_embeds_list)
+                         for emb in pair] + [prompt_embeds[-1]]
+        wrapped_input_embeds = torch.cat(result_embeds, dim=1)
+        wrapped_atts_input = torch.ones(wrapped_input_embeds.size()[:-1],
+                                        dtype=torch.long).to(wrapped_input_embeds.device)
+        return wrapped_input_embeds, wrapped_atts_input
 
     def forward(self, inputs: Dict[str, Tensor]) -> Dict[str, Tensor]:
-        """
-            TODO: More Modalities.
-            Only accept image inputs here.
-            Other modalities will conflict with the pre-defined prompt and wrapping strategy.
-        """
-        bind_inputs = {ModalityType.VISION: inputs['image']}
-        embeds = self.encode_inputs(bind_inputs)
-        # assert "vision" in embeds, "Only Vision Input Can Be Accepted Now."
-        if self.prompt_list:
-            prompt = random.choice(self.prompt_list)
+        # filter `inputs` as it may contain informatioins other than modalities
+        modality_inputs = filter_modalities(inputs)
+        embeds = self.encode_inputs(modality_inputs)
+        filtered_prompts = filter_prompt(embeds, self.prompt_list)
+        if filtered_prompts:
+            prompt = random.choice(filtered_prompts)
         else:
             prompt = None
-        img_embeds, atts_img = self.prompt_wrap(embeds, ModalityType.VISION, prompt)
+        input_embs, input_atts = self.prompt_wrap(embeds, prompt)
 
         # NOTE: No modifications from the next line to the end. Except for the autocast part.
 
@@ -134,28 +162,28 @@ class BindGPT4(BaseModel):
             truncation=True,
             max_length=self.max_txt_len,
             add_special_tokens=False
-        ).to(img_embeds.device)
+        ).to(input_embs.device)
 
         targets = to_regress_tokens.input_ids.masked_fill(
             to_regress_tokens.input_ids == self.llama_tokenizer.pad_token_id, -100
         )
 
         empty_targets = (
-            torch.ones([atts_img.shape[0], atts_img.shape[1] + 1],
-                       dtype=torch.long).to(img_embeds.device).fill_(-100)  # plus one for bos
+            torch.ones([input_atts.shape[0], input_atts.shape[1] + 1],
+                       dtype=torch.long).to(input_embs.device).fill_(-100)  # plus one for bos
         )
         targets = torch.cat([empty_targets, targets], dim=1)
 
-        batch_size = img_embeds.shape[0]
+        batch_size = input_embs.shape[0]
         bos = torch.ones([batch_size, 1],
                          dtype=to_regress_tokens.input_ids.dtype,
                          device=to_regress_tokens.input_ids.device) * self.llama_tokenizer.bos_token_id
         bos_embeds = self.llama_model.model.embed_tokens(bos)
-        atts_bos = atts_img[:, :1]
+        atts_bos = input_atts[:, :1]
 
         to_regress_embeds = self.llama_model.model.embed_tokens(to_regress_tokens.input_ids)
-        inputs_embeds = torch.cat([bos_embeds, img_embeds, to_regress_embeds], dim=1)
-        attention_mask = torch.cat([atts_bos, atts_img, to_regress_tokens.attention_mask], dim=1)
+        inputs_embeds = torch.cat([bos_embeds, input_embs, to_regress_embeds], dim=1)
+        attention_mask = torch.cat([atts_bos, input_atts, to_regress_tokens.attention_mask], dim=1)
 
         outputs = self.llama_model(
             inputs_embeds=inputs_embeds,
diff --git a/minigpt4/models/blip2.py b/minigpt4/models/blip2.py
index e33bf82..46c1eea 100644
--- a/minigpt4/models/blip2.py
+++ b/minigpt4/models/blip2.py
@@ -144,7 +144,7 @@ def compute_sim_matrix(model, data_loader, **kwargs):
     vit_feats = []
     image_embeds = []
     for samples in data_loader:
-        image = samples["image"]
+        image = samples["vision"]
 
         image = image.to(model.device)
         image_feat, vit_feat = model.forward_image(image)
diff --git a/minigpt4/models/mini_gpt4.py b/minigpt4/models/mini_gpt4.py
index 19e04f5..90ec011 100644
--- a/minigpt4/models/mini_gpt4.py
+++ b/minigpt4/models/mini_gpt4.py
@@ -164,7 +164,7 @@ class MiniGPT4(Blip2Base):
             return img_embeds, atts_img
 
     def forward(self, samples):
-        image = samples["image"]
+        image = samples["vision"]
         img_embeds, atts_img = self.encode_img(image)
         if hasattr(samples, 'question_split'):  # VQA dataset
             print('VQA Batch')
diff --git a/minigpt4/processors/__init__.py b/minigpt4/processors/__init__.py
index 0ce174e..57b6d9a 100644
--- a/minigpt4/processors/__init__.py
+++ b/minigpt4/processors/__init__.py
@@ -11,11 +11,15 @@ from minigpt4.processors.blip_processors import (
     Blip2ImageEvalProcessor,
     BlipCaptionProcessor,
 )
-from minigpt4.processors.imagebind_processor import (
+from minigpt4.processors.imagebind_vision_processor import (
     ImageBindCaptionProcessor,
     ImageBindVisionTrainProcessor,
     ImageBindVisionEvalProcessor
 )
+from minigpt4.processors.imagebind_audio_processor import (
+    ImageBindAudioTrainProcessor,
+    ImageBindAudioEvalProcessor,
+)
 
 from minigpt4.common.registry import registry
 
@@ -26,7 +30,9 @@ __all__ = [
     "BlipCaptionProcessor",
     "ImageBindCaptionProcessor",
     "ImageBindVisionTrainProcessor",
-    "ImageBindVisionEvalProcessor"
+    "ImageBindVisionEvalProcessor",
+    "ImageBindAudioTrainProcessor",
+    "ImageBindAudioEvalProcessor",
 ]
 
 
diff --git a/minigpt4/processors/audio_augment.py b/minigpt4/processors/audio_augment.py
new file mode 100644
index 0000000..869db2e
--- /dev/null
+++ b/minigpt4/processors/audio_augment.py
@@ -0,0 +1,190 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+# @Author  : Xinhao Mei @CVSSP, University of Surrey
+# @E-mail  : x.mei@surrey.ac.uk
+
+"""
+    Implemenation of SpecAugment++,
+    Adapated from Qiuqiang Kong's trochlibrosa:
+    https://github.com/qiuqiangkong/torchlibrosa/blob/master/torchlibrosa/augmentation.py
+"""
+
+import torch
+import torch.nn as nn
+
+
+class DropStripes:
+
+    def __init__(self, dim, drop_width, stripes_num):
+        """ Drop stripes.
+        args:
+            dim: int, dimension along which to drop
+            drop_width: int, maximum width of stripes to drop
+            stripes_num: int, how many stripes to drop
+        """
+        super(DropStripes, self).__init__()
+
+        assert dim in [2, 3]  # dim 2: time; dim 3: frequency
+
+        self.dim = dim
+        self.drop_width = drop_width
+        self.stripes_num = stripes_num
+
+    def __call__(self, input):
+        """input: (batch_size, channels, time_steps, freq_bins)"""
+
+        assert input.ndimension() == 4
+        batch_size = input.shape[0]
+        total_width = input.shape[self.dim]
+
+        for n in range(batch_size):
+            self.transform_slice(input[n], total_width)
+
+        return input
+
+    def transform_slice(self, e, total_width):
+        """ e: (channels, time_steps, freq_bins)"""
+
+        for _ in range(self.stripes_num):
+            distance = torch.randint(low=0, high=self.drop_width, size=(1,))[0]
+            bgn = torch.randint(low=0, high=total_width - distance, size=(1,))[0]
+
+            if self.dim == 2:
+                e[:, bgn: bgn + distance, :] = 0
+            elif self.dim == 3:
+                e[:, :, bgn: bgn + distance] = 0
+
+
+class MixStripes:
+
+    def __init__(self, dim, mix_width, stripes_num):
+        """ Mix stripes
+        args:
+            dim: int, dimension along which to mix
+            mix_width: int, maximum width of stripes to mix
+            stripes_num: int, how many stripes to mix
+        """
+
+        super(MixStripes, self).__init__()
+
+        assert dim in [2, 3]
+
+        self.dim = dim
+        self.mix_width = mix_width
+        self.stripes_num = stripes_num
+
+    def __call__(self, input):
+        """input: (batch_size, channel, time_steps, freq_bins)"""
+
+        assert input.ndimension() == 4
+
+        batch_size = input.shape[0]
+        total_width = input.shape[self.dim]
+
+        rand_sample = input[torch.randperm(batch_size)]
+        for i in range(batch_size):
+            self.transform_slice(input[i], rand_sample[i], total_width)
+        return input
+
+    def transform_slice(self, input, random_sample, total_width):
+
+        for _ in range(self.stripes_num):
+            distance = torch.randint(low=0, high=self.mix_width, size=(1,))[0]
+            bgn = torch.randint(low=0, high=total_width - distance, size=(1,))[0]
+
+            if self.dim == 2:
+                input[:, bgn: bgn + distance, :] = 0.5 * input[:, bgn: bgn + distance, :] + \
+                                                   0.5 * random_sample[:, bgn: bgn + distance, :]
+            elif self.dim == 3:
+                input[:, :, bgn: bgn + distance] = 0.5 * input[:, :, bgn: bgn + distance] + \
+                                                   0.5 * random_sample[:, :, bgn: bgn + distance]
+
+
+class CutStripes:
+
+    def __init__(self, dim, cut_width, stripes_num):
+        """ Cutting stripes with another randomly selected sample in mini-batch.
+        args:
+            dim: int, dimension along which to cut
+            cut_width: int, maximum width of stripes to cut
+            stripes_num: int, how many stripes to cut
+        """
+
+        super(CutStripes, self).__init__()
+
+        assert dim in [2, 3]
+
+        self.dim = dim
+        self.cut_width = cut_width
+        self.stripes_num = stripes_num
+
+    def __call__(self, input):
+        """input: (batch_size, channel, time_steps, freq_bins)"""
+
+        assert input.ndimension() == 4
+
+        batch_size = input.shape[0]
+        total_width = input.shape[self.dim]
+
+        rand_sample = input[torch.randperm(batch_size)]
+        for i in range(batch_size):
+            self.transform_slice(input[i], rand_sample[i], total_width)
+        return input
+
+    def transform_slice(self, input, random_sample, total_width):
+
+        for _ in range(self.stripes_num):
+            distance = torch.randint(low=0, high=self.cut_width, size=(1,))[0]
+            bgn = torch.randint(low=0, high=total_width - distance, size=(1,))[0]
+
+            if self.dim == 2:
+                input[:, bgn: bgn + distance, :] = random_sample[:, bgn: bgn + distance, :]
+            elif self.dim == 3:
+                input[:, :, bgn: bgn + distance] = random_sample[:, :, bgn: bgn + distance]
+
+
+class SpecAugmentation:
+
+    def __init__(self, time_drop_width, time_stripes_num, freq_drop_width, freq_stripes_num,
+                 mask_type='mixture'):
+        """Spec augmetation and SpecAugment++.
+        [ref] Park, D.S., Chan, W., Zhang, Y., Chiu, C.C., Zoph, B., Cubuk, E.D.
+        and Le, Q.V., 2019. Specaugment: A simple data augmentation method
+        for automatic speech recognition. arXiv preprint arXiv:1904.08779.
+        [ref] Wang H, Zou Y, Wang W., 2021. SpecAugment++: A Hidden Space
+        Data Augmentation Method for Acoustic Scene Classification. arXiv
+        preprint arXiv:2103.16858.
+
+        Args:
+            time_drop_width: int
+            time_stripes_num: int
+            freq_drop_width: int
+            freq_stripes_num: int
+            mask_type: str, mask type in SpecAugment++ (zero_value, mixture, cutting)
+        """
+
+        super(SpecAugmentation, self).__init__()
+
+        if mask_type == 'zero_value':
+            self.time_augmentator = DropStripes(dim=2, drop_width=time_drop_width,
+                                                stripes_num=time_stripes_num)
+            self.freq_augmentator = DropStripes(dim=3, drop_width=freq_drop_width,
+                                                stripes_num=freq_stripes_num)
+        elif mask_type == 'mixture':
+            self.time_augmentator = MixStripes(dim=2, mix_width=time_drop_width,
+                                               stripes_num=time_stripes_num)
+            self.freq_augmentator = MixStripes(dim=3, mix_width=freq_drop_width,
+                                               stripes_num=freq_stripes_num)
+        elif mask_type == 'cutting':
+            self.time_augmentator = CutStripes(dim=2, cut_width=time_drop_width,
+                                               stripes_num=time_stripes_num)
+            self.freq_augmentator = CutStripes(dim=3, cut_width=freq_drop_width,
+                                               stripes_num=freq_stripes_num)
+        else:
+            raise NameError('No such mask type in SpecAugment++')
+
+    def __call__(self, inputs):
+        # x should be in size [batch_size, channel, time_steps, freq_bins]
+        x = self.time_augmentator(inputs)
+        x = self.freq_augmentator(x)
+        return x
diff --git a/minigpt4/processors/blip_processors.py b/minigpt4/processors/blip_processors.py
index fd26160..d2bc43f 100644
--- a/minigpt4/processors/blip_processors.py
+++ b/minigpt4/processors/blip_processors.py
@@ -9,7 +9,7 @@ import re
 
 from minigpt4.common.registry import registry
 from minigpt4.processors.base_processor import BaseProcessor
-from minigpt4.processors.randaugment import RandomAugment
+from minigpt4.processors.vision_augment import RandomAugment
 from omegaconf import OmegaConf
 from torchvision import transforms
 from torchvision.transforms.functional import InterpolationMode
diff --git a/minigpt4/processors/imagebind_audio_processor.py b/minigpt4/processors/imagebind_audio_processor.py
new file mode 100644
index 0000000..fa64cb3
--- /dev/null
+++ b/minigpt4/processors/imagebind_audio_processor.py
@@ -0,0 +1,166 @@
+from typing import Union, List
+
+import torch
+import torchaudio
+from omegaconf import OmegaConf
+from pytorchvideo.data.clip_sampling import ConstantClipsPerVideoSampler, RandomMultiClipSampler
+from torch import Tensor
+
+from imagebind.data.data_utils import waveform2melspec, get_constant_clip_timepoints, \
+    get_random_clip_timepoints
+from minigpt4.processors.base_processor import BaseProcessor
+from torchvision import transforms
+from minigpt4.common.registry import registry
+from minigpt4.processors.audio_augment import SpecAugmentation
+
+
+class ImageBindAudioBaseProcessor(BaseProcessor):
+    def __init__(self, mean=None, std=None, target_sr=None, clip_duration=None, clips_per_video=None,
+                 num_mel_bins=None, target_length=None, clip_sample_method="Random"):
+        super().__init__()
+        self.mean = -4.268 if mean is None else mean
+        self.std = 9.138 if std is None else std
+        self.target_sr = 16000 if target_sr is None else target_sr
+        self.num_mel_bins = num_mel_bins
+        self.target_length = target_length
+        self.clip_sampler = self._construct_clip_sampler(clip_duration, clips_per_video, clip_sample_method)
+        self.normalize = transforms.Normalize(self.mean, self.std)
+
+    def _construct_clip_sampler(self, clip_duration, clips_per_video, clip_sample_method):
+        if clip_duration is None or clips_per_video is None:
+            return None
+        if clip_sample_method == "Constant":
+            return ConstantClipsPerVideoSampler(
+                clip_duration=clip_duration, clips_per_video=clips_per_video
+            )
+        elif clip_sample_method == "Random":
+            return RandomMultiClipSampler(clip_duration=clip_duration, num_clips=clips_per_video)
+        else:
+            raise NotImplementedError
+
+    def waveform_resample(self, waveform: Tensor, origin_sr: int) -> Tensor:
+        return torchaudio.functional.resample(
+            waveform, orig_freq=origin_sr, new_freq=self.target_sr
+        )
+
+    def clip_sample(self, waveform: Tensor) -> List[Tensor]:
+        if self.clip_sampler is None:
+            return [waveform]
+        elif isinstance(self.clip_sampler, ConstantClipsPerVideoSampler):
+            all_clips_timepoints = get_constant_clip_timepoints(self.clip_sampler, waveform.size(1) / self.target_sr)
+        elif isinstance(self.clip_sampler, RandomMultiClipSampler):
+            all_clips_timepoints = get_random_clip_timepoints(self.clip_sampler, waveform.size(1) / self.target_sr)
+        else:
+            raise NotImplementedError
+        all_clips = []
+        for clip_timepoints in all_clips_timepoints:
+            start_pos = int(clip_timepoints[0] * self.target_sr)
+            end_pos = int(clip_timepoints[1] * self.target_sr)
+            waveform_clip = waveform[:, start_pos: end_pos]
+            all_clips.append(waveform_clip)
+        return all_clips
+
+    def waveform_melspec(self, waveforms: Union[List[Tensor], Tensor]) -> List[Tensor]:
+        if isinstance(waveforms, Tensor):
+            return waveform2melspec(waveforms, self.target_sr, self.num_mel_bins, self.target_length)
+        else:
+            return [waveform2melspec(waveform, self.target_sr, self.num_mel_bins, self.target_length)
+                    for waveform in waveforms]
+
+
+@registry.register_processor("imagebind_audio_train")
+class ImageBindAudioTrainProcessor(ImageBindAudioBaseProcessor):
+    def __init__(self, mean=None, std=None, target_sr=None, clip_duration=None, clips_per_video=None,
+                 clip_sample_method="Random", num_mel_bins=None, target_length=None, time_drop_width=13,
+                 time_stripes_num=2, freq_drop_width=8, freq_stripes_num=2, mask_type='mixture'):
+        super().__init__(mean=mean, std=std, target_sr=target_sr,
+                         clip_duration=clip_duration, clips_per_video=clips_per_video,
+                         num_mel_bins=num_mel_bins, target_length=target_length,
+                         clip_sample_method=clip_sample_method)
+        self.spec_augment = SpecAugmentation(time_drop_width, time_stripes_num,
+                                             freq_drop_width, freq_stripes_num, mask_type)
+
+    def __call__(self, item):
+        # item: Tuple[Tensor, int]
+        waveform, origin_sr = item[0], item[1]
+        waveform = self.waveform_resample(waveform, origin_sr)
+        waveform_clips = self.clip_sample(waveform)
+        melspec_clips = self.waveform_melspec(waveform_clips)
+        normed_melspecs = [self.normalize(clip) for clip in melspec_clips]
+        all_clips = torch.stack(normed_melspecs, dim=0)
+        # all_clips: [clips_per_video, channel, mel_bins, time_steps]
+        # augment: [batch_size, channel, time_steps, freq_bins]
+        augmented_clips = self.spec_augment(all_clips.transpose(-2, -1)).transpose(-2, -1)
+        return augmented_clips
+
+    @classmethod
+    def from_config(cls, cfg=None):
+        if cfg is None:
+            cfg = OmegaConf.create()
+
+        target_sr = cfg.get("target_sr", 16000)
+        clip_duration = cfg.get("clip_duration", 2)
+        clips_per_video = cfg.get("clips_per_video", 3)
+        num_mel_bins = cfg.get("num_mel_bins", 128)
+        target_length = cfg.get("target_length", 204)
+        time_drop_width = cfg.get("time_drop_width", 13)
+        time_stripes_num = cfg.get("time_stripes_num", 2)
+        # 13 * 2 / 204 = 12.75% Time Mask
+        freq_drop_width = cfg.get("freq_drop_width", 8)
+        freq_stripes_num = cfg.get("freq_stripes_num", 2)
+        # 8 * 2 / 128 = 12.5% Freq Mask
+        mask_type = cfg.get("mask_type", 'mixture')
+
+        mean = cfg.get("mean", None)
+        std = cfg.get("std", None)
+
+        return cls(
+            mean=mean, std=std, target_sr=target_sr,
+            clip_duration=clip_duration, clips_per_video=clips_per_video,
+            num_mel_bins=num_mel_bins, target_length=target_length,
+            time_drop_width=time_drop_width, time_stripes_num=time_stripes_num,
+            freq_drop_width=freq_drop_width, freq_stripes_num=freq_stripes_num,
+            mask_type=mask_type
+        )
+
+
+@registry.register_processor("imagebind_audio_eval")
+class ImageBindAudioEvalProcessor(ImageBindAudioBaseProcessor):
+    def __init__(self, mean=None, std=None, target_sr=None, clip_duration=None, clips_per_video=None,
+                 clip_sample_method="Constant", num_mel_bins=None, target_length=None):
+        super().__init__(mean=mean, std=std, target_sr=target_sr,
+                         clip_duration=clip_duration, clips_per_video=clips_per_video,
+                         num_mel_bins=num_mel_bins, target_length=target_length,
+                         clip_sample_method=clip_sample_method)
+
+    def __call__(self, item):
+        # item: Tuple[Tensor, int]
+        waveform, origin_sr = item[0], item[1]
+        waveform = self.waveform_resample(waveform, origin_sr)
+        waveform_clips = self.clip_sample(waveform)
+        melspec_clips = self.waveform_melspec(waveform_clips)
+        normed_melspecs = [self.normalize(clip) for clip in melspec_clips]
+        all_clips = torch.stack(normed_melspecs, dim=0)
+        # all_clips: [clips_per_video, channel, mel_bins, time_steps]
+        return all_clips
+
+    @classmethod
+    def from_config(cls, cfg=None):
+        if cfg is None:
+            cfg = OmegaConf.create()
+
+        target_sr = cfg.get("target_sr", 16000)
+        clip_duration = cfg.get("clip_duration", 2)
+        clips_per_video = cfg.get("clips_per_video", 3)
+        num_mel_bins = cfg.get("num_mel_bins", 128)
+        target_length = cfg.get("target_length", 204)
+
+        mean = cfg.get("mean", None)
+        std = cfg.get("std", None)
+
+        return cls(
+            mean=mean, std=std, target_sr=target_sr,
+            clip_duration=clip_duration, clips_per_video=clips_per_video,
+            num_mel_bins=num_mel_bins, target_length=target_length
+        )
+
diff --git a/minigpt4/processors/imagebind_processor.py b/minigpt4/processors/imagebind_vision_processor.py
similarity index 98%
rename from minigpt4/processors/imagebind_processor.py
rename to minigpt4/processors/imagebind_vision_processor.py
index 837f499..67b92f1 100644
--- a/minigpt4/processors/imagebind_processor.py
+++ b/minigpt4/processors/imagebind_vision_processor.py
@@ -9,7 +9,7 @@ import re
 
 from minigpt4.common.registry import registry
 from minigpt4.processors.base_processor import BaseProcessor
-from minigpt4.processors.randaugment import RandomAugment
+from minigpt4.processors.vision_augment import RandomAugment
 from omegaconf import OmegaConf
 from torchvision import transforms
 from torchvision.transforms.functional import InterpolationMode
@@ -147,3 +147,5 @@ class ImageBindVisionEvalProcessor(ImageBindVisionBaseProcessor):
 
         return cls(image_size=image_size, mean=mean, std=std)
 
+
+
diff --git a/minigpt4/processors/randaugment.py b/minigpt4/processors/vision_augment.py
similarity index 100%
rename from minigpt4/processors/randaugment.py
rename to minigpt4/processors/vision_augment.py
diff --git a/requirements.txt b/requirements.txt
index eb572d3..00fb09d 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -53,7 +53,6 @@ ipdb
 tensorflow-cpu
 tensorboardX
 # mmcv==2.0.0 -f https://download.openmmlab.com/mmcv/dist/cu117/torch1.13/index.html
-
 bs4==0.0.1              # Needed for text cleaning
 bson==0.5.10
 byted-dataloader==0.3.6
@@ -68,4 +67,5 @@ tensorflow-io==0.30.0   # Needed for tensorboard hdfs support
 tqdm==4.64.1
 
 git+https://github.com/facebookresearch/segment-anything.git
-git+https://github.com/IDEA-Research/GroundingDINO.git
\ No newline at end of file
+git+https://github.com/IDEA-Research/GroundingDINO.git
+pytorchvideo==0.1.5
diff --git a/train_configs/bindgpt4_stage1_audio_pretrain.yaml b/train_configs/bindgpt4_stage1_audio_pretrain.yaml
new file mode 100644
index 0000000..9a2c023
--- /dev/null
+++ b/train_configs/bindgpt4_stage1_audio_pretrain.yaml
@@ -0,0 +1,68 @@
+model:
+  arch: bind_gpt4
+  model_type: pretrain_vicuna
+  freeze_imagebind: True
+  freeze_qformer: False
+
+datasets:
+  bbc:
+    audio_processor:
+      train:
+        name: "imagebind_audio_train"
+    text_processor:
+      train:
+        name: "imagebind_caption"
+    sample_ratio: 31
+  audioset:
+    audio_processor:
+      train:
+        name: "imagebind_audio_train"
+    text_processor:
+      train:
+        name: "imagebind_caption"
+    sample_ratio: 108
+  soundbible:
+    audio_processor:
+      train:
+        name: "imagebind_audio_train"
+    text_processor:
+      train:
+        name: "imagebind_caption"
+    sample_ratio: 2
+  freesound:
+    audio_processor:
+      train:
+        name: "imagebind_audio_train"
+    text_processor:
+      train:
+        name: "imagebind_caption"
+    sample_ratio: 262
+
+run:
+  task: image_text_pretrain
+  # optimizer
+  lr_sched: "linear_warmup_cosine_lr"
+  init_lr: 1e-4
+  min_lr: 8e-5
+  warmup_lr: 1e-6
+
+  weight_decay: 0.05
+  max_epoch: 4
+  batch_size_train: 64
+  batch_size_eval: 64
+  num_workers: 4
+  warmup_steps: 5000
+  iters_per_epoch: 5000
+
+  seed: 42
+  output_dir: "output/bindgpt4_stage1_audio_pretrain"
+
+  amp: True
+  resume_ckpt_path: null
+  evaluate: False
+  train_splits: ["train"]
+
+  device: "cuda"
+  world_size: 1
+  dist_url: "env://"
+  distributed: True
\ No newline at end of file