mirror of
https://github.com/Vision-CAIR/MiniGPT-4.git
synced 2025-04-05 18:40:46 +00:00
init audio data config (#2)
- Add audio datasets - Add audio processors - Add audio support in bindgpt - Add audio training config --------- Co-authored-by: bingyikang <bingyikang@bytedance.com> Co-authored-by: zhaoyang <913556700@qq.com>
This commit is contained in:
parent
3efda2ac76
commit
05220fe3c1
175
README.md
175
README.md
@ -1,170 +1,5 @@
|
||||
# MiniGPT-4: Enhancing Vision-language Understanding with Advanced Large Language Models
|
||||
[Deyao Zhu](https://tsutikgiau.github.io/)* (On Job Market!), [Jun Chen](https://junchen14.github.io/)* (On Job Market!), [Xiaoqian Shen](https://xiaoqian-shen.github.io), [Xiang Li](https://xiangli.ac.cn), and [Mohamed Elhoseiny](https://www.mohamed-elhoseiny.com/). *Equal Contribution
|
||||
|
||||
**King Abdullah University of Science and Technology**
|
||||
|
||||
<a href='https://minigpt-4.github.io'><img src='https://img.shields.io/badge/Project-Page-Green'></a> <a href='https://arxiv.org/abs/2304.10592'><img src='https://img.shields.io/badge/Paper-Arxiv-red'></a> <a href='https://huggingface.co/spaces/Vision-CAIR/minigpt4'><img src='https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue'></a> <a href='https://huggingface.co/Vision-CAIR/MiniGPT-4'><img src='https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Model-blue'></a> [](https://colab.research.google.com/drive/1OK4kYsZphwt5DXchKkzMBjYF6jnkqh4R?usp=sharing) [](https://www.youtube.com/watch?v=__tftoxpBAw&feature=youtu.be)
|
||||
|
||||
|
||||
## News
|
||||
We now provide a pretrained MiniGPT-4 aligned with Vicuna-7B! The demo GPU memory consumption now can be as low as 12GB.
|
||||
|
||||
|
||||
## Online Demo
|
||||
|
||||
Click the image to chat with MiniGPT-4 around your images
|
||||
[](https://minigpt-4.github.io)
|
||||
|
||||
|
||||
## Examples
|
||||
| | |
|
||||
:-------------------------:|:-------------------------:
|
||||
 | 
|
||||
 | 
|
||||
|
||||
More examples can be found in the [project page](https://minigpt-4.github.io).
|
||||
|
||||
|
||||
|
||||
## Introduction
|
||||
- MiniGPT-4 aligns a frozen visual encoder from BLIP-2 with a frozen LLM, Vicuna, using just one projection layer.
|
||||
- We train MiniGPT-4 with two stages. The first traditional pretraining stage is trained using roughly 5 million aligned image-text pairs in 10 hours using 4 A100s. After the first stage, Vicuna is able to understand the image. But the generation ability of Vicuna is heavilly impacted.
|
||||
- To address this issue and improve usability, we propose a novel way to create high-quality image-text pairs by the model itself and ChatGPT together. Based on this, we then create a small (3500 pairs in total) yet high-quality dataset.
|
||||
- The second finetuning stage is trained on this dataset in a conversation template to significantly improve its generation reliability and overall usability. To our surprise, this stage is computationally efficient and takes only around 7 minutes with a single A100.
|
||||
- MiniGPT-4 yields many emerging vision-language capabilities similar to those demonstrated in GPT-4.
|
||||
|
||||
|
||||

|
||||
|
||||
|
||||
## Getting Started
|
||||
### Installation
|
||||
|
||||
**1. Prepare the code and the environment**
|
||||
|
||||
Git clone our repository, creating a python environment and ativate it via the following command
|
||||
|
||||
```bash
|
||||
git clone https://github.com/Vision-CAIR/MiniGPT-4.git
|
||||
cd MiniGPT-4
|
||||
conda env create -f environment.yml
|
||||
conda activate minigpt4
|
||||
```
|
||||
|
||||
|
||||
**2. Prepare the pretrained Vicuna weights**
|
||||
|
||||
The current version of MiniGPT-4 is built on the v0 versoin of Vicuna-13B.
|
||||
Please refer to our instruction [here](PrepareVicuna.md)
|
||||
to prepare the Vicuna weights.
|
||||
The final weights would be in a single folder in a structure similar to the following:
|
||||
|
||||
```
|
||||
vicuna_weights
|
||||
├── config.json
|
||||
├── generation_config.json
|
||||
├── pytorch_model.bin.index.json
|
||||
├── pytorch_model-00001-of-00003.bin
|
||||
...
|
||||
```
|
||||
|
||||
Then, set the path to the vicuna weight in the model config file
|
||||
[here](minigpt4/configs/models/minigpt4.yaml#L16) at Line 16.
|
||||
|
||||
**3. Prepare the pretrained MiniGPT-4 checkpoint**
|
||||
|
||||
Download the pretrained checkpoints according to the Vicuna model you prepare.
|
||||
|
||||
| Checkpoint Aligned with Vicuna 13B | Checkpoint Aligned with Vicuna 7B |
|
||||
:------------------------------------------------------------------------------------------------:|:----------------------------------------------------------------------------------------------:
|
||||
[Downlad](https://drive.google.com/file/d/1a4zLvaiDBr-36pasffmgpvH5P7CKmpze/view?usp=share_link) | [Download](https://drive.google.com/file/d/1RY9jV0dyqLX-o38LrumkKRh6Jtaop58R/view?usp=sharing)
|
||||
|
||||
|
||||
Then, set the path to the pretrained checkpoint in the evaluation config file
|
||||
in [eval_configs/minigpt4_eval.yaml](eval_configs/minigpt4_eval.yaml#L10) at Line 11.
|
||||
|
||||
|
||||
|
||||
### Launching Demo Locally
|
||||
|
||||
Try out our demo [demo.py](eval_scripts/qualitative_eval.py) on your local machine by running
|
||||
|
||||
```
|
||||
python demo.py --cfg-path eval_configs/minigpt4_eval.yaml --gpu-id 0
|
||||
```
|
||||
|
||||
To save GPU memory, Vicuna loads as 8 bit by default, with a beam search width of 1.
|
||||
This configuration requires about 23G GPU memory for Vicuna 13B and 11.5G GPU memory for Vicuna 7B.
|
||||
For more powerful GPUs, you can run the model
|
||||
in 16 bit by setting low_resource to False in the config file
|
||||
[minigpt4_eval.yaml](eval_configs/minigpt4_eval.yaml) and use a larger beam search width.
|
||||
|
||||
Thanks [@WangRongsheng](https://github.com/WangRongsheng), you can also run our code on [Colab](https://colab.research.google.com/drive/1OK4kYsZphwt5DXchKkzMBjYF6jnkqh4R?usp=sharing)
|
||||
|
||||
|
||||
### Training
|
||||
The training of MiniGPT-4 contains two alignment stages.
|
||||
|
||||
**1. First pretraining stage**
|
||||
|
||||
In the first pretrained stage, the model is trained using image-text pairs from Laion and CC datasets
|
||||
to align the vision and language model. To download and prepare the datasets, please check
|
||||
our [first stage dataset preparation instruction](dataset/README_1_STAGE.md).
|
||||
After the first stage, the visual features are mapped and can be understood by the language
|
||||
model.
|
||||
To launch the first stage training, run the following command. In our experiments, we use 4 A100.
|
||||
You can change the save path in the config file
|
||||
[train_configs/minigpt4_stage1_pretrain.yaml](train_configs/minigpt4_stage1_pretrain.yaml)
|
||||
|
||||
```bash
|
||||
torchrun --nproc-per-node NUM_GPU train.py --cfg-path train_configs/minigpt4_stage1_pretrain.yaml
|
||||
```
|
||||
|
||||
A MiniGPT-4 checkpoint with only stage one training can be downloaded
|
||||
[here (13B)](https://drive.google.com/file/d/1u9FRRBB3VovP1HxCAlpD9Lw4t4P6-Yq8/view?usp=share_link) or [here (7B)](https://drive.google.com/file/d/1HihQtCEXUyBM1i9DQbaK934wW3TZi-h5/view?usp=share_link).
|
||||
Compared to the model after stage two, this checkpoint generate incomplete and repeated sentences frequently.
|
||||
|
||||
|
||||
**2. Second finetuning stage**
|
||||
|
||||
In the second stage, we use a small high quality image-text pair dataset created by ourselves
|
||||
and convert it to a conversation format to further align MiniGPT-4.
|
||||
To download and prepare our second stage dataset, please check our
|
||||
[second stage dataset preparation instruction](dataset/README_2_STAGE.md).
|
||||
To launch the second stage alignment,
|
||||
first specify the path to the checkpoint file trained in stage 1 in
|
||||
[train_configs/minigpt4_stage1_pretrain.yaml](train_configs/minigpt4_stage2_finetune.yaml).
|
||||
You can also specify the output path there.
|
||||
Then, run the following command. In our experiments, we use 1 A100.
|
||||
|
||||
```bash
|
||||
torchrun --nproc-per-node NUM_GPU train.py --cfg-path train_configs/minigpt4_stage2_finetune.yaml
|
||||
```
|
||||
|
||||
After the second stage alignment, MiniGPT-4 is able to talk about the image coherently and user-friendly.
|
||||
|
||||
|
||||
|
||||
|
||||
## Acknowledgement
|
||||
|
||||
+ [BLIP2](https://huggingface.co/docs/transformers/main/model_doc/blip-2) The model architecture of MiniGPT-4 follows BLIP-2. Don't forget to check this great open-source work if you don't know it before!
|
||||
+ [Lavis](https://github.com/salesforce/LAVIS) This repository is built upon Lavis!
|
||||
+ [Vicuna](https://github.com/lm-sys/FastChat) The fantastic language ability of Vicuna with only 13B parameters is just amazing. And it is open-source!
|
||||
|
||||
|
||||
If you're using MiniGPT-4 in your research or applications, please cite using this BibTeX:
|
||||
```bibtex
|
||||
@article{zhu2023minigpt,
|
||||
title={MiniGPT-4: Enhancing Vision-Language Understanding with Advanced Large Language Models},
|
||||
author={Zhu, Deyao and Chen, Jun and Shen, Xiaoqian and Li, Xiang and Elhoseiny, Mohamed},
|
||||
journal={arXiv preprint arXiv:2304.10592},
|
||||
year={2023}
|
||||
}
|
||||
```
|
||||
|
||||
|
||||
## License
|
||||
This repository is under [BSD 3-Clause License](LICENSE.md).
|
||||
Many codes are based on [Lavis](https://github.com/salesforce/LAVIS) with
|
||||
BSD 3-Clause License [here](LICENSE_Lavis.md).
|
||||
# TODO List
|
||||
- [ ] reconstruct the eval scripts
|
||||
- [ ] support audio evaluation
|
||||
- [ ] clean code repo & solve recursive import
|
||||
- [x] audio processor: random sampling
|
@ -6,7 +6,9 @@
|
||||
pip3 install --upgrade pip
|
||||
pip3 install -r requirements.txt
|
||||
pip3 install byted-dataloader -i "https://bytedpypi.byted.org/simple"
|
||||
mmengine-0.7.3
|
||||
pip3 install mmmengine==0.7.3
|
||||
pip3 install mmcv==2.0.0 -f https://download.openmmlab.com/mmcv/dist/cu117/torch1.13/index.html
|
||||
|
||||
# unset http_proxy && unset https_proxy && unset no_proxy
|
||||
|
||||
# # ----------------------------------------------------------------------------------------
|
||||
|
@ -1,4 +1,5 @@
|
||||
import torch
|
||||
import torchaudio
|
||||
from PIL import Image
|
||||
|
||||
|
||||
@ -13,3 +14,12 @@ def load_image(image, image_processor):
|
||||
if len(image.shape) == 3:
|
||||
image = image.unsqueeze(0)
|
||||
return image
|
||||
|
||||
|
||||
def load_audio(audio, audio_processor):
|
||||
if isinstance(audio, str): # is a audio path
|
||||
raw_audio = torchaudio.load(audio)
|
||||
audio = audio_processor(audio)
|
||||
# elif isinstance(audio, )
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
@ -125,6 +125,7 @@ with gr.Blocks() as demo:
|
||||
with gr.Row():
|
||||
with gr.Column(scale=0.5):
|
||||
image = gr.Image(type="pil")
|
||||
# audio = gr.Audio()
|
||||
upload_button = gr.Button(value="Upload & Start Chat", interactive=True, variant="primary")
|
||||
clear = gr.Button("Restart")
|
||||
|
||||
@ -150,7 +151,7 @@ with gr.Blocks() as demo:
|
||||
chat_state = gr.State()
|
||||
emb_list = gr.State()
|
||||
chatbot = gr.Chatbot(label='BindGPT-4')
|
||||
text_input = gr.Textbox(label='User', placeholder='Please upload your image first', interactive=False)
|
||||
text_input = gr.Textbox(label='User', placeholder='Please upload your image/audio first', interactive=False)
|
||||
|
||||
upload_button.click(upload_img, [image, text_input, chat_state],
|
||||
[image, text_input, upload_button, chat_state, emb_list])
|
||||
|
@ -15,7 +15,7 @@ import logging
|
||||
from imagebind.models.multimodal_preprocessors import SimpleTokenizer
|
||||
from PIL import Image
|
||||
from pytorchvideo import transforms as pv_transforms
|
||||
from pytorchvideo.data.clip_sampling import ConstantClipsPerVideoSampler
|
||||
from pytorchvideo.data.clip_sampling import ConstantClipsPerVideoSampler, RandomMultiClipSampler
|
||||
from pytorchvideo.data.encoded_video import EncodedVideo
|
||||
|
||||
from torchvision import transforms
|
||||
@ -59,23 +59,11 @@ def waveform2melspec(waveform, sample_rate, num_mel_bins, target_length):
|
||||
fbank = torch.nn.functional.pad(fbank, (0, p), mode="constant", value=0)
|
||||
elif p < 0:
|
||||
fbank = fbank[:, 0:target_length]
|
||||
# Convert to [1, mel_bins, num_frames] shape, essentially like a 1
|
||||
# channel image
|
||||
# Convert to [1, mel_bins, num_frames] shape, essentially like a 1 channel image
|
||||
fbank = fbank.unsqueeze(0)
|
||||
return fbank
|
||||
|
||||
|
||||
def get_clip_timepoints(clip_sampler, duration):
|
||||
# Read out all clips in this video
|
||||
all_clips_timepoints = []
|
||||
is_last_clip = False
|
||||
end = 0.0
|
||||
while not is_last_clip:
|
||||
start, end, _, _, is_last_clip = clip_sampler(end, duration, annotation=None)
|
||||
all_clips_timepoints.append((start, end))
|
||||
return all_clips_timepoints
|
||||
|
||||
|
||||
def load_and_transform_vision_data(image_paths, device):
|
||||
if image_paths is None:
|
||||
return None
|
||||
@ -137,14 +125,14 @@ def load_and_transform_audio_data(
|
||||
waveform = torchaudio.functional.resample(
|
||||
waveform, orig_freq=sr, new_freq=sample_rate
|
||||
)
|
||||
all_clips_timepoints = get_clip_timepoints(
|
||||
all_clips_timepoints = get_constant_clip_timepoints(
|
||||
clip_sampler, waveform.size(1) / sample_rate
|
||||
)
|
||||
all_clips = []
|
||||
for clip_timepoints in all_clips_timepoints:
|
||||
waveform_clip = waveform[
|
||||
:,
|
||||
int(clip_timepoints[0] * sample_rate) : int(
|
||||
int(clip_timepoints[0] * sample_rate): int(
|
||||
clip_timepoints[1] * sample_rate
|
||||
),
|
||||
]
|
||||
@ -162,7 +150,8 @@ def load_and_transform_audio_data(
|
||||
return torch.stack(audio_outputs, dim=0)
|
||||
|
||||
|
||||
def get_clip_timepoints(clip_sampler, duration):
|
||||
def get_constant_clip_timepoints(clip_sampler, duration):
|
||||
assert isinstance(clip_sampler, ConstantClipsPerVideoSampler), "Incompatible Type of Sampler!"
|
||||
# Read out all clips in this video
|
||||
all_clips_timepoints = []
|
||||
is_last_clip = False
|
||||
@ -173,6 +162,13 @@ def get_clip_timepoints(clip_sampler, duration):
|
||||
return all_clips_timepoints
|
||||
|
||||
|
||||
def get_random_clip_timepoints(clip_sampler, duration):
|
||||
assert isinstance(clip_sampler, RandomMultiClipSampler), "Incompatible Type of Sampler!"
|
||||
starts, ends, _, _, _ = clip_sampler(0.0, duration, annotation=None)
|
||||
all_clips_timepoints = sorted(list(zip(starts, ends)), key=lambda x: x[0])
|
||||
return all_clips_timepoints
|
||||
|
||||
|
||||
def crop_boxes(boxes, x_offset, y_offset):
|
||||
"""
|
||||
Perform crop on the bounding boxes given the offsets.
|
||||
@ -244,7 +240,7 @@ def uniform_crop(images, size, spatial_idx, boxes=None, scale_size=None):
|
||||
x_offset = 0
|
||||
elif spatial_idx == 2:
|
||||
x_offset = width - size
|
||||
cropped = images[:, :, y_offset : y_offset + size, x_offset : x_offset + size]
|
||||
cropped = images[:, :, y_offset: y_offset + size, x_offset: x_offset + size]
|
||||
cropped_boxes = crop_boxes(boxes, x_offset, y_offset) if boxes is not None else None
|
||||
if ndim == 3:
|
||||
cropped = cropped.squeeze(0)
|
||||
@ -328,7 +324,7 @@ def load_and_transform_video_data(
|
||||
**{"sample_rate": sample_rate},
|
||||
)
|
||||
|
||||
all_clips_timepoints = get_clip_timepoints(clip_sampler, video.duration)
|
||||
all_clips_timepoints = get_constant_clip_timepoints(clip_sampler, video.duration)
|
||||
|
||||
all_video = []
|
||||
for clip_timepoints in all_clips_timepoints:
|
||||
|
@ -50,46 +50,59 @@ ModalityType = SimpleNamespace(
|
||||
class ImageBindJoiner(nn.Module):
|
||||
def __init__(self,
|
||||
vision_query_token_num: int,
|
||||
audio_query_token_num: int,
|
||||
vision_qformer_frozen: bool = False,
|
||||
vision_qformer_model: str = "", # The url or path of pre-trained vision Q-Former model
|
||||
vision_pre_dims: List[int] = (), # Projection before Q-Former
|
||||
vision_post_dims: List[int] = (768, 768) # Projection after Q-Former
|
||||
vision_post_dims: List[int] = (1280, 768), # Projection after Q-Former
|
||||
audio_pre_dims: List[int] = (),
|
||||
audio_post_dims: List[int] = (768, 768)
|
||||
):
|
||||
super().__init__()
|
||||
assert not (vision_qformer_frozen and vision_qformer_model == "")
|
||||
self.modality_pre_projectors = self._create_modality_pre_projectors(vision_pre_dims)
|
||||
self.modality_pre_projectors = self._create_modality_pre_projectors(vision_pre_dims, audio_pre_dims)
|
||||
self.modality_qformers = self._create_modality_qformers(vision_query_token_num,
|
||||
vision_qformer_frozen,
|
||||
vision_qformer_model)
|
||||
self.modality_post_projectors = self._create_modality_post_projectors(vision_post_dims)
|
||||
vision_qformer_model,
|
||||
audio_query_token_num)
|
||||
self.modality_post_projectors = self._create_modality_post_projectors(vision_post_dims, audio_post_dims)
|
||||
|
||||
def _create_modality_pre_projectors(self,
|
||||
vision_pre_dims
|
||||
vision_pre_dims,
|
||||
audio_pre_dims
|
||||
):
|
||||
modality_pre_projectors = {
|
||||
ModalityType.VISION: create_projectors(vision_pre_dims)
|
||||
ModalityType.VISION: create_projectors(vision_pre_dims),
|
||||
ModalityType.AUDIO: create_projectors(audio_pre_dims)
|
||||
}
|
||||
return modality_pre_projectors
|
||||
|
||||
def _create_modality_qformers(self,
|
||||
vision_query_token_num,
|
||||
vision_qformer_frozen,
|
||||
vision_qformer_model
|
||||
vision_qformer_model,
|
||||
audio_query_token_num
|
||||
):
|
||||
vision_qformer = SequenceGenericQFormer(num_query_token=vision_query_token_num,
|
||||
freeze_qformer=vision_qformer_frozen,
|
||||
encoder_width=1280, # TODO: fix hard-coding
|
||||
q_former_model=vision_qformer_model)
|
||||
audio_qformer = SequenceGenericQFormer(num_query_token=audio_query_token_num,
|
||||
freeze_qformer=False,
|
||||
encoder_width=768)
|
||||
modality_qformers = {
|
||||
ModalityType.VISION: vision_qformer
|
||||
ModalityType.VISION: vision_qformer,
|
||||
ModalityType.AUDIO: audio_qformer
|
||||
}
|
||||
|
||||
return nn.ModuleDict(modality_qformers)
|
||||
|
||||
def _create_modality_post_projectors(self, vision_post_dims):
|
||||
def _create_modality_post_projectors(self, vision_post_dims, audio_post_dims):
|
||||
vision_projector = create_projectors(vision_post_dims)
|
||||
audio_projector = create_projectors(audio_post_dims)
|
||||
modality_projectors = {
|
||||
ModalityType.VISION: vision_projector
|
||||
ModalityType.VISION: vision_projector,
|
||||
ModalityType.AUDIO: audio_projector
|
||||
}
|
||||
|
||||
return nn.ModuleDict(modality_projectors)
|
||||
@ -97,7 +110,6 @@ class ImageBindJoiner(nn.Module):
|
||||
def forward(self, inputs: Dict[str, Tensor]) -> Dict[str, Tensor]:
|
||||
outputs = {}
|
||||
for modality_key, modality_value in inputs.items():
|
||||
# assert modality_key == ModalityType.VISION, "Only Vision is Currently Supported."
|
||||
if modality_value is not None:
|
||||
modality_value = self.modality_pre_projectors[modality_key](modality_value)
|
||||
modality_value = self.modality_qformers[modality_key](modality_value)
|
||||
@ -544,7 +556,7 @@ class ImageBindModel(nn.Module):
|
||||
|
||||
# NOTE: The reduction operation has been modified.
|
||||
if reduce_list:
|
||||
modality_value = modality_value.reshape(B, S, *modality_value[2:])
|
||||
modality_value = modality_value.reshape(B, S, *modality_value.shape[1:])
|
||||
modality_value = modality_value.mean(dim=1)
|
||||
|
||||
outputs[modality_key] = modality_value
|
||||
|
@ -32,10 +32,12 @@ class Registry:
|
||||
"""
|
||||
|
||||
def wrap(builder_cls):
|
||||
# TODO: merge them or split builders by modality
|
||||
from minigpt4.datasets.builders.image_base_dataset_builder import ImageBaseDatasetBuilder
|
||||
from minigpt4.datasets.builders.audio_base_dataset_builder import AudioBaseDatasetBuilder
|
||||
|
||||
assert issubclass(
|
||||
builder_cls, ImageBaseDatasetBuilder
|
||||
builder_cls, (ImageBaseDatasetBuilder, AudioBaseDatasetBuilder)
|
||||
), "All builders must inherit BaseDatasetBuilder class, found {}".format(
|
||||
builder_cls
|
||||
)
|
||||
|
5
minigpt4/configs/datasets/audioset/defaults.yaml
Normal file
5
minigpt4/configs/datasets/audioset/defaults.yaml
Normal file
@ -0,0 +1,5 @@
|
||||
datasets:
|
||||
audioset:
|
||||
data_type: audio
|
||||
build_info:
|
||||
storage: /mnt/bn/zilongdata-hl/dataset/wavcaps/web_datasets/AudioSet_SL/AudioSet_SL{00..54}.tar
|
5
minigpt4/configs/datasets/bbc/defaults.yaml
Normal file
5
minigpt4/configs/datasets/bbc/defaults.yaml
Normal file
@ -0,0 +1,5 @@
|
||||
datasets:
|
||||
bbc:
|
||||
data_type: audio
|
||||
build_info:
|
||||
storage: /mnt/bn/zilongdata-hl/dataset/wavcaps/web_datasets/BBC_Sound_Effects/BBC_Sound_Effects{000000..000062}.tar
|
5
minigpt4/configs/datasets/freesound/defaults.yaml
Normal file
5
minigpt4/configs/datasets/freesound/defaults.yaml
Normal file
@ -0,0 +1,5 @@
|
||||
datasets:
|
||||
freesound:
|
||||
data_type: audio
|
||||
build_info:
|
||||
storage: /mnt/bn/zilongdata-hl/dataset/wavcaps/web_datasets/FreeSound/FreeSound{000000..000524}.tar
|
5
minigpt4/configs/datasets/soundbible/defaults.yaml
Normal file
5
minigpt4/configs/datasets/soundbible/defaults.yaml
Normal file
@ -0,0 +1,5 @@
|
||||
datasets:
|
||||
soundbible:
|
||||
data_type: audio
|
||||
build_info:
|
||||
storage: /mnt/bn/zilongdata-hl/dataset/wavcaps/web_datasets/SoundBible/SoundBible0.tar
|
@ -11,12 +11,23 @@ from minigpt4.datasets.builders.image_text_pair_builder import (
|
||||
LaionBuilderImage,
|
||||
CCSBUAlignBuilderImage
|
||||
)
|
||||
from minigpt4.datasets.builders.audio_text_pair_builder import (
|
||||
BBCBuilder,
|
||||
AudioSetBuilder,
|
||||
SoundBibleBuilder,
|
||||
FreeSoundBuilder
|
||||
)
|
||||
from minigpt4.common.registry import registry
|
||||
|
||||
__all__ = [
|
||||
"CCSBUBuilderImage",
|
||||
"LaionBuilderImage",
|
||||
"CCSBUAlignBuilderImage"
|
||||
"CCSBUAlignBuilderImage",
|
||||
# Audio builders
|
||||
"BBCBuilder",
|
||||
"AudioSetBuilder",
|
||||
"SoundBibleBuilder",
|
||||
"FreeSoundBuilder",
|
||||
]
|
||||
|
||||
|
||||
|
@ -31,6 +31,51 @@ class AudioBaseDatasetBuilder:
|
||||
|
||||
self.data_type = self.config.data_type
|
||||
|
||||
self.audio_processors = {"train": BaseProcessor(), "eval": BaseProcessor()}
|
||||
self.text_processors = {"train": BaseProcessor(), "eval": BaseProcessor()}
|
||||
|
||||
def build_datasets(self):
|
||||
# download, split, etc...
|
||||
# only called on 1 GPU/TPU in distributed
|
||||
|
||||
if is_main_process():
|
||||
self._download_data()
|
||||
|
||||
if is_dist_avail_and_initialized():
|
||||
dist.barrier()
|
||||
|
||||
# at this point, all the annotations and image/videos should be all downloaded to the specified locations.
|
||||
logging.info("Building datasets...")
|
||||
datasets = self.build() # dataset['train'/'val'/'test']
|
||||
|
||||
return datasets
|
||||
|
||||
def build_processors(self):
|
||||
aud_proc_cfg = self.config.get("audio_processor")
|
||||
txt_proc_cfg = self.config.get("text_processor")
|
||||
|
||||
if aud_proc_cfg is not None:
|
||||
aud_train_cfg = aud_proc_cfg.get("train")
|
||||
aud_eval_cfg = aud_proc_cfg.get("eval")
|
||||
|
||||
self.audio_processors["train"] = self._build_proc_from_cfg(aud_train_cfg)
|
||||
self.audio_processors["eval"] = self._build_proc_from_cfg(aud_eval_cfg)
|
||||
|
||||
if txt_proc_cfg is not None:
|
||||
txt_train_cfg = txt_proc_cfg.get("train")
|
||||
txt_eval_cfg = txt_proc_cfg.get("eval")
|
||||
|
||||
self.text_processors["train"] = self._build_proc_from_cfg(txt_train_cfg)
|
||||
self.text_processors["eval"] = self._build_proc_from_cfg(txt_eval_cfg)
|
||||
|
||||
@staticmethod
|
||||
def _build_proc_from_cfg(cfg):
|
||||
return (
|
||||
registry.get_processor_class(cfg.name).from_config(cfg)
|
||||
if cfg is not None
|
||||
else None
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def default_config_path(cls, type="default"):
|
||||
return utils.get_abs_path(cls.DATASET_CONFIG_DICT[type])
|
||||
@ -95,5 +140,3 @@ class AudioBaseDatasetBuilder:
|
||||
filename = os.path.basename(storage_path)
|
||||
|
||||
download_url(url=url_or_filename, root=dirname, filename=filename)
|
||||
|
||||
|
||||
|
56
minigpt4/datasets/builders/audio_text_pair_builder.py
Normal file
56
minigpt4/datasets/builders/audio_text_pair_builder.py
Normal file
@ -0,0 +1,56 @@
|
||||
import os
|
||||
import logging
|
||||
import warnings
|
||||
|
||||
from minigpt4.common.registry import registry
|
||||
from minigpt4.datasets.builders.audio_base_dataset_builder import AudioBaseDatasetBuilder
|
||||
from minigpt4.datasets.datasets.audio_caption import GenericAudioDataset
|
||||
|
||||
|
||||
|
||||
class GenericAudioBuilder(AudioBaseDatasetBuilder):
|
||||
train_dataset_cls = GenericAudioDataset
|
||||
|
||||
def _download_ann(self):
|
||||
pass
|
||||
|
||||
def _download_aud(self):
|
||||
pass
|
||||
|
||||
def build(self):
|
||||
self.build_processors()
|
||||
|
||||
build_info = self.config.build_info
|
||||
|
||||
datasets = dict()
|
||||
split = "train"
|
||||
|
||||
# create datasets
|
||||
dataset_cls = self.train_dataset_cls
|
||||
datasets[split] = dataset_cls(
|
||||
audio_processor=self.audio_processors[split],
|
||||
text_processor=self.text_processors[split],
|
||||
location=build_info.storage,
|
||||
).inner_dataset
|
||||
|
||||
return datasets
|
||||
|
||||
|
||||
@registry.register_builder("bbc")
|
||||
class BBCBuilder(GenericAudioBuilder):
|
||||
DATASET_CONFIG_DICT = {"default": "configs/datasets/bbc/defaults.yaml"}
|
||||
|
||||
|
||||
@registry.register_builder("audioset")
|
||||
class AudioSetBuilder(GenericAudioBuilder):
|
||||
DATASET_CONFIG_DICT = {"default": "configs/datasets/audioset/defaults.yaml"}
|
||||
|
||||
|
||||
@registry.register_builder("soundbible")
|
||||
class SoundBibleBuilder(GenericAudioBuilder):
|
||||
DATASET_CONFIG_DICT = {"default": "configs/datasets/soundbible/defaults.yaml"}
|
||||
|
||||
|
||||
@registry.register_builder("freesound")
|
||||
class FreeSoundBuilder(GenericAudioBuilder):
|
||||
DATASET_CONFIG_DICT = {"default": "configs/datasets/freesound/defaults.yaml"}
|
@ -0,0 +1 @@
|
||||
from minigpt4.datasets.datasets.audio_caption.audio_caption_datasets import GenericAudioDataset
|
@ -6,21 +6,22 @@ from minigpt4.datasets.datasets.base_dataset import BaseDataset
|
||||
|
||||
|
||||
class GenericAudioDataset(BaseDataset):
|
||||
def __init__(self, vision_processor, text_processor, location):
|
||||
super().__init__(x_processor=vision_processor, text_processor=text_processor)
|
||||
def __init__(self, audio_processor, text_processor, location):
|
||||
super().__init__(x_processor=audio_processor, text_processor=text_processor)
|
||||
|
||||
self.inner_dataset = wds.DataPipeline(
|
||||
wds.ResampledShards(location),
|
||||
wds.tarfile_to_samples(handler=wds.warn_and_continue),
|
||||
wds.shuffle(1000, handler=wds.warn_and_continue),
|
||||
wds.decode(wds.torch_audio, handler=wds.warn_and_continue),
|
||||
wds.to_tuple("jpg", "json", handler=wds.warn_and_continue),
|
||||
wds.to_tuple("flac", "json", handler=wds.warn_and_continue),
|
||||
wds.map_tuple(self.x_processor, handler=wds.warn_and_continue),
|
||||
wds.map(self.to_dict, handler=wds.warn_and_continue),
|
||||
)
|
||||
|
||||
def to_dict(self, sample):
|
||||
return {
|
||||
"image": sample[0],
|
||||
"audio": sample[0],
|
||||
# [clips_per_video, channel, mel_bins, time_steps]
|
||||
"text_input": self.text_processor(sample[1]["caption"]),
|
||||
}
|
||||
|
@ -21,7 +21,7 @@ class CCSBUDataset(BaseDataset):
|
||||
|
||||
def to_dict(self, sample):
|
||||
return {
|
||||
"image": sample[0],
|
||||
"vision": sample[0],
|
||||
"text_input": self.text_processor(sample[1]["caption"]),
|
||||
}
|
||||
|
||||
@ -41,7 +41,7 @@ class CCSBUAlignDatasetImageImageCaptionDataset(ImageCaptionDataset):
|
||||
caption = ann["caption"]
|
||||
|
||||
return {
|
||||
"image": image,
|
||||
"vision": image,
|
||||
"text_input": caption,
|
||||
"image_id": self.img_ids[ann["image_id"]],
|
||||
}
|
||||
@ -49,7 +49,7 @@ class CCSBUAlignDatasetImageImageCaptionDataset(ImageCaptionDataset):
|
||||
|
||||
class CCDataset(BaseDataset):
|
||||
def __init__(self, vis_processor, text_processor, location):
|
||||
super().__init__(vis_processor=vis_processor, text_processor=text_processor)
|
||||
super().__init__(x_processor=vis_processor, text_processor=text_processor)
|
||||
|
||||
self.inner_dataset = wds.DataPipeline(
|
||||
wds.ResampledShards(location),
|
||||
@ -57,12 +57,12 @@ class CCDataset(BaseDataset):
|
||||
wds.shuffle(1000, handler=wds.warn_and_continue),
|
||||
wds.decode("pilrgb", handler=wds.warn_and_continue),
|
||||
wds.to_tuple("jpg", "txt", handler=wds.warn_and_continue),
|
||||
wds.map_tuple(self.vis_processor, handler=wds.warn_and_continue),
|
||||
wds.map_tuple(self.x_processor, handler=wds.warn_and_continue),
|
||||
wds.map(self.to_dict, handler=wds.warn_and_continue),
|
||||
)
|
||||
|
||||
def to_dict(self, sample):
|
||||
return {
|
||||
"image": sample[0],
|
||||
"vision": sample[0],
|
||||
"text_input": sample[1],
|
||||
}
|
||||
|
@ -42,7 +42,7 @@ class ImageCaptionDataset(BaseDataset, __ImageDisplMixin):
|
||||
caption = self.text_processor(ann["caption"])
|
||||
|
||||
return {
|
||||
"image": image,
|
||||
"vision": image,
|
||||
"text_input": caption,
|
||||
"image_id": self.img_ids[ann["image_id"]],
|
||||
}
|
||||
@ -67,7 +67,7 @@ class CaptionEvalDataset(BaseDataset, __ImageDisplMixin):
|
||||
image = self.x_processor(image)
|
||||
|
||||
return {
|
||||
"image": image,
|
||||
"vision": image,
|
||||
"image_id": ann["image_id"],
|
||||
"instance_id": ann["instance_id"],
|
||||
}
|
||||
|
@ -25,7 +25,7 @@ class LaionDataset(BaseDataset):
|
||||
|
||||
def to_dict(self, sample):
|
||||
return {
|
||||
"image": sample[0],
|
||||
"vision": sample[0],
|
||||
"text_input": self.text_processor(sample[1]["caption"]),
|
||||
}
|
||||
|
||||
|
@ -9,7 +9,7 @@ class __ImageDisplMixin:
|
||||
{
|
||||
"file": ann["image"],
|
||||
"caption": ann["caption"],
|
||||
"image": sample["image"],
|
||||
"vision": sample["vision"],
|
||||
}
|
||||
)
|
||||
|
||||
|
@ -1,8 +1,9 @@
|
||||
import random
|
||||
from typing import Dict, Tuple
|
||||
from typing import Dict, Tuple, List
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import re
|
||||
from torch import Tensor
|
||||
from transformers import LlamaTokenizer
|
||||
|
||||
@ -12,6 +13,36 @@ from minigpt4.models.blip2 import BaseModel
|
||||
from minigpt4.models.modeling_llama import LlamaForCausalLM
|
||||
|
||||
|
||||
def filter_prompt(input_embeds: Dict[str, Tensor], prompt_list: List[str]) -> List[str]:
|
||||
if not prompt_list:
|
||||
return prompt_list
|
||||
input_modal_set = set([k.title() for k in input_embeds if input_embeds[k] is not None])
|
||||
prompt_modal_sets = [set(re.findall("<([^<>]+)><ModalityHere></\\1>", prompt)) for prompt in prompt_list]
|
||||
results = [prompt_list[i] for i, prompt_modal_set in enumerate(prompt_modal_sets) if
|
||||
prompt_modal_set == input_modal_set]
|
||||
return results
|
||||
|
||||
|
||||
def arrange_modalities(input_embeds: Dict[str, Tensor], prompt: str) -> List[Tensor]:
|
||||
prompt_modalities = re.findall("<([^<>]+)><ModalityHere></\\1>", prompt)
|
||||
return [input_embeds[modality.lower()] for modality in prompt_modalities]
|
||||
|
||||
|
||||
def concat_all_embeddings(input_embeds: Dict[str, Tensor], dim: int) -> Tensor:
|
||||
embeds = [input_embeds[key] for key in input_embeds if input_embeds[key] is not None]
|
||||
return torch.cat(embeds, dim=dim)
|
||||
|
||||
|
||||
def filter_modalities(inputs):
|
||||
filtered_inputs = {}
|
||||
|
||||
for k in ModalityType.__dict__.values():
|
||||
if k in inputs:
|
||||
filtered_inputs[k] = inputs[k]
|
||||
|
||||
return filtered_inputs
|
||||
|
||||
|
||||
@registry.register_model("bind_gpt4")
|
||||
class BindGPT4(BaseModel):
|
||||
"""
|
||||
@ -61,7 +92,9 @@ class BindGPT4(BaseModel):
|
||||
print('Loading Q-Former and Adapter/Projector')
|
||||
self.multimodal_joiner = ImageBindJoiner(vision_query_token_num=num_query_token,
|
||||
vision_qformer_frozen=freeze_qformer,
|
||||
vision_post_dims=[768, self.llama_model.config.hidden_size]
|
||||
vision_post_dims=[768, self.llama_model.config.hidden_size],
|
||||
audio_query_token_num=num_query_token,
|
||||
audio_post_dims=[768, self.llama_model.config.hidden_size]
|
||||
# vision_qformer_model=q_former_model,
|
||||
# vision_pre_dims=(1280, 1408)
|
||||
)
|
||||
@ -84,42 +117,37 @@ class BindGPT4(BaseModel):
|
||||
def encode_inputs(self, inputs: Dict[str, Tensor]) -> Dict[str, Tensor]:
|
||||
imagebind_outputs = self.multimodal_encoder(inputs)
|
||||
llama_inputs = self.multimodal_joiner(imagebind_outputs)
|
||||
# NOTE: only accept image here
|
||||
return llama_inputs
|
||||
|
||||
def prompt_wrap(self, inputs: Dict[str, Tensor], modality_name: str, prompt: str) -> Tuple[Tensor, Tensor]:
|
||||
# TODO: Accept More Modalities.
|
||||
input_embeds = inputs[modality_name]
|
||||
def prompt_wrap(self, inputs: Dict[str, Tensor], prompt: str) -> Tuple[Tensor, Tensor]:
|
||||
if not prompt:
|
||||
input_embeds = concat_all_embeddings(inputs, dim=1)
|
||||
attns_input = torch.ones(input_embeds.size()[:-1], dtype=torch.long).to(input_embeds.device)
|
||||
if prompt:
|
||||
batch_size = input_embeds.shape[0]
|
||||
p_before, p_after = prompt.split('<ModalityHere>')
|
||||
p_before_tokens = self.llama_tokenizer(
|
||||
p_before, return_tensors="pt", add_special_tokens=False).to(input_embeds.device)
|
||||
p_after_tokens = self.llama_tokenizer(
|
||||
p_after, return_tensors="pt", add_special_tokens=False).to(input_embeds.device)
|
||||
p_before_embeds = self.llama_model.model.embed_tokens(p_before_tokens.input_ids).expand(batch_size, -1, -1)
|
||||
p_after_embeds = self.llama_model.model.embed_tokens(p_after_tokens.input_ids).expand(batch_size, -1, -1)
|
||||
wrapped_input_embeds = torch.cat([p_before_embeds, inputs, p_after_embeds], dim=1)
|
||||
wrapped_atts_input = attns_input[:, :1].expand(-1, wrapped_input_embeds.shape[1])
|
||||
return wrapped_input_embeds, wrapped_atts_input
|
||||
else:
|
||||
return input_embeds, attns_input
|
||||
input_embeds_list = arrange_modalities(inputs, prompt)
|
||||
batch_size = input_embeds_list[0].shape[0]
|
||||
prompt_slices = prompt.split('<ModalityHere>')
|
||||
prompt_tokens = [self.llama_tokenizer(prompt_slice, return_tensors="pt", add_special_tokens=False)
|
||||
.to(input_embeds_list[0].device) for prompt_slice in prompt_slices]
|
||||
prompt_embeds = [self.llama_model.model.embed_tokens(prompt_token.input_ids).expand(batch_size, -1, -1)
|
||||
for prompt_token in prompt_tokens]
|
||||
result_embeds = [emb for pair in zip(prompt_embeds[:-1], input_embeds_list)
|
||||
for emb in pair] + [prompt_embeds[-1]]
|
||||
wrapped_input_embeds = torch.cat(result_embeds, dim=1)
|
||||
wrapped_atts_input = torch.ones(wrapped_input_embeds.size()[:-1],
|
||||
dtype=torch.long).to(wrapped_input_embeds.device)
|
||||
return wrapped_input_embeds, wrapped_atts_input
|
||||
|
||||
def forward(self, inputs: Dict[str, Tensor]) -> Dict[str, Tensor]:
|
||||
"""
|
||||
TODO: More Modalities.
|
||||
Only accept image inputs here.
|
||||
Other modalities will conflict with the pre-defined prompt and wrapping strategy.
|
||||
"""
|
||||
bind_inputs = {ModalityType.VISION: inputs['image']}
|
||||
embeds = self.encode_inputs(bind_inputs)
|
||||
# assert "vision" in embeds, "Only Vision Input Can Be Accepted Now."
|
||||
if self.prompt_list:
|
||||
prompt = random.choice(self.prompt_list)
|
||||
# filter `inputs` as it may contain informatioins other than modalities
|
||||
modality_inputs = filter_modalities(inputs)
|
||||
embeds = self.encode_inputs(modality_inputs)
|
||||
filtered_prompts = filter_prompt(embeds, self.prompt_list)
|
||||
if filtered_prompts:
|
||||
prompt = random.choice(filtered_prompts)
|
||||
else:
|
||||
prompt = None
|
||||
img_embeds, atts_img = self.prompt_wrap(embeds, ModalityType.VISION, prompt)
|
||||
input_embs, input_atts = self.prompt_wrap(embeds, prompt)
|
||||
|
||||
# NOTE: No modifications from the next line to the end. Except for the autocast part.
|
||||
|
||||
@ -134,28 +162,28 @@ class BindGPT4(BaseModel):
|
||||
truncation=True,
|
||||
max_length=self.max_txt_len,
|
||||
add_special_tokens=False
|
||||
).to(img_embeds.device)
|
||||
).to(input_embs.device)
|
||||
|
||||
targets = to_regress_tokens.input_ids.masked_fill(
|
||||
to_regress_tokens.input_ids == self.llama_tokenizer.pad_token_id, -100
|
||||
)
|
||||
|
||||
empty_targets = (
|
||||
torch.ones([atts_img.shape[0], atts_img.shape[1] + 1],
|
||||
dtype=torch.long).to(img_embeds.device).fill_(-100) # plus one for bos
|
||||
torch.ones([input_atts.shape[0], input_atts.shape[1] + 1],
|
||||
dtype=torch.long).to(input_embs.device).fill_(-100) # plus one for bos
|
||||
)
|
||||
targets = torch.cat([empty_targets, targets], dim=1)
|
||||
|
||||
batch_size = img_embeds.shape[0]
|
||||
batch_size = input_embs.shape[0]
|
||||
bos = torch.ones([batch_size, 1],
|
||||
dtype=to_regress_tokens.input_ids.dtype,
|
||||
device=to_regress_tokens.input_ids.device) * self.llama_tokenizer.bos_token_id
|
||||
bos_embeds = self.llama_model.model.embed_tokens(bos)
|
||||
atts_bos = atts_img[:, :1]
|
||||
atts_bos = input_atts[:, :1]
|
||||
|
||||
to_regress_embeds = self.llama_model.model.embed_tokens(to_regress_tokens.input_ids)
|
||||
inputs_embeds = torch.cat([bos_embeds, img_embeds, to_regress_embeds], dim=1)
|
||||
attention_mask = torch.cat([atts_bos, atts_img, to_regress_tokens.attention_mask], dim=1)
|
||||
inputs_embeds = torch.cat([bos_embeds, input_embs, to_regress_embeds], dim=1)
|
||||
attention_mask = torch.cat([atts_bos, input_atts, to_regress_tokens.attention_mask], dim=1)
|
||||
|
||||
outputs = self.llama_model(
|
||||
inputs_embeds=inputs_embeds,
|
||||
|
@ -144,7 +144,7 @@ def compute_sim_matrix(model, data_loader, **kwargs):
|
||||
vit_feats = []
|
||||
image_embeds = []
|
||||
for samples in data_loader:
|
||||
image = samples["image"]
|
||||
image = samples["vision"]
|
||||
|
||||
image = image.to(model.device)
|
||||
image_feat, vit_feat = model.forward_image(image)
|
||||
|
@ -164,7 +164,7 @@ class MiniGPT4(Blip2Base):
|
||||
return img_embeds, atts_img
|
||||
|
||||
def forward(self, samples):
|
||||
image = samples["image"]
|
||||
image = samples["vision"]
|
||||
img_embeds, atts_img = self.encode_img(image)
|
||||
if hasattr(samples, 'question_split'): # VQA dataset
|
||||
print('VQA Batch')
|
||||
|
@ -11,11 +11,15 @@ from minigpt4.processors.blip_processors import (
|
||||
Blip2ImageEvalProcessor,
|
||||
BlipCaptionProcessor,
|
||||
)
|
||||
from minigpt4.processors.imagebind_processor import (
|
||||
from minigpt4.processors.imagebind_vision_processor import (
|
||||
ImageBindCaptionProcessor,
|
||||
ImageBindVisionTrainProcessor,
|
||||
ImageBindVisionEvalProcessor
|
||||
)
|
||||
from minigpt4.processors.imagebind_audio_processor import (
|
||||
ImageBindAudioTrainProcessor,
|
||||
ImageBindAudioEvalProcessor,
|
||||
)
|
||||
|
||||
from minigpt4.common.registry import registry
|
||||
|
||||
@ -26,7 +30,9 @@ __all__ = [
|
||||
"BlipCaptionProcessor",
|
||||
"ImageBindCaptionProcessor",
|
||||
"ImageBindVisionTrainProcessor",
|
||||
"ImageBindVisionEvalProcessor"
|
||||
"ImageBindVisionEvalProcessor",
|
||||
"ImageBindAudioTrainProcessor",
|
||||
"ImageBindAudioEvalProcessor",
|
||||
]
|
||||
|
||||
|
||||
|
190
minigpt4/processors/audio_augment.py
Normal file
190
minigpt4/processors/audio_augment.py
Normal file
@ -0,0 +1,190 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Author : Xinhao Mei @CVSSP, University of Surrey
|
||||
# @E-mail : x.mei@surrey.ac.uk
|
||||
|
||||
"""
|
||||
Implemenation of SpecAugment++,
|
||||
Adapated from Qiuqiang Kong's trochlibrosa:
|
||||
https://github.com/qiuqiangkong/torchlibrosa/blob/master/torchlibrosa/augmentation.py
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class DropStripes:
|
||||
|
||||
def __init__(self, dim, drop_width, stripes_num):
|
||||
""" Drop stripes.
|
||||
args:
|
||||
dim: int, dimension along which to drop
|
||||
drop_width: int, maximum width of stripes to drop
|
||||
stripes_num: int, how many stripes to drop
|
||||
"""
|
||||
super(DropStripes, self).__init__()
|
||||
|
||||
assert dim in [2, 3] # dim 2: time; dim 3: frequency
|
||||
|
||||
self.dim = dim
|
||||
self.drop_width = drop_width
|
||||
self.stripes_num = stripes_num
|
||||
|
||||
def __call__(self, input):
|
||||
"""input: (batch_size, channels, time_steps, freq_bins)"""
|
||||
|
||||
assert input.ndimension() == 4
|
||||
batch_size = input.shape[0]
|
||||
total_width = input.shape[self.dim]
|
||||
|
||||
for n in range(batch_size):
|
||||
self.transform_slice(input[n], total_width)
|
||||
|
||||
return input
|
||||
|
||||
def transform_slice(self, e, total_width):
|
||||
""" e: (channels, time_steps, freq_bins)"""
|
||||
|
||||
for _ in range(self.stripes_num):
|
||||
distance = torch.randint(low=0, high=self.drop_width, size=(1,))[0]
|
||||
bgn = torch.randint(low=0, high=total_width - distance, size=(1,))[0]
|
||||
|
||||
if self.dim == 2:
|
||||
e[:, bgn: bgn + distance, :] = 0
|
||||
elif self.dim == 3:
|
||||
e[:, :, bgn: bgn + distance] = 0
|
||||
|
||||
|
||||
class MixStripes:
|
||||
|
||||
def __init__(self, dim, mix_width, stripes_num):
|
||||
""" Mix stripes
|
||||
args:
|
||||
dim: int, dimension along which to mix
|
||||
mix_width: int, maximum width of stripes to mix
|
||||
stripes_num: int, how many stripes to mix
|
||||
"""
|
||||
|
||||
super(MixStripes, self).__init__()
|
||||
|
||||
assert dim in [2, 3]
|
||||
|
||||
self.dim = dim
|
||||
self.mix_width = mix_width
|
||||
self.stripes_num = stripes_num
|
||||
|
||||
def __call__(self, input):
|
||||
"""input: (batch_size, channel, time_steps, freq_bins)"""
|
||||
|
||||
assert input.ndimension() == 4
|
||||
|
||||
batch_size = input.shape[0]
|
||||
total_width = input.shape[self.dim]
|
||||
|
||||
rand_sample = input[torch.randperm(batch_size)]
|
||||
for i in range(batch_size):
|
||||
self.transform_slice(input[i], rand_sample[i], total_width)
|
||||
return input
|
||||
|
||||
def transform_slice(self, input, random_sample, total_width):
|
||||
|
||||
for _ in range(self.stripes_num):
|
||||
distance = torch.randint(low=0, high=self.mix_width, size=(1,))[0]
|
||||
bgn = torch.randint(low=0, high=total_width - distance, size=(1,))[0]
|
||||
|
||||
if self.dim == 2:
|
||||
input[:, bgn: bgn + distance, :] = 0.5 * input[:, bgn: bgn + distance, :] + \
|
||||
0.5 * random_sample[:, bgn: bgn + distance, :]
|
||||
elif self.dim == 3:
|
||||
input[:, :, bgn: bgn + distance] = 0.5 * input[:, :, bgn: bgn + distance] + \
|
||||
0.5 * random_sample[:, :, bgn: bgn + distance]
|
||||
|
||||
|
||||
class CutStripes:
|
||||
|
||||
def __init__(self, dim, cut_width, stripes_num):
|
||||
""" Cutting stripes with another randomly selected sample in mini-batch.
|
||||
args:
|
||||
dim: int, dimension along which to cut
|
||||
cut_width: int, maximum width of stripes to cut
|
||||
stripes_num: int, how many stripes to cut
|
||||
"""
|
||||
|
||||
super(CutStripes, self).__init__()
|
||||
|
||||
assert dim in [2, 3]
|
||||
|
||||
self.dim = dim
|
||||
self.cut_width = cut_width
|
||||
self.stripes_num = stripes_num
|
||||
|
||||
def __call__(self, input):
|
||||
"""input: (batch_size, channel, time_steps, freq_bins)"""
|
||||
|
||||
assert input.ndimension() == 4
|
||||
|
||||
batch_size = input.shape[0]
|
||||
total_width = input.shape[self.dim]
|
||||
|
||||
rand_sample = input[torch.randperm(batch_size)]
|
||||
for i in range(batch_size):
|
||||
self.transform_slice(input[i], rand_sample[i], total_width)
|
||||
return input
|
||||
|
||||
def transform_slice(self, input, random_sample, total_width):
|
||||
|
||||
for _ in range(self.stripes_num):
|
||||
distance = torch.randint(low=0, high=self.cut_width, size=(1,))[0]
|
||||
bgn = torch.randint(low=0, high=total_width - distance, size=(1,))[0]
|
||||
|
||||
if self.dim == 2:
|
||||
input[:, bgn: bgn + distance, :] = random_sample[:, bgn: bgn + distance, :]
|
||||
elif self.dim == 3:
|
||||
input[:, :, bgn: bgn + distance] = random_sample[:, :, bgn: bgn + distance]
|
||||
|
||||
|
||||
class SpecAugmentation:
|
||||
|
||||
def __init__(self, time_drop_width, time_stripes_num, freq_drop_width, freq_stripes_num,
|
||||
mask_type='mixture'):
|
||||
"""Spec augmetation and SpecAugment++.
|
||||
[ref] Park, D.S., Chan, W., Zhang, Y., Chiu, C.C., Zoph, B., Cubuk, E.D.
|
||||
and Le, Q.V., 2019. Specaugment: A simple data augmentation method
|
||||
for automatic speech recognition. arXiv preprint arXiv:1904.08779.
|
||||
[ref] Wang H, Zou Y, Wang W., 2021. SpecAugment++: A Hidden Space
|
||||
Data Augmentation Method for Acoustic Scene Classification. arXiv
|
||||
preprint arXiv:2103.16858.
|
||||
|
||||
Args:
|
||||
time_drop_width: int
|
||||
time_stripes_num: int
|
||||
freq_drop_width: int
|
||||
freq_stripes_num: int
|
||||
mask_type: str, mask type in SpecAugment++ (zero_value, mixture, cutting)
|
||||
"""
|
||||
|
||||
super(SpecAugmentation, self).__init__()
|
||||
|
||||
if mask_type == 'zero_value':
|
||||
self.time_augmentator = DropStripes(dim=2, drop_width=time_drop_width,
|
||||
stripes_num=time_stripes_num)
|
||||
self.freq_augmentator = DropStripes(dim=3, drop_width=freq_drop_width,
|
||||
stripes_num=freq_stripes_num)
|
||||
elif mask_type == 'mixture':
|
||||
self.time_augmentator = MixStripes(dim=2, mix_width=time_drop_width,
|
||||
stripes_num=time_stripes_num)
|
||||
self.freq_augmentator = MixStripes(dim=3, mix_width=freq_drop_width,
|
||||
stripes_num=freq_stripes_num)
|
||||
elif mask_type == 'cutting':
|
||||
self.time_augmentator = CutStripes(dim=2, cut_width=time_drop_width,
|
||||
stripes_num=time_stripes_num)
|
||||
self.freq_augmentator = CutStripes(dim=3, cut_width=freq_drop_width,
|
||||
stripes_num=freq_stripes_num)
|
||||
else:
|
||||
raise NameError('No such mask type in SpecAugment++')
|
||||
|
||||
def __call__(self, inputs):
|
||||
# x should be in size [batch_size, channel, time_steps, freq_bins]
|
||||
x = self.time_augmentator(inputs)
|
||||
x = self.freq_augmentator(x)
|
||||
return x
|
@ -9,7 +9,7 @@ import re
|
||||
|
||||
from minigpt4.common.registry import registry
|
||||
from minigpt4.processors.base_processor import BaseProcessor
|
||||
from minigpt4.processors.randaugment import RandomAugment
|
||||
from minigpt4.processors.vision_augment import RandomAugment
|
||||
from omegaconf import OmegaConf
|
||||
from torchvision import transforms
|
||||
from torchvision.transforms.functional import InterpolationMode
|
||||
|
166
minigpt4/processors/imagebind_audio_processor.py
Normal file
166
minigpt4/processors/imagebind_audio_processor.py
Normal file
@ -0,0 +1,166 @@
|
||||
from typing import Union, List
|
||||
|
||||
import torch
|
||||
import torchaudio
|
||||
from omegaconf import OmegaConf
|
||||
from pytorchvideo.data.clip_sampling import ConstantClipsPerVideoSampler, RandomMultiClipSampler
|
||||
from torch import Tensor
|
||||
|
||||
from imagebind.data.data_utils import waveform2melspec, get_constant_clip_timepoints, \
|
||||
get_random_clip_timepoints
|
||||
from minigpt4.processors.base_processor import BaseProcessor
|
||||
from torchvision import transforms
|
||||
from minigpt4.common.registry import registry
|
||||
from minigpt4.processors.audio_augment import SpecAugmentation
|
||||
|
||||
|
||||
class ImageBindAudioBaseProcessor(BaseProcessor):
|
||||
def __init__(self, mean=None, std=None, target_sr=None, clip_duration=None, clips_per_video=None,
|
||||
num_mel_bins=None, target_length=None, clip_sample_method="Random"):
|
||||
super().__init__()
|
||||
self.mean = -4.268 if mean is None else mean
|
||||
self.std = 9.138 if std is None else std
|
||||
self.target_sr = 16000 if target_sr is None else target_sr
|
||||
self.num_mel_bins = num_mel_bins
|
||||
self.target_length = target_length
|
||||
self.clip_sampler = self._construct_clip_sampler(clip_duration, clips_per_video, clip_sample_method)
|
||||
self.normalize = transforms.Normalize(self.mean, self.std)
|
||||
|
||||
def _construct_clip_sampler(self, clip_duration, clips_per_video, clip_sample_method):
|
||||
if clip_duration is None or clips_per_video is None:
|
||||
return None
|
||||
if clip_sample_method == "Constant":
|
||||
return ConstantClipsPerVideoSampler(
|
||||
clip_duration=clip_duration, clips_per_video=clips_per_video
|
||||
)
|
||||
elif clip_sample_method == "Random":
|
||||
return RandomMultiClipSampler(clip_duration=clip_duration, num_clips=clips_per_video)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
def waveform_resample(self, waveform: Tensor, origin_sr: int) -> Tensor:
|
||||
return torchaudio.functional.resample(
|
||||
waveform, orig_freq=origin_sr, new_freq=self.target_sr
|
||||
)
|
||||
|
||||
def clip_sample(self, waveform: Tensor) -> List[Tensor]:
|
||||
if self.clip_sampler is None:
|
||||
return [waveform]
|
||||
elif isinstance(self.clip_sampler, ConstantClipsPerVideoSampler):
|
||||
all_clips_timepoints = get_constant_clip_timepoints(self.clip_sampler, waveform.size(1) / self.target_sr)
|
||||
elif isinstance(self.clip_sampler, RandomMultiClipSampler):
|
||||
all_clips_timepoints = get_random_clip_timepoints(self.clip_sampler, waveform.size(1) / self.target_sr)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
all_clips = []
|
||||
for clip_timepoints in all_clips_timepoints:
|
||||
start_pos = int(clip_timepoints[0] * self.target_sr)
|
||||
end_pos = int(clip_timepoints[1] * self.target_sr)
|
||||
waveform_clip = waveform[:, start_pos: end_pos]
|
||||
all_clips.append(waveform_clip)
|
||||
return all_clips
|
||||
|
||||
def waveform_melspec(self, waveforms: Union[List[Tensor], Tensor]) -> List[Tensor]:
|
||||
if isinstance(waveforms, Tensor):
|
||||
return waveform2melspec(waveforms, self.target_sr, self.num_mel_bins, self.target_length)
|
||||
else:
|
||||
return [waveform2melspec(waveform, self.target_sr, self.num_mel_bins, self.target_length)
|
||||
for waveform in waveforms]
|
||||
|
||||
|
||||
@registry.register_processor("imagebind_audio_train")
|
||||
class ImageBindAudioTrainProcessor(ImageBindAudioBaseProcessor):
|
||||
def __init__(self, mean=None, std=None, target_sr=None, clip_duration=None, clips_per_video=None,
|
||||
clip_sample_method="Random", num_mel_bins=None, target_length=None, time_drop_width=13,
|
||||
time_stripes_num=2, freq_drop_width=8, freq_stripes_num=2, mask_type='mixture'):
|
||||
super().__init__(mean=mean, std=std, target_sr=target_sr,
|
||||
clip_duration=clip_duration, clips_per_video=clips_per_video,
|
||||
num_mel_bins=num_mel_bins, target_length=target_length,
|
||||
clip_sample_method=clip_sample_method)
|
||||
self.spec_augment = SpecAugmentation(time_drop_width, time_stripes_num,
|
||||
freq_drop_width, freq_stripes_num, mask_type)
|
||||
|
||||
def __call__(self, item):
|
||||
# item: Tuple[Tensor, int]
|
||||
waveform, origin_sr = item[0], item[1]
|
||||
waveform = self.waveform_resample(waveform, origin_sr)
|
||||
waveform_clips = self.clip_sample(waveform)
|
||||
melspec_clips = self.waveform_melspec(waveform_clips)
|
||||
normed_melspecs = [self.normalize(clip) for clip in melspec_clips]
|
||||
all_clips = torch.stack(normed_melspecs, dim=0)
|
||||
# all_clips: [clips_per_video, channel, mel_bins, time_steps]
|
||||
# augment: [batch_size, channel, time_steps, freq_bins]
|
||||
augmented_clips = self.spec_augment(all_clips.transpose(-2, -1)).transpose(-2, -1)
|
||||
return augmented_clips
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, cfg=None):
|
||||
if cfg is None:
|
||||
cfg = OmegaConf.create()
|
||||
|
||||
target_sr = cfg.get("target_sr", 16000)
|
||||
clip_duration = cfg.get("clip_duration", 2)
|
||||
clips_per_video = cfg.get("clips_per_video", 3)
|
||||
num_mel_bins = cfg.get("num_mel_bins", 128)
|
||||
target_length = cfg.get("target_length", 204)
|
||||
time_drop_width = cfg.get("time_drop_width", 13)
|
||||
time_stripes_num = cfg.get("time_stripes_num", 2)
|
||||
# 13 * 2 / 204 = 12.75% Time Mask
|
||||
freq_drop_width = cfg.get("freq_drop_width", 8)
|
||||
freq_stripes_num = cfg.get("freq_stripes_num", 2)
|
||||
# 8 * 2 / 128 = 12.5% Freq Mask
|
||||
mask_type = cfg.get("mask_type", 'mixture')
|
||||
|
||||
mean = cfg.get("mean", None)
|
||||
std = cfg.get("std", None)
|
||||
|
||||
return cls(
|
||||
mean=mean, std=std, target_sr=target_sr,
|
||||
clip_duration=clip_duration, clips_per_video=clips_per_video,
|
||||
num_mel_bins=num_mel_bins, target_length=target_length,
|
||||
time_drop_width=time_drop_width, time_stripes_num=time_stripes_num,
|
||||
freq_drop_width=freq_drop_width, freq_stripes_num=freq_stripes_num,
|
||||
mask_type=mask_type
|
||||
)
|
||||
|
||||
|
||||
@registry.register_processor("imagebind_audio_eval")
|
||||
class ImageBindAudioEvalProcessor(ImageBindAudioBaseProcessor):
|
||||
def __init__(self, mean=None, std=None, target_sr=None, clip_duration=None, clips_per_video=None,
|
||||
clip_sample_method="Constant", num_mel_bins=None, target_length=None):
|
||||
super().__init__(mean=mean, std=std, target_sr=target_sr,
|
||||
clip_duration=clip_duration, clips_per_video=clips_per_video,
|
||||
num_mel_bins=num_mel_bins, target_length=target_length,
|
||||
clip_sample_method=clip_sample_method)
|
||||
|
||||
def __call__(self, item):
|
||||
# item: Tuple[Tensor, int]
|
||||
waveform, origin_sr = item[0], item[1]
|
||||
waveform = self.waveform_resample(waveform, origin_sr)
|
||||
waveform_clips = self.clip_sample(waveform)
|
||||
melspec_clips = self.waveform_melspec(waveform_clips)
|
||||
normed_melspecs = [self.normalize(clip) for clip in melspec_clips]
|
||||
all_clips = torch.stack(normed_melspecs, dim=0)
|
||||
# all_clips: [clips_per_video, channel, mel_bins, time_steps]
|
||||
return all_clips
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, cfg=None):
|
||||
if cfg is None:
|
||||
cfg = OmegaConf.create()
|
||||
|
||||
target_sr = cfg.get("target_sr", 16000)
|
||||
clip_duration = cfg.get("clip_duration", 2)
|
||||
clips_per_video = cfg.get("clips_per_video", 3)
|
||||
num_mel_bins = cfg.get("num_mel_bins", 128)
|
||||
target_length = cfg.get("target_length", 204)
|
||||
|
||||
mean = cfg.get("mean", None)
|
||||
std = cfg.get("std", None)
|
||||
|
||||
return cls(
|
||||
mean=mean, std=std, target_sr=target_sr,
|
||||
clip_duration=clip_duration, clips_per_video=clips_per_video,
|
||||
num_mel_bins=num_mel_bins, target_length=target_length
|
||||
)
|
||||
|
@ -9,7 +9,7 @@ import re
|
||||
|
||||
from minigpt4.common.registry import registry
|
||||
from minigpt4.processors.base_processor import BaseProcessor
|
||||
from minigpt4.processors.randaugment import RandomAugment
|
||||
from minigpt4.processors.vision_augment import RandomAugment
|
||||
from omegaconf import OmegaConf
|
||||
from torchvision import transforms
|
||||
from torchvision.transforms.functional import InterpolationMode
|
||||
@ -147,3 +147,5 @@ class ImageBindVisionEvalProcessor(ImageBindVisionBaseProcessor):
|
||||
|
||||
return cls(image_size=image_size, mean=mean, std=std)
|
||||
|
||||
|
||||
|
@ -53,7 +53,6 @@ ipdb
|
||||
tensorflow-cpu
|
||||
tensorboardX
|
||||
# mmcv==2.0.0 -f https://download.openmmlab.com/mmcv/dist/cu117/torch1.13/index.html
|
||||
|
||||
bs4==0.0.1 # Needed for text cleaning
|
||||
bson==0.5.10
|
||||
byted-dataloader==0.3.6
|
||||
@ -66,3 +65,4 @@ tensorboard==2.11.2
|
||||
tensorflow==2.11.0 # Needed for tensorboard hdfs support
|
||||
tensorflow-io==0.30.0 # Needed for tensorboard hdfs support
|
||||
tqdm==4.64.1
|
||||
pytorchvideo==0.1.5
|
68
train_configs/bindgpt4_stage1_audio_pretrain.yaml
Normal file
68
train_configs/bindgpt4_stage1_audio_pretrain.yaml
Normal file
@ -0,0 +1,68 @@
|
||||
model:
|
||||
arch: bind_gpt4
|
||||
model_type: pretrain_vicuna
|
||||
freeze_imagebind: True
|
||||
freeze_qformer: False
|
||||
|
||||
datasets:
|
||||
bbc:
|
||||
audio_processor:
|
||||
train:
|
||||
name: "imagebind_audio_train"
|
||||
text_processor:
|
||||
train:
|
||||
name: "imagebind_caption"
|
||||
sample_ratio: 31
|
||||
audioset:
|
||||
audio_processor:
|
||||
train:
|
||||
name: "imagebind_audio_train"
|
||||
text_processor:
|
||||
train:
|
||||
name: "imagebind_caption"
|
||||
sample_ratio: 108
|
||||
soundbible:
|
||||
audio_processor:
|
||||
train:
|
||||
name: "imagebind_audio_train"
|
||||
text_processor:
|
||||
train:
|
||||
name: "imagebind_caption"
|
||||
sample_ratio: 2
|
||||
freesound:
|
||||
audio_processor:
|
||||
train:
|
||||
name: "imagebind_audio_train"
|
||||
text_processor:
|
||||
train:
|
||||
name: "imagebind_caption"
|
||||
sample_ratio: 262
|
||||
|
||||
run:
|
||||
task: image_text_pretrain
|
||||
# optimizer
|
||||
lr_sched: "linear_warmup_cosine_lr"
|
||||
init_lr: 1e-4
|
||||
min_lr: 8e-5
|
||||
warmup_lr: 1e-6
|
||||
|
||||
weight_decay: 0.05
|
||||
max_epoch: 4
|
||||
batch_size_train: 64
|
||||
batch_size_eval: 64
|
||||
num_workers: 4
|
||||
warmup_steps: 5000
|
||||
iters_per_epoch: 5000
|
||||
|
||||
seed: 42
|
||||
output_dir: "output/bindgpt4_stage1_audio_pretrain"
|
||||
|
||||
amp: True
|
||||
resume_ckpt_path: null
|
||||
evaluate: False
|
||||
train_splits: ["train"]
|
||||
|
||||
device: "cuda"
|
||||
world_size: 1
|
||||
dist_url: "env://"
|
||||
distributed: True
|
Loading…
Reference in New Issue
Block a user