mirror of
https://github.com/Vision-CAIR/MiniGPT-4.git
synced 2025-04-06 19:10:45 +00:00
init audio data config (#2)
- Add audio datasets - Add audio processors - Add audio support in bindgpt - Add audio training config --------- Co-authored-by: bingyikang <bingyikang@bytedance.com> Co-authored-by: zhaoyang <913556700@qq.com>
This commit is contained in:
parent
3efda2ac76
commit
05220fe3c1
175
README.md
175
README.md
@ -1,170 +1,5 @@
|
|||||||
# MiniGPT-4: Enhancing Vision-language Understanding with Advanced Large Language Models
|
# TODO List
|
||||||
[Deyao Zhu](https://tsutikgiau.github.io/)* (On Job Market!), [Jun Chen](https://junchen14.github.io/)* (On Job Market!), [Xiaoqian Shen](https://xiaoqian-shen.github.io), [Xiang Li](https://xiangli.ac.cn), and [Mohamed Elhoseiny](https://www.mohamed-elhoseiny.com/). *Equal Contribution
|
- [ ] reconstruct the eval scripts
|
||||||
|
- [ ] support audio evaluation
|
||||||
**King Abdullah University of Science and Technology**
|
- [ ] clean code repo & solve recursive import
|
||||||
|
- [x] audio processor: random sampling
|
||||||
<a href='https://minigpt-4.github.io'><img src='https://img.shields.io/badge/Project-Page-Green'></a> <a href='https://arxiv.org/abs/2304.10592'><img src='https://img.shields.io/badge/Paper-Arxiv-red'></a> <a href='https://huggingface.co/spaces/Vision-CAIR/minigpt4'><img src='https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue'></a> <a href='https://huggingface.co/Vision-CAIR/MiniGPT-4'><img src='https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Model-blue'></a> [](https://colab.research.google.com/drive/1OK4kYsZphwt5DXchKkzMBjYF6jnkqh4R?usp=sharing) [](https://www.youtube.com/watch?v=__tftoxpBAw&feature=youtu.be)
|
|
||||||
|
|
||||||
|
|
||||||
## News
|
|
||||||
We now provide a pretrained MiniGPT-4 aligned with Vicuna-7B! The demo GPU memory consumption now can be as low as 12GB.
|
|
||||||
|
|
||||||
|
|
||||||
## Online Demo
|
|
||||||
|
|
||||||
Click the image to chat with MiniGPT-4 around your images
|
|
||||||
[](https://minigpt-4.github.io)
|
|
||||||
|
|
||||||
|
|
||||||
## Examples
|
|
||||||
| | |
|
|
||||||
:-------------------------:|:-------------------------:
|
|
||||||
 | 
|
|
||||||
 | 
|
|
||||||
|
|
||||||
More examples can be found in the [project page](https://minigpt-4.github.io).
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
## Introduction
|
|
||||||
- MiniGPT-4 aligns a frozen visual encoder from BLIP-2 with a frozen LLM, Vicuna, using just one projection layer.
|
|
||||||
- We train MiniGPT-4 with two stages. The first traditional pretraining stage is trained using roughly 5 million aligned image-text pairs in 10 hours using 4 A100s. After the first stage, Vicuna is able to understand the image. But the generation ability of Vicuna is heavilly impacted.
|
|
||||||
- To address this issue and improve usability, we propose a novel way to create high-quality image-text pairs by the model itself and ChatGPT together. Based on this, we then create a small (3500 pairs in total) yet high-quality dataset.
|
|
||||||
- The second finetuning stage is trained on this dataset in a conversation template to significantly improve its generation reliability and overall usability. To our surprise, this stage is computationally efficient and takes only around 7 minutes with a single A100.
|
|
||||||
- MiniGPT-4 yields many emerging vision-language capabilities similar to those demonstrated in GPT-4.
|
|
||||||
|
|
||||||
|
|
||||||

|
|
||||||
|
|
||||||
|
|
||||||
## Getting Started
|
|
||||||
### Installation
|
|
||||||
|
|
||||||
**1. Prepare the code and the environment**
|
|
||||||
|
|
||||||
Git clone our repository, creating a python environment and ativate it via the following command
|
|
||||||
|
|
||||||
```bash
|
|
||||||
git clone https://github.com/Vision-CAIR/MiniGPT-4.git
|
|
||||||
cd MiniGPT-4
|
|
||||||
conda env create -f environment.yml
|
|
||||||
conda activate minigpt4
|
|
||||||
```
|
|
||||||
|
|
||||||
|
|
||||||
**2. Prepare the pretrained Vicuna weights**
|
|
||||||
|
|
||||||
The current version of MiniGPT-4 is built on the v0 versoin of Vicuna-13B.
|
|
||||||
Please refer to our instruction [here](PrepareVicuna.md)
|
|
||||||
to prepare the Vicuna weights.
|
|
||||||
The final weights would be in a single folder in a structure similar to the following:
|
|
||||||
|
|
||||||
```
|
|
||||||
vicuna_weights
|
|
||||||
├── config.json
|
|
||||||
├── generation_config.json
|
|
||||||
├── pytorch_model.bin.index.json
|
|
||||||
├── pytorch_model-00001-of-00003.bin
|
|
||||||
...
|
|
||||||
```
|
|
||||||
|
|
||||||
Then, set the path to the vicuna weight in the model config file
|
|
||||||
[here](minigpt4/configs/models/minigpt4.yaml#L16) at Line 16.
|
|
||||||
|
|
||||||
**3. Prepare the pretrained MiniGPT-4 checkpoint**
|
|
||||||
|
|
||||||
Download the pretrained checkpoints according to the Vicuna model you prepare.
|
|
||||||
|
|
||||||
| Checkpoint Aligned with Vicuna 13B | Checkpoint Aligned with Vicuna 7B |
|
|
||||||
:------------------------------------------------------------------------------------------------:|:----------------------------------------------------------------------------------------------:
|
|
||||||
[Downlad](https://drive.google.com/file/d/1a4zLvaiDBr-36pasffmgpvH5P7CKmpze/view?usp=share_link) | [Download](https://drive.google.com/file/d/1RY9jV0dyqLX-o38LrumkKRh6Jtaop58R/view?usp=sharing)
|
|
||||||
|
|
||||||
|
|
||||||
Then, set the path to the pretrained checkpoint in the evaluation config file
|
|
||||||
in [eval_configs/minigpt4_eval.yaml](eval_configs/minigpt4_eval.yaml#L10) at Line 11.
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
### Launching Demo Locally
|
|
||||||
|
|
||||||
Try out our demo [demo.py](eval_scripts/qualitative_eval.py) on your local machine by running
|
|
||||||
|
|
||||||
```
|
|
||||||
python demo.py --cfg-path eval_configs/minigpt4_eval.yaml --gpu-id 0
|
|
||||||
```
|
|
||||||
|
|
||||||
To save GPU memory, Vicuna loads as 8 bit by default, with a beam search width of 1.
|
|
||||||
This configuration requires about 23G GPU memory for Vicuna 13B and 11.5G GPU memory for Vicuna 7B.
|
|
||||||
For more powerful GPUs, you can run the model
|
|
||||||
in 16 bit by setting low_resource to False in the config file
|
|
||||||
[minigpt4_eval.yaml](eval_configs/minigpt4_eval.yaml) and use a larger beam search width.
|
|
||||||
|
|
||||||
Thanks [@WangRongsheng](https://github.com/WangRongsheng), you can also run our code on [Colab](https://colab.research.google.com/drive/1OK4kYsZphwt5DXchKkzMBjYF6jnkqh4R?usp=sharing)
|
|
||||||
|
|
||||||
|
|
||||||
### Training
|
|
||||||
The training of MiniGPT-4 contains two alignment stages.
|
|
||||||
|
|
||||||
**1. First pretraining stage**
|
|
||||||
|
|
||||||
In the first pretrained stage, the model is trained using image-text pairs from Laion and CC datasets
|
|
||||||
to align the vision and language model. To download and prepare the datasets, please check
|
|
||||||
our [first stage dataset preparation instruction](dataset/README_1_STAGE.md).
|
|
||||||
After the first stage, the visual features are mapped and can be understood by the language
|
|
||||||
model.
|
|
||||||
To launch the first stage training, run the following command. In our experiments, we use 4 A100.
|
|
||||||
You can change the save path in the config file
|
|
||||||
[train_configs/minigpt4_stage1_pretrain.yaml](train_configs/minigpt4_stage1_pretrain.yaml)
|
|
||||||
|
|
||||||
```bash
|
|
||||||
torchrun --nproc-per-node NUM_GPU train.py --cfg-path train_configs/minigpt4_stage1_pretrain.yaml
|
|
||||||
```
|
|
||||||
|
|
||||||
A MiniGPT-4 checkpoint with only stage one training can be downloaded
|
|
||||||
[here (13B)](https://drive.google.com/file/d/1u9FRRBB3VovP1HxCAlpD9Lw4t4P6-Yq8/view?usp=share_link) or [here (7B)](https://drive.google.com/file/d/1HihQtCEXUyBM1i9DQbaK934wW3TZi-h5/view?usp=share_link).
|
|
||||||
Compared to the model after stage two, this checkpoint generate incomplete and repeated sentences frequently.
|
|
||||||
|
|
||||||
|
|
||||||
**2. Second finetuning stage**
|
|
||||||
|
|
||||||
In the second stage, we use a small high quality image-text pair dataset created by ourselves
|
|
||||||
and convert it to a conversation format to further align MiniGPT-4.
|
|
||||||
To download and prepare our second stage dataset, please check our
|
|
||||||
[second stage dataset preparation instruction](dataset/README_2_STAGE.md).
|
|
||||||
To launch the second stage alignment,
|
|
||||||
first specify the path to the checkpoint file trained in stage 1 in
|
|
||||||
[train_configs/minigpt4_stage1_pretrain.yaml](train_configs/minigpt4_stage2_finetune.yaml).
|
|
||||||
You can also specify the output path there.
|
|
||||||
Then, run the following command. In our experiments, we use 1 A100.
|
|
||||||
|
|
||||||
```bash
|
|
||||||
torchrun --nproc-per-node NUM_GPU train.py --cfg-path train_configs/minigpt4_stage2_finetune.yaml
|
|
||||||
```
|
|
||||||
|
|
||||||
After the second stage alignment, MiniGPT-4 is able to talk about the image coherently and user-friendly.
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
## Acknowledgement
|
|
||||||
|
|
||||||
+ [BLIP2](https://huggingface.co/docs/transformers/main/model_doc/blip-2) The model architecture of MiniGPT-4 follows BLIP-2. Don't forget to check this great open-source work if you don't know it before!
|
|
||||||
+ [Lavis](https://github.com/salesforce/LAVIS) This repository is built upon Lavis!
|
|
||||||
+ [Vicuna](https://github.com/lm-sys/FastChat) The fantastic language ability of Vicuna with only 13B parameters is just amazing. And it is open-source!
|
|
||||||
|
|
||||||
|
|
||||||
If you're using MiniGPT-4 in your research or applications, please cite using this BibTeX:
|
|
||||||
```bibtex
|
|
||||||
@article{zhu2023minigpt,
|
|
||||||
title={MiniGPT-4: Enhancing Vision-Language Understanding with Advanced Large Language Models},
|
|
||||||
author={Zhu, Deyao and Chen, Jun and Shen, Xiaoqian and Li, Xiang and Elhoseiny, Mohamed},
|
|
||||||
journal={arXiv preprint arXiv:2304.10592},
|
|
||||||
year={2023}
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
|
|
||||||
## License
|
|
||||||
This repository is under [BSD 3-Clause License](LICENSE.md).
|
|
||||||
Many codes are based on [Lavis](https://github.com/salesforce/LAVIS) with
|
|
||||||
BSD 3-Clause License [here](LICENSE_Lavis.md).
|
|
@ -6,7 +6,9 @@
|
|||||||
pip3 install --upgrade pip
|
pip3 install --upgrade pip
|
||||||
pip3 install -r requirements.txt
|
pip3 install -r requirements.txt
|
||||||
pip3 install byted-dataloader -i "https://bytedpypi.byted.org/simple"
|
pip3 install byted-dataloader -i "https://bytedpypi.byted.org/simple"
|
||||||
mmengine-0.7.3
|
pip3 install mmmengine==0.7.3
|
||||||
|
pip3 install mmcv==2.0.0 -f https://download.openmmlab.com/mmcv/dist/cu117/torch1.13/index.html
|
||||||
|
|
||||||
# unset http_proxy && unset https_proxy && unset no_proxy
|
# unset http_proxy && unset https_proxy && unset no_proxy
|
||||||
|
|
||||||
# # ----------------------------------------------------------------------------------------
|
# # ----------------------------------------------------------------------------------------
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
import torch
|
import torch
|
||||||
|
import torchaudio
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
|
|
||||||
@ -13,3 +14,12 @@ def load_image(image, image_processor):
|
|||||||
if len(image.shape) == 3:
|
if len(image.shape) == 3:
|
||||||
image = image.unsqueeze(0)
|
image = image.unsqueeze(0)
|
||||||
return image
|
return image
|
||||||
|
|
||||||
|
|
||||||
|
def load_audio(audio, audio_processor):
|
||||||
|
if isinstance(audio, str): # is a audio path
|
||||||
|
raw_audio = torchaudio.load(audio)
|
||||||
|
audio = audio_processor(audio)
|
||||||
|
# elif isinstance(audio, )
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
||||||
|
@ -125,6 +125,7 @@ with gr.Blocks() as demo:
|
|||||||
with gr.Row():
|
with gr.Row():
|
||||||
with gr.Column(scale=0.5):
|
with gr.Column(scale=0.5):
|
||||||
image = gr.Image(type="pil")
|
image = gr.Image(type="pil")
|
||||||
|
# audio = gr.Audio()
|
||||||
upload_button = gr.Button(value="Upload & Start Chat", interactive=True, variant="primary")
|
upload_button = gr.Button(value="Upload & Start Chat", interactive=True, variant="primary")
|
||||||
clear = gr.Button("Restart")
|
clear = gr.Button("Restart")
|
||||||
|
|
||||||
@ -150,7 +151,7 @@ with gr.Blocks() as demo:
|
|||||||
chat_state = gr.State()
|
chat_state = gr.State()
|
||||||
emb_list = gr.State()
|
emb_list = gr.State()
|
||||||
chatbot = gr.Chatbot(label='BindGPT-4')
|
chatbot = gr.Chatbot(label='BindGPT-4')
|
||||||
text_input = gr.Textbox(label='User', placeholder='Please upload your image first', interactive=False)
|
text_input = gr.Textbox(label='User', placeholder='Please upload your image/audio first', interactive=False)
|
||||||
|
|
||||||
upload_button.click(upload_img, [image, text_input, chat_state],
|
upload_button.click(upload_img, [image, text_input, chat_state],
|
||||||
[image, text_input, upload_button, chat_state, emb_list])
|
[image, text_input, upload_button, chat_state, emb_list])
|
||||||
|
@ -15,7 +15,7 @@ import logging
|
|||||||
from imagebind.models.multimodal_preprocessors import SimpleTokenizer
|
from imagebind.models.multimodal_preprocessors import SimpleTokenizer
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from pytorchvideo import transforms as pv_transforms
|
from pytorchvideo import transforms as pv_transforms
|
||||||
from pytorchvideo.data.clip_sampling import ConstantClipsPerVideoSampler
|
from pytorchvideo.data.clip_sampling import ConstantClipsPerVideoSampler, RandomMultiClipSampler
|
||||||
from pytorchvideo.data.encoded_video import EncodedVideo
|
from pytorchvideo.data.encoded_video import EncodedVideo
|
||||||
|
|
||||||
from torchvision import transforms
|
from torchvision import transforms
|
||||||
@ -59,23 +59,11 @@ def waveform2melspec(waveform, sample_rate, num_mel_bins, target_length):
|
|||||||
fbank = torch.nn.functional.pad(fbank, (0, p), mode="constant", value=0)
|
fbank = torch.nn.functional.pad(fbank, (0, p), mode="constant", value=0)
|
||||||
elif p < 0:
|
elif p < 0:
|
||||||
fbank = fbank[:, 0:target_length]
|
fbank = fbank[:, 0:target_length]
|
||||||
# Convert to [1, mel_bins, num_frames] shape, essentially like a 1
|
# Convert to [1, mel_bins, num_frames] shape, essentially like a 1 channel image
|
||||||
# channel image
|
|
||||||
fbank = fbank.unsqueeze(0)
|
fbank = fbank.unsqueeze(0)
|
||||||
return fbank
|
return fbank
|
||||||
|
|
||||||
|
|
||||||
def get_clip_timepoints(clip_sampler, duration):
|
|
||||||
# Read out all clips in this video
|
|
||||||
all_clips_timepoints = []
|
|
||||||
is_last_clip = False
|
|
||||||
end = 0.0
|
|
||||||
while not is_last_clip:
|
|
||||||
start, end, _, _, is_last_clip = clip_sampler(end, duration, annotation=None)
|
|
||||||
all_clips_timepoints.append((start, end))
|
|
||||||
return all_clips_timepoints
|
|
||||||
|
|
||||||
|
|
||||||
def load_and_transform_vision_data(image_paths, device):
|
def load_and_transform_vision_data(image_paths, device):
|
||||||
if image_paths is None:
|
if image_paths is None:
|
||||||
return None
|
return None
|
||||||
@ -137,14 +125,14 @@ def load_and_transform_audio_data(
|
|||||||
waveform = torchaudio.functional.resample(
|
waveform = torchaudio.functional.resample(
|
||||||
waveform, orig_freq=sr, new_freq=sample_rate
|
waveform, orig_freq=sr, new_freq=sample_rate
|
||||||
)
|
)
|
||||||
all_clips_timepoints = get_clip_timepoints(
|
all_clips_timepoints = get_constant_clip_timepoints(
|
||||||
clip_sampler, waveform.size(1) / sample_rate
|
clip_sampler, waveform.size(1) / sample_rate
|
||||||
)
|
)
|
||||||
all_clips = []
|
all_clips = []
|
||||||
for clip_timepoints in all_clips_timepoints:
|
for clip_timepoints in all_clips_timepoints:
|
||||||
waveform_clip = waveform[
|
waveform_clip = waveform[
|
||||||
:,
|
:,
|
||||||
int(clip_timepoints[0] * sample_rate) : int(
|
int(clip_timepoints[0] * sample_rate): int(
|
||||||
clip_timepoints[1] * sample_rate
|
clip_timepoints[1] * sample_rate
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
@ -162,7 +150,8 @@ def load_and_transform_audio_data(
|
|||||||
return torch.stack(audio_outputs, dim=0)
|
return torch.stack(audio_outputs, dim=0)
|
||||||
|
|
||||||
|
|
||||||
def get_clip_timepoints(clip_sampler, duration):
|
def get_constant_clip_timepoints(clip_sampler, duration):
|
||||||
|
assert isinstance(clip_sampler, ConstantClipsPerVideoSampler), "Incompatible Type of Sampler!"
|
||||||
# Read out all clips in this video
|
# Read out all clips in this video
|
||||||
all_clips_timepoints = []
|
all_clips_timepoints = []
|
||||||
is_last_clip = False
|
is_last_clip = False
|
||||||
@ -173,6 +162,13 @@ def get_clip_timepoints(clip_sampler, duration):
|
|||||||
return all_clips_timepoints
|
return all_clips_timepoints
|
||||||
|
|
||||||
|
|
||||||
|
def get_random_clip_timepoints(clip_sampler, duration):
|
||||||
|
assert isinstance(clip_sampler, RandomMultiClipSampler), "Incompatible Type of Sampler!"
|
||||||
|
starts, ends, _, _, _ = clip_sampler(0.0, duration, annotation=None)
|
||||||
|
all_clips_timepoints = sorted(list(zip(starts, ends)), key=lambda x: x[0])
|
||||||
|
return all_clips_timepoints
|
||||||
|
|
||||||
|
|
||||||
def crop_boxes(boxes, x_offset, y_offset):
|
def crop_boxes(boxes, x_offset, y_offset):
|
||||||
"""
|
"""
|
||||||
Perform crop on the bounding boxes given the offsets.
|
Perform crop on the bounding boxes given the offsets.
|
||||||
@ -244,7 +240,7 @@ def uniform_crop(images, size, spatial_idx, boxes=None, scale_size=None):
|
|||||||
x_offset = 0
|
x_offset = 0
|
||||||
elif spatial_idx == 2:
|
elif spatial_idx == 2:
|
||||||
x_offset = width - size
|
x_offset = width - size
|
||||||
cropped = images[:, :, y_offset : y_offset + size, x_offset : x_offset + size]
|
cropped = images[:, :, y_offset: y_offset + size, x_offset: x_offset + size]
|
||||||
cropped_boxes = crop_boxes(boxes, x_offset, y_offset) if boxes is not None else None
|
cropped_boxes = crop_boxes(boxes, x_offset, y_offset) if boxes is not None else None
|
||||||
if ndim == 3:
|
if ndim == 3:
|
||||||
cropped = cropped.squeeze(0)
|
cropped = cropped.squeeze(0)
|
||||||
@ -328,7 +324,7 @@ def load_and_transform_video_data(
|
|||||||
**{"sample_rate": sample_rate},
|
**{"sample_rate": sample_rate},
|
||||||
)
|
)
|
||||||
|
|
||||||
all_clips_timepoints = get_clip_timepoints(clip_sampler, video.duration)
|
all_clips_timepoints = get_constant_clip_timepoints(clip_sampler, video.duration)
|
||||||
|
|
||||||
all_video = []
|
all_video = []
|
||||||
for clip_timepoints in all_clips_timepoints:
|
for clip_timepoints in all_clips_timepoints:
|
||||||
|
@ -50,46 +50,59 @@ ModalityType = SimpleNamespace(
|
|||||||
class ImageBindJoiner(nn.Module):
|
class ImageBindJoiner(nn.Module):
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
vision_query_token_num: int,
|
vision_query_token_num: int,
|
||||||
|
audio_query_token_num: int,
|
||||||
vision_qformer_frozen: bool = False,
|
vision_qformer_frozen: bool = False,
|
||||||
vision_qformer_model: str = "", # The url or path of pre-trained vision Q-Former model
|
vision_qformer_model: str = "", # The url or path of pre-trained vision Q-Former model
|
||||||
vision_pre_dims: List[int] = (), # Projection before Q-Former
|
vision_pre_dims: List[int] = (), # Projection before Q-Former
|
||||||
vision_post_dims: List[int] = (768, 768) # Projection after Q-Former
|
vision_post_dims: List[int] = (1280, 768), # Projection after Q-Former
|
||||||
|
audio_pre_dims: List[int] = (),
|
||||||
|
audio_post_dims: List[int] = (768, 768)
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
assert not (vision_qformer_frozen and vision_qformer_model == "")
|
assert not (vision_qformer_frozen and vision_qformer_model == "")
|
||||||
self.modality_pre_projectors = self._create_modality_pre_projectors(vision_pre_dims)
|
self.modality_pre_projectors = self._create_modality_pre_projectors(vision_pre_dims, audio_pre_dims)
|
||||||
self.modality_qformers = self._create_modality_qformers(vision_query_token_num,
|
self.modality_qformers = self._create_modality_qformers(vision_query_token_num,
|
||||||
vision_qformer_frozen,
|
vision_qformer_frozen,
|
||||||
vision_qformer_model)
|
vision_qformer_model,
|
||||||
self.modality_post_projectors = self._create_modality_post_projectors(vision_post_dims)
|
audio_query_token_num)
|
||||||
|
self.modality_post_projectors = self._create_modality_post_projectors(vision_post_dims, audio_post_dims)
|
||||||
|
|
||||||
def _create_modality_pre_projectors(self,
|
def _create_modality_pre_projectors(self,
|
||||||
vision_pre_dims
|
vision_pre_dims,
|
||||||
|
audio_pre_dims
|
||||||
):
|
):
|
||||||
modality_pre_projectors = {
|
modality_pre_projectors = {
|
||||||
ModalityType.VISION: create_projectors(vision_pre_dims)
|
ModalityType.VISION: create_projectors(vision_pre_dims),
|
||||||
|
ModalityType.AUDIO: create_projectors(audio_pre_dims)
|
||||||
}
|
}
|
||||||
return modality_pre_projectors
|
return modality_pre_projectors
|
||||||
|
|
||||||
def _create_modality_qformers(self,
|
def _create_modality_qformers(self,
|
||||||
vision_query_token_num,
|
vision_query_token_num,
|
||||||
vision_qformer_frozen,
|
vision_qformer_frozen,
|
||||||
vision_qformer_model
|
vision_qformer_model,
|
||||||
|
audio_query_token_num
|
||||||
):
|
):
|
||||||
vision_qformer = SequenceGenericQFormer(num_query_token=vision_query_token_num,
|
vision_qformer = SequenceGenericQFormer(num_query_token=vision_query_token_num,
|
||||||
freeze_qformer=vision_qformer_frozen,
|
freeze_qformer=vision_qformer_frozen,
|
||||||
encoder_width=1280, # TODO: fix hard-coding
|
encoder_width=1280, # TODO: fix hard-coding
|
||||||
q_former_model=vision_qformer_model)
|
q_former_model=vision_qformer_model)
|
||||||
|
audio_qformer = SequenceGenericQFormer(num_query_token=audio_query_token_num,
|
||||||
|
freeze_qformer=False,
|
||||||
|
encoder_width=768)
|
||||||
modality_qformers = {
|
modality_qformers = {
|
||||||
ModalityType.VISION: vision_qformer
|
ModalityType.VISION: vision_qformer,
|
||||||
|
ModalityType.AUDIO: audio_qformer
|
||||||
}
|
}
|
||||||
|
|
||||||
return nn.ModuleDict(modality_qformers)
|
return nn.ModuleDict(modality_qformers)
|
||||||
|
|
||||||
def _create_modality_post_projectors(self, vision_post_dims):
|
def _create_modality_post_projectors(self, vision_post_dims, audio_post_dims):
|
||||||
vision_projector = create_projectors(vision_post_dims)
|
vision_projector = create_projectors(vision_post_dims)
|
||||||
|
audio_projector = create_projectors(audio_post_dims)
|
||||||
modality_projectors = {
|
modality_projectors = {
|
||||||
ModalityType.VISION: vision_projector
|
ModalityType.VISION: vision_projector,
|
||||||
|
ModalityType.AUDIO: audio_projector
|
||||||
}
|
}
|
||||||
|
|
||||||
return nn.ModuleDict(modality_projectors)
|
return nn.ModuleDict(modality_projectors)
|
||||||
@ -97,7 +110,6 @@ class ImageBindJoiner(nn.Module):
|
|||||||
def forward(self, inputs: Dict[str, Tensor]) -> Dict[str, Tensor]:
|
def forward(self, inputs: Dict[str, Tensor]) -> Dict[str, Tensor]:
|
||||||
outputs = {}
|
outputs = {}
|
||||||
for modality_key, modality_value in inputs.items():
|
for modality_key, modality_value in inputs.items():
|
||||||
# assert modality_key == ModalityType.VISION, "Only Vision is Currently Supported."
|
|
||||||
if modality_value is not None:
|
if modality_value is not None:
|
||||||
modality_value = self.modality_pre_projectors[modality_key](modality_value)
|
modality_value = self.modality_pre_projectors[modality_key](modality_value)
|
||||||
modality_value = self.modality_qformers[modality_key](modality_value)
|
modality_value = self.modality_qformers[modality_key](modality_value)
|
||||||
@ -544,7 +556,7 @@ class ImageBindModel(nn.Module):
|
|||||||
|
|
||||||
# NOTE: The reduction operation has been modified.
|
# NOTE: The reduction operation has been modified.
|
||||||
if reduce_list:
|
if reduce_list:
|
||||||
modality_value = modality_value.reshape(B, S, *modality_value[2:])
|
modality_value = modality_value.reshape(B, S, *modality_value.shape[1:])
|
||||||
modality_value = modality_value.mean(dim=1)
|
modality_value = modality_value.mean(dim=1)
|
||||||
|
|
||||||
outputs[modality_key] = modality_value
|
outputs[modality_key] = modality_value
|
||||||
|
@ -32,10 +32,12 @@ class Registry:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def wrap(builder_cls):
|
def wrap(builder_cls):
|
||||||
|
# TODO: merge them or split builders by modality
|
||||||
from minigpt4.datasets.builders.image_base_dataset_builder import ImageBaseDatasetBuilder
|
from minigpt4.datasets.builders.image_base_dataset_builder import ImageBaseDatasetBuilder
|
||||||
|
from minigpt4.datasets.builders.audio_base_dataset_builder import AudioBaseDatasetBuilder
|
||||||
|
|
||||||
assert issubclass(
|
assert issubclass(
|
||||||
builder_cls, ImageBaseDatasetBuilder
|
builder_cls, (ImageBaseDatasetBuilder, AudioBaseDatasetBuilder)
|
||||||
), "All builders must inherit BaseDatasetBuilder class, found {}".format(
|
), "All builders must inherit BaseDatasetBuilder class, found {}".format(
|
||||||
builder_cls
|
builder_cls
|
||||||
)
|
)
|
||||||
|
5
minigpt4/configs/datasets/audioset/defaults.yaml
Normal file
5
minigpt4/configs/datasets/audioset/defaults.yaml
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
datasets:
|
||||||
|
audioset:
|
||||||
|
data_type: audio
|
||||||
|
build_info:
|
||||||
|
storage: /mnt/bn/zilongdata-hl/dataset/wavcaps/web_datasets/AudioSet_SL/AudioSet_SL{00..54}.tar
|
5
minigpt4/configs/datasets/bbc/defaults.yaml
Normal file
5
minigpt4/configs/datasets/bbc/defaults.yaml
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
datasets:
|
||||||
|
bbc:
|
||||||
|
data_type: audio
|
||||||
|
build_info:
|
||||||
|
storage: /mnt/bn/zilongdata-hl/dataset/wavcaps/web_datasets/BBC_Sound_Effects/BBC_Sound_Effects{000000..000062}.tar
|
5
minigpt4/configs/datasets/freesound/defaults.yaml
Normal file
5
minigpt4/configs/datasets/freesound/defaults.yaml
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
datasets:
|
||||||
|
freesound:
|
||||||
|
data_type: audio
|
||||||
|
build_info:
|
||||||
|
storage: /mnt/bn/zilongdata-hl/dataset/wavcaps/web_datasets/FreeSound/FreeSound{000000..000524}.tar
|
5
minigpt4/configs/datasets/soundbible/defaults.yaml
Normal file
5
minigpt4/configs/datasets/soundbible/defaults.yaml
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
datasets:
|
||||||
|
soundbible:
|
||||||
|
data_type: audio
|
||||||
|
build_info:
|
||||||
|
storage: /mnt/bn/zilongdata-hl/dataset/wavcaps/web_datasets/SoundBible/SoundBible0.tar
|
@ -11,12 +11,23 @@ from minigpt4.datasets.builders.image_text_pair_builder import (
|
|||||||
LaionBuilderImage,
|
LaionBuilderImage,
|
||||||
CCSBUAlignBuilderImage
|
CCSBUAlignBuilderImage
|
||||||
)
|
)
|
||||||
|
from minigpt4.datasets.builders.audio_text_pair_builder import (
|
||||||
|
BBCBuilder,
|
||||||
|
AudioSetBuilder,
|
||||||
|
SoundBibleBuilder,
|
||||||
|
FreeSoundBuilder
|
||||||
|
)
|
||||||
from minigpt4.common.registry import registry
|
from minigpt4.common.registry import registry
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"CCSBUBuilderImage",
|
"CCSBUBuilderImage",
|
||||||
"LaionBuilderImage",
|
"LaionBuilderImage",
|
||||||
"CCSBUAlignBuilderImage"
|
"CCSBUAlignBuilderImage",
|
||||||
|
# Audio builders
|
||||||
|
"BBCBuilder",
|
||||||
|
"AudioSetBuilder",
|
||||||
|
"SoundBibleBuilder",
|
||||||
|
"FreeSoundBuilder",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@ -31,6 +31,51 @@ class AudioBaseDatasetBuilder:
|
|||||||
|
|
||||||
self.data_type = self.config.data_type
|
self.data_type = self.config.data_type
|
||||||
|
|
||||||
|
self.audio_processors = {"train": BaseProcessor(), "eval": BaseProcessor()}
|
||||||
|
self.text_processors = {"train": BaseProcessor(), "eval": BaseProcessor()}
|
||||||
|
|
||||||
|
def build_datasets(self):
|
||||||
|
# download, split, etc...
|
||||||
|
# only called on 1 GPU/TPU in distributed
|
||||||
|
|
||||||
|
if is_main_process():
|
||||||
|
self._download_data()
|
||||||
|
|
||||||
|
if is_dist_avail_and_initialized():
|
||||||
|
dist.barrier()
|
||||||
|
|
||||||
|
# at this point, all the annotations and image/videos should be all downloaded to the specified locations.
|
||||||
|
logging.info("Building datasets...")
|
||||||
|
datasets = self.build() # dataset['train'/'val'/'test']
|
||||||
|
|
||||||
|
return datasets
|
||||||
|
|
||||||
|
def build_processors(self):
|
||||||
|
aud_proc_cfg = self.config.get("audio_processor")
|
||||||
|
txt_proc_cfg = self.config.get("text_processor")
|
||||||
|
|
||||||
|
if aud_proc_cfg is not None:
|
||||||
|
aud_train_cfg = aud_proc_cfg.get("train")
|
||||||
|
aud_eval_cfg = aud_proc_cfg.get("eval")
|
||||||
|
|
||||||
|
self.audio_processors["train"] = self._build_proc_from_cfg(aud_train_cfg)
|
||||||
|
self.audio_processors["eval"] = self._build_proc_from_cfg(aud_eval_cfg)
|
||||||
|
|
||||||
|
if txt_proc_cfg is not None:
|
||||||
|
txt_train_cfg = txt_proc_cfg.get("train")
|
||||||
|
txt_eval_cfg = txt_proc_cfg.get("eval")
|
||||||
|
|
||||||
|
self.text_processors["train"] = self._build_proc_from_cfg(txt_train_cfg)
|
||||||
|
self.text_processors["eval"] = self._build_proc_from_cfg(txt_eval_cfg)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _build_proc_from_cfg(cfg):
|
||||||
|
return (
|
||||||
|
registry.get_processor_class(cfg.name).from_config(cfg)
|
||||||
|
if cfg is not None
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def default_config_path(cls, type="default"):
|
def default_config_path(cls, type="default"):
|
||||||
return utils.get_abs_path(cls.DATASET_CONFIG_DICT[type])
|
return utils.get_abs_path(cls.DATASET_CONFIG_DICT[type])
|
||||||
@ -95,5 +140,3 @@ class AudioBaseDatasetBuilder:
|
|||||||
filename = os.path.basename(storage_path)
|
filename = os.path.basename(storage_path)
|
||||||
|
|
||||||
download_url(url=url_or_filename, root=dirname, filename=filename)
|
download_url(url=url_or_filename, root=dirname, filename=filename)
|
||||||
|
|
||||||
|
|
||||||
|
56
minigpt4/datasets/builders/audio_text_pair_builder.py
Normal file
56
minigpt4/datasets/builders/audio_text_pair_builder.py
Normal file
@ -0,0 +1,56 @@
|
|||||||
|
import os
|
||||||
|
import logging
|
||||||
|
import warnings
|
||||||
|
|
||||||
|
from minigpt4.common.registry import registry
|
||||||
|
from minigpt4.datasets.builders.audio_base_dataset_builder import AudioBaseDatasetBuilder
|
||||||
|
from minigpt4.datasets.datasets.audio_caption import GenericAudioDataset
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class GenericAudioBuilder(AudioBaseDatasetBuilder):
|
||||||
|
train_dataset_cls = GenericAudioDataset
|
||||||
|
|
||||||
|
def _download_ann(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def _download_aud(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def build(self):
|
||||||
|
self.build_processors()
|
||||||
|
|
||||||
|
build_info = self.config.build_info
|
||||||
|
|
||||||
|
datasets = dict()
|
||||||
|
split = "train"
|
||||||
|
|
||||||
|
# create datasets
|
||||||
|
dataset_cls = self.train_dataset_cls
|
||||||
|
datasets[split] = dataset_cls(
|
||||||
|
audio_processor=self.audio_processors[split],
|
||||||
|
text_processor=self.text_processors[split],
|
||||||
|
location=build_info.storage,
|
||||||
|
).inner_dataset
|
||||||
|
|
||||||
|
return datasets
|
||||||
|
|
||||||
|
|
||||||
|
@registry.register_builder("bbc")
|
||||||
|
class BBCBuilder(GenericAudioBuilder):
|
||||||
|
DATASET_CONFIG_DICT = {"default": "configs/datasets/bbc/defaults.yaml"}
|
||||||
|
|
||||||
|
|
||||||
|
@registry.register_builder("audioset")
|
||||||
|
class AudioSetBuilder(GenericAudioBuilder):
|
||||||
|
DATASET_CONFIG_DICT = {"default": "configs/datasets/audioset/defaults.yaml"}
|
||||||
|
|
||||||
|
|
||||||
|
@registry.register_builder("soundbible")
|
||||||
|
class SoundBibleBuilder(GenericAudioBuilder):
|
||||||
|
DATASET_CONFIG_DICT = {"default": "configs/datasets/soundbible/defaults.yaml"}
|
||||||
|
|
||||||
|
|
||||||
|
@registry.register_builder("freesound")
|
||||||
|
class FreeSoundBuilder(GenericAudioBuilder):
|
||||||
|
DATASET_CONFIG_DICT = {"default": "configs/datasets/freesound/defaults.yaml"}
|
@ -0,0 +1 @@
|
|||||||
|
from minigpt4.datasets.datasets.audio_caption.audio_caption_datasets import GenericAudioDataset
|
@ -6,21 +6,22 @@ from minigpt4.datasets.datasets.base_dataset import BaseDataset
|
|||||||
|
|
||||||
|
|
||||||
class GenericAudioDataset(BaseDataset):
|
class GenericAudioDataset(BaseDataset):
|
||||||
def __init__(self, vision_processor, text_processor, location):
|
def __init__(self, audio_processor, text_processor, location):
|
||||||
super().__init__(x_processor=vision_processor, text_processor=text_processor)
|
super().__init__(x_processor=audio_processor, text_processor=text_processor)
|
||||||
|
|
||||||
self.inner_dataset = wds.DataPipeline(
|
self.inner_dataset = wds.DataPipeline(
|
||||||
wds.ResampledShards(location),
|
wds.ResampledShards(location),
|
||||||
wds.tarfile_to_samples(handler=wds.warn_and_continue),
|
wds.tarfile_to_samples(handler=wds.warn_and_continue),
|
||||||
wds.shuffle(1000, handler=wds.warn_and_continue),
|
wds.shuffle(1000, handler=wds.warn_and_continue),
|
||||||
wds.decode(wds.torch_audio, handler=wds.warn_and_continue),
|
wds.decode(wds.torch_audio, handler=wds.warn_and_continue),
|
||||||
wds.to_tuple("jpg", "json", handler=wds.warn_and_continue),
|
wds.to_tuple("flac", "json", handler=wds.warn_and_continue),
|
||||||
wds.map_tuple(self.x_processor, handler=wds.warn_and_continue),
|
wds.map_tuple(self.x_processor, handler=wds.warn_and_continue),
|
||||||
wds.map(self.to_dict, handler=wds.warn_and_continue),
|
wds.map(self.to_dict, handler=wds.warn_and_continue),
|
||||||
)
|
)
|
||||||
|
|
||||||
def to_dict(self, sample):
|
def to_dict(self, sample):
|
||||||
return {
|
return {
|
||||||
"image": sample[0],
|
"audio": sample[0],
|
||||||
|
# [clips_per_video, channel, mel_bins, time_steps]
|
||||||
"text_input": self.text_processor(sample[1]["caption"]),
|
"text_input": self.text_processor(sample[1]["caption"]),
|
||||||
}
|
}
|
||||||
|
@ -21,7 +21,7 @@ class CCSBUDataset(BaseDataset):
|
|||||||
|
|
||||||
def to_dict(self, sample):
|
def to_dict(self, sample):
|
||||||
return {
|
return {
|
||||||
"image": sample[0],
|
"vision": sample[0],
|
||||||
"text_input": self.text_processor(sample[1]["caption"]),
|
"text_input": self.text_processor(sample[1]["caption"]),
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -41,7 +41,7 @@ class CCSBUAlignDatasetImageImageCaptionDataset(ImageCaptionDataset):
|
|||||||
caption = ann["caption"]
|
caption = ann["caption"]
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"image": image,
|
"vision": image,
|
||||||
"text_input": caption,
|
"text_input": caption,
|
||||||
"image_id": self.img_ids[ann["image_id"]],
|
"image_id": self.img_ids[ann["image_id"]],
|
||||||
}
|
}
|
||||||
@ -49,7 +49,7 @@ class CCSBUAlignDatasetImageImageCaptionDataset(ImageCaptionDataset):
|
|||||||
|
|
||||||
class CCDataset(BaseDataset):
|
class CCDataset(BaseDataset):
|
||||||
def __init__(self, vis_processor, text_processor, location):
|
def __init__(self, vis_processor, text_processor, location):
|
||||||
super().__init__(vis_processor=vis_processor, text_processor=text_processor)
|
super().__init__(x_processor=vis_processor, text_processor=text_processor)
|
||||||
|
|
||||||
self.inner_dataset = wds.DataPipeline(
|
self.inner_dataset = wds.DataPipeline(
|
||||||
wds.ResampledShards(location),
|
wds.ResampledShards(location),
|
||||||
@ -57,12 +57,12 @@ class CCDataset(BaseDataset):
|
|||||||
wds.shuffle(1000, handler=wds.warn_and_continue),
|
wds.shuffle(1000, handler=wds.warn_and_continue),
|
||||||
wds.decode("pilrgb", handler=wds.warn_and_continue),
|
wds.decode("pilrgb", handler=wds.warn_and_continue),
|
||||||
wds.to_tuple("jpg", "txt", handler=wds.warn_and_continue),
|
wds.to_tuple("jpg", "txt", handler=wds.warn_and_continue),
|
||||||
wds.map_tuple(self.vis_processor, handler=wds.warn_and_continue),
|
wds.map_tuple(self.x_processor, handler=wds.warn_and_continue),
|
||||||
wds.map(self.to_dict, handler=wds.warn_and_continue),
|
wds.map(self.to_dict, handler=wds.warn_and_continue),
|
||||||
)
|
)
|
||||||
|
|
||||||
def to_dict(self, sample):
|
def to_dict(self, sample):
|
||||||
return {
|
return {
|
||||||
"image": sample[0],
|
"vision": sample[0],
|
||||||
"text_input": sample[1],
|
"text_input": sample[1],
|
||||||
}
|
}
|
||||||
|
@ -42,7 +42,7 @@ class ImageCaptionDataset(BaseDataset, __ImageDisplMixin):
|
|||||||
caption = self.text_processor(ann["caption"])
|
caption = self.text_processor(ann["caption"])
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"image": image,
|
"vision": image,
|
||||||
"text_input": caption,
|
"text_input": caption,
|
||||||
"image_id": self.img_ids[ann["image_id"]],
|
"image_id": self.img_ids[ann["image_id"]],
|
||||||
}
|
}
|
||||||
@ -67,7 +67,7 @@ class CaptionEvalDataset(BaseDataset, __ImageDisplMixin):
|
|||||||
image = self.x_processor(image)
|
image = self.x_processor(image)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"image": image,
|
"vision": image,
|
||||||
"image_id": ann["image_id"],
|
"image_id": ann["image_id"],
|
||||||
"instance_id": ann["instance_id"],
|
"instance_id": ann["instance_id"],
|
||||||
}
|
}
|
||||||
|
@ -25,7 +25,7 @@ class LaionDataset(BaseDataset):
|
|||||||
|
|
||||||
def to_dict(self, sample):
|
def to_dict(self, sample):
|
||||||
return {
|
return {
|
||||||
"image": sample[0],
|
"vision": sample[0],
|
||||||
"text_input": self.text_processor(sample[1]["caption"]),
|
"text_input": self.text_processor(sample[1]["caption"]),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -9,7 +9,7 @@ class __ImageDisplMixin:
|
|||||||
{
|
{
|
||||||
"file": ann["image"],
|
"file": ann["image"],
|
||||||
"caption": ann["caption"],
|
"caption": ann["caption"],
|
||||||
"image": sample["image"],
|
"vision": sample["vision"],
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -1,8 +1,9 @@
|
|||||||
import random
|
import random
|
||||||
from typing import Dict, Tuple
|
from typing import Dict, Tuple, List
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
import re
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from transformers import LlamaTokenizer
|
from transformers import LlamaTokenizer
|
||||||
|
|
||||||
@ -12,6 +13,36 @@ from minigpt4.models.blip2 import BaseModel
|
|||||||
from minigpt4.models.modeling_llama import LlamaForCausalLM
|
from minigpt4.models.modeling_llama import LlamaForCausalLM
|
||||||
|
|
||||||
|
|
||||||
|
def filter_prompt(input_embeds: Dict[str, Tensor], prompt_list: List[str]) -> List[str]:
|
||||||
|
if not prompt_list:
|
||||||
|
return prompt_list
|
||||||
|
input_modal_set = set([k.title() for k in input_embeds if input_embeds[k] is not None])
|
||||||
|
prompt_modal_sets = [set(re.findall("<([^<>]+)><ModalityHere></\\1>", prompt)) for prompt in prompt_list]
|
||||||
|
results = [prompt_list[i] for i, prompt_modal_set in enumerate(prompt_modal_sets) if
|
||||||
|
prompt_modal_set == input_modal_set]
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
def arrange_modalities(input_embeds: Dict[str, Tensor], prompt: str) -> List[Tensor]:
|
||||||
|
prompt_modalities = re.findall("<([^<>]+)><ModalityHere></\\1>", prompt)
|
||||||
|
return [input_embeds[modality.lower()] for modality in prompt_modalities]
|
||||||
|
|
||||||
|
|
||||||
|
def concat_all_embeddings(input_embeds: Dict[str, Tensor], dim: int) -> Tensor:
|
||||||
|
embeds = [input_embeds[key] for key in input_embeds if input_embeds[key] is not None]
|
||||||
|
return torch.cat(embeds, dim=dim)
|
||||||
|
|
||||||
|
|
||||||
|
def filter_modalities(inputs):
|
||||||
|
filtered_inputs = {}
|
||||||
|
|
||||||
|
for k in ModalityType.__dict__.values():
|
||||||
|
if k in inputs:
|
||||||
|
filtered_inputs[k] = inputs[k]
|
||||||
|
|
||||||
|
return filtered_inputs
|
||||||
|
|
||||||
|
|
||||||
@registry.register_model("bind_gpt4")
|
@registry.register_model("bind_gpt4")
|
||||||
class BindGPT4(BaseModel):
|
class BindGPT4(BaseModel):
|
||||||
"""
|
"""
|
||||||
@ -61,7 +92,9 @@ class BindGPT4(BaseModel):
|
|||||||
print('Loading Q-Former and Adapter/Projector')
|
print('Loading Q-Former and Adapter/Projector')
|
||||||
self.multimodal_joiner = ImageBindJoiner(vision_query_token_num=num_query_token,
|
self.multimodal_joiner = ImageBindJoiner(vision_query_token_num=num_query_token,
|
||||||
vision_qformer_frozen=freeze_qformer,
|
vision_qformer_frozen=freeze_qformer,
|
||||||
vision_post_dims=[768, self.llama_model.config.hidden_size]
|
vision_post_dims=[768, self.llama_model.config.hidden_size],
|
||||||
|
audio_query_token_num=num_query_token,
|
||||||
|
audio_post_dims=[768, self.llama_model.config.hidden_size]
|
||||||
# vision_qformer_model=q_former_model,
|
# vision_qformer_model=q_former_model,
|
||||||
# vision_pre_dims=(1280, 1408)
|
# vision_pre_dims=(1280, 1408)
|
||||||
)
|
)
|
||||||
@ -84,42 +117,37 @@ class BindGPT4(BaseModel):
|
|||||||
def encode_inputs(self, inputs: Dict[str, Tensor]) -> Dict[str, Tensor]:
|
def encode_inputs(self, inputs: Dict[str, Tensor]) -> Dict[str, Tensor]:
|
||||||
imagebind_outputs = self.multimodal_encoder(inputs)
|
imagebind_outputs = self.multimodal_encoder(inputs)
|
||||||
llama_inputs = self.multimodal_joiner(imagebind_outputs)
|
llama_inputs = self.multimodal_joiner(imagebind_outputs)
|
||||||
# NOTE: only accept image here
|
|
||||||
return llama_inputs
|
return llama_inputs
|
||||||
|
|
||||||
def prompt_wrap(self, inputs: Dict[str, Tensor], modality_name: str, prompt: str) -> Tuple[Tensor, Tensor]:
|
def prompt_wrap(self, inputs: Dict[str, Tensor], prompt: str) -> Tuple[Tensor, Tensor]:
|
||||||
# TODO: Accept More Modalities.
|
if not prompt:
|
||||||
input_embeds = inputs[modality_name]
|
input_embeds = concat_all_embeddings(inputs, dim=1)
|
||||||
attns_input = torch.ones(input_embeds.size()[:-1], dtype=torch.long).to(input_embeds.device)
|
attns_input = torch.ones(input_embeds.size()[:-1], dtype=torch.long).to(input_embeds.device)
|
||||||
if prompt:
|
|
||||||
batch_size = input_embeds.shape[0]
|
|
||||||
p_before, p_after = prompt.split('<ModalityHere>')
|
|
||||||
p_before_tokens = self.llama_tokenizer(
|
|
||||||
p_before, return_tensors="pt", add_special_tokens=False).to(input_embeds.device)
|
|
||||||
p_after_tokens = self.llama_tokenizer(
|
|
||||||
p_after, return_tensors="pt", add_special_tokens=False).to(input_embeds.device)
|
|
||||||
p_before_embeds = self.llama_model.model.embed_tokens(p_before_tokens.input_ids).expand(batch_size, -1, -1)
|
|
||||||
p_after_embeds = self.llama_model.model.embed_tokens(p_after_tokens.input_ids).expand(batch_size, -1, -1)
|
|
||||||
wrapped_input_embeds = torch.cat([p_before_embeds, inputs, p_after_embeds], dim=1)
|
|
||||||
wrapped_atts_input = attns_input[:, :1].expand(-1, wrapped_input_embeds.shape[1])
|
|
||||||
return wrapped_input_embeds, wrapped_atts_input
|
|
||||||
else:
|
|
||||||
return input_embeds, attns_input
|
return input_embeds, attns_input
|
||||||
|
input_embeds_list = arrange_modalities(inputs, prompt)
|
||||||
|
batch_size = input_embeds_list[0].shape[0]
|
||||||
|
prompt_slices = prompt.split('<ModalityHere>')
|
||||||
|
prompt_tokens = [self.llama_tokenizer(prompt_slice, return_tensors="pt", add_special_tokens=False)
|
||||||
|
.to(input_embeds_list[0].device) for prompt_slice in prompt_slices]
|
||||||
|
prompt_embeds = [self.llama_model.model.embed_tokens(prompt_token.input_ids).expand(batch_size, -1, -1)
|
||||||
|
for prompt_token in prompt_tokens]
|
||||||
|
result_embeds = [emb for pair in zip(prompt_embeds[:-1], input_embeds_list)
|
||||||
|
for emb in pair] + [prompt_embeds[-1]]
|
||||||
|
wrapped_input_embeds = torch.cat(result_embeds, dim=1)
|
||||||
|
wrapped_atts_input = torch.ones(wrapped_input_embeds.size()[:-1],
|
||||||
|
dtype=torch.long).to(wrapped_input_embeds.device)
|
||||||
|
return wrapped_input_embeds, wrapped_atts_input
|
||||||
|
|
||||||
def forward(self, inputs: Dict[str, Tensor]) -> Dict[str, Tensor]:
|
def forward(self, inputs: Dict[str, Tensor]) -> Dict[str, Tensor]:
|
||||||
"""
|
# filter `inputs` as it may contain informatioins other than modalities
|
||||||
TODO: More Modalities.
|
modality_inputs = filter_modalities(inputs)
|
||||||
Only accept image inputs here.
|
embeds = self.encode_inputs(modality_inputs)
|
||||||
Other modalities will conflict with the pre-defined prompt and wrapping strategy.
|
filtered_prompts = filter_prompt(embeds, self.prompt_list)
|
||||||
"""
|
if filtered_prompts:
|
||||||
bind_inputs = {ModalityType.VISION: inputs['image']}
|
prompt = random.choice(filtered_prompts)
|
||||||
embeds = self.encode_inputs(bind_inputs)
|
|
||||||
# assert "vision" in embeds, "Only Vision Input Can Be Accepted Now."
|
|
||||||
if self.prompt_list:
|
|
||||||
prompt = random.choice(self.prompt_list)
|
|
||||||
else:
|
else:
|
||||||
prompt = None
|
prompt = None
|
||||||
img_embeds, atts_img = self.prompt_wrap(embeds, ModalityType.VISION, prompt)
|
input_embs, input_atts = self.prompt_wrap(embeds, prompt)
|
||||||
|
|
||||||
# NOTE: No modifications from the next line to the end. Except for the autocast part.
|
# NOTE: No modifications from the next line to the end. Except for the autocast part.
|
||||||
|
|
||||||
@ -134,28 +162,28 @@ class BindGPT4(BaseModel):
|
|||||||
truncation=True,
|
truncation=True,
|
||||||
max_length=self.max_txt_len,
|
max_length=self.max_txt_len,
|
||||||
add_special_tokens=False
|
add_special_tokens=False
|
||||||
).to(img_embeds.device)
|
).to(input_embs.device)
|
||||||
|
|
||||||
targets = to_regress_tokens.input_ids.masked_fill(
|
targets = to_regress_tokens.input_ids.masked_fill(
|
||||||
to_regress_tokens.input_ids == self.llama_tokenizer.pad_token_id, -100
|
to_regress_tokens.input_ids == self.llama_tokenizer.pad_token_id, -100
|
||||||
)
|
)
|
||||||
|
|
||||||
empty_targets = (
|
empty_targets = (
|
||||||
torch.ones([atts_img.shape[0], atts_img.shape[1] + 1],
|
torch.ones([input_atts.shape[0], input_atts.shape[1] + 1],
|
||||||
dtype=torch.long).to(img_embeds.device).fill_(-100) # plus one for bos
|
dtype=torch.long).to(input_embs.device).fill_(-100) # plus one for bos
|
||||||
)
|
)
|
||||||
targets = torch.cat([empty_targets, targets], dim=1)
|
targets = torch.cat([empty_targets, targets], dim=1)
|
||||||
|
|
||||||
batch_size = img_embeds.shape[0]
|
batch_size = input_embs.shape[0]
|
||||||
bos = torch.ones([batch_size, 1],
|
bos = torch.ones([batch_size, 1],
|
||||||
dtype=to_regress_tokens.input_ids.dtype,
|
dtype=to_regress_tokens.input_ids.dtype,
|
||||||
device=to_regress_tokens.input_ids.device) * self.llama_tokenizer.bos_token_id
|
device=to_regress_tokens.input_ids.device) * self.llama_tokenizer.bos_token_id
|
||||||
bos_embeds = self.llama_model.model.embed_tokens(bos)
|
bos_embeds = self.llama_model.model.embed_tokens(bos)
|
||||||
atts_bos = atts_img[:, :1]
|
atts_bos = input_atts[:, :1]
|
||||||
|
|
||||||
to_regress_embeds = self.llama_model.model.embed_tokens(to_regress_tokens.input_ids)
|
to_regress_embeds = self.llama_model.model.embed_tokens(to_regress_tokens.input_ids)
|
||||||
inputs_embeds = torch.cat([bos_embeds, img_embeds, to_regress_embeds], dim=1)
|
inputs_embeds = torch.cat([bos_embeds, input_embs, to_regress_embeds], dim=1)
|
||||||
attention_mask = torch.cat([atts_bos, atts_img, to_regress_tokens.attention_mask], dim=1)
|
attention_mask = torch.cat([atts_bos, input_atts, to_regress_tokens.attention_mask], dim=1)
|
||||||
|
|
||||||
outputs = self.llama_model(
|
outputs = self.llama_model(
|
||||||
inputs_embeds=inputs_embeds,
|
inputs_embeds=inputs_embeds,
|
||||||
|
@ -144,7 +144,7 @@ def compute_sim_matrix(model, data_loader, **kwargs):
|
|||||||
vit_feats = []
|
vit_feats = []
|
||||||
image_embeds = []
|
image_embeds = []
|
||||||
for samples in data_loader:
|
for samples in data_loader:
|
||||||
image = samples["image"]
|
image = samples["vision"]
|
||||||
|
|
||||||
image = image.to(model.device)
|
image = image.to(model.device)
|
||||||
image_feat, vit_feat = model.forward_image(image)
|
image_feat, vit_feat = model.forward_image(image)
|
||||||
|
@ -164,7 +164,7 @@ class MiniGPT4(Blip2Base):
|
|||||||
return img_embeds, atts_img
|
return img_embeds, atts_img
|
||||||
|
|
||||||
def forward(self, samples):
|
def forward(self, samples):
|
||||||
image = samples["image"]
|
image = samples["vision"]
|
||||||
img_embeds, atts_img = self.encode_img(image)
|
img_embeds, atts_img = self.encode_img(image)
|
||||||
if hasattr(samples, 'question_split'): # VQA dataset
|
if hasattr(samples, 'question_split'): # VQA dataset
|
||||||
print('VQA Batch')
|
print('VQA Batch')
|
||||||
|
@ -11,11 +11,15 @@ from minigpt4.processors.blip_processors import (
|
|||||||
Blip2ImageEvalProcessor,
|
Blip2ImageEvalProcessor,
|
||||||
BlipCaptionProcessor,
|
BlipCaptionProcessor,
|
||||||
)
|
)
|
||||||
from minigpt4.processors.imagebind_processor import (
|
from minigpt4.processors.imagebind_vision_processor import (
|
||||||
ImageBindCaptionProcessor,
|
ImageBindCaptionProcessor,
|
||||||
ImageBindVisionTrainProcessor,
|
ImageBindVisionTrainProcessor,
|
||||||
ImageBindVisionEvalProcessor
|
ImageBindVisionEvalProcessor
|
||||||
)
|
)
|
||||||
|
from minigpt4.processors.imagebind_audio_processor import (
|
||||||
|
ImageBindAudioTrainProcessor,
|
||||||
|
ImageBindAudioEvalProcessor,
|
||||||
|
)
|
||||||
|
|
||||||
from minigpt4.common.registry import registry
|
from minigpt4.common.registry import registry
|
||||||
|
|
||||||
@ -26,7 +30,9 @@ __all__ = [
|
|||||||
"BlipCaptionProcessor",
|
"BlipCaptionProcessor",
|
||||||
"ImageBindCaptionProcessor",
|
"ImageBindCaptionProcessor",
|
||||||
"ImageBindVisionTrainProcessor",
|
"ImageBindVisionTrainProcessor",
|
||||||
"ImageBindVisionEvalProcessor"
|
"ImageBindVisionEvalProcessor",
|
||||||
|
"ImageBindAudioTrainProcessor",
|
||||||
|
"ImageBindAudioEvalProcessor",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
190
minigpt4/processors/audio_augment.py
Normal file
190
minigpt4/processors/audio_augment.py
Normal file
@ -0,0 +1,190 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# @Author : Xinhao Mei @CVSSP, University of Surrey
|
||||||
|
# @E-mail : x.mei@surrey.ac.uk
|
||||||
|
|
||||||
|
"""
|
||||||
|
Implemenation of SpecAugment++,
|
||||||
|
Adapated from Qiuqiang Kong's trochlibrosa:
|
||||||
|
https://github.com/qiuqiangkong/torchlibrosa/blob/master/torchlibrosa/augmentation.py
|
||||||
|
"""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
|
||||||
|
class DropStripes:
|
||||||
|
|
||||||
|
def __init__(self, dim, drop_width, stripes_num):
|
||||||
|
""" Drop stripes.
|
||||||
|
args:
|
||||||
|
dim: int, dimension along which to drop
|
||||||
|
drop_width: int, maximum width of stripes to drop
|
||||||
|
stripes_num: int, how many stripes to drop
|
||||||
|
"""
|
||||||
|
super(DropStripes, self).__init__()
|
||||||
|
|
||||||
|
assert dim in [2, 3] # dim 2: time; dim 3: frequency
|
||||||
|
|
||||||
|
self.dim = dim
|
||||||
|
self.drop_width = drop_width
|
||||||
|
self.stripes_num = stripes_num
|
||||||
|
|
||||||
|
def __call__(self, input):
|
||||||
|
"""input: (batch_size, channels, time_steps, freq_bins)"""
|
||||||
|
|
||||||
|
assert input.ndimension() == 4
|
||||||
|
batch_size = input.shape[0]
|
||||||
|
total_width = input.shape[self.dim]
|
||||||
|
|
||||||
|
for n in range(batch_size):
|
||||||
|
self.transform_slice(input[n], total_width)
|
||||||
|
|
||||||
|
return input
|
||||||
|
|
||||||
|
def transform_slice(self, e, total_width):
|
||||||
|
""" e: (channels, time_steps, freq_bins)"""
|
||||||
|
|
||||||
|
for _ in range(self.stripes_num):
|
||||||
|
distance = torch.randint(low=0, high=self.drop_width, size=(1,))[0]
|
||||||
|
bgn = torch.randint(low=0, high=total_width - distance, size=(1,))[0]
|
||||||
|
|
||||||
|
if self.dim == 2:
|
||||||
|
e[:, bgn: bgn + distance, :] = 0
|
||||||
|
elif self.dim == 3:
|
||||||
|
e[:, :, bgn: bgn + distance] = 0
|
||||||
|
|
||||||
|
|
||||||
|
class MixStripes:
|
||||||
|
|
||||||
|
def __init__(self, dim, mix_width, stripes_num):
|
||||||
|
""" Mix stripes
|
||||||
|
args:
|
||||||
|
dim: int, dimension along which to mix
|
||||||
|
mix_width: int, maximum width of stripes to mix
|
||||||
|
stripes_num: int, how many stripes to mix
|
||||||
|
"""
|
||||||
|
|
||||||
|
super(MixStripes, self).__init__()
|
||||||
|
|
||||||
|
assert dim in [2, 3]
|
||||||
|
|
||||||
|
self.dim = dim
|
||||||
|
self.mix_width = mix_width
|
||||||
|
self.stripes_num = stripes_num
|
||||||
|
|
||||||
|
def __call__(self, input):
|
||||||
|
"""input: (batch_size, channel, time_steps, freq_bins)"""
|
||||||
|
|
||||||
|
assert input.ndimension() == 4
|
||||||
|
|
||||||
|
batch_size = input.shape[0]
|
||||||
|
total_width = input.shape[self.dim]
|
||||||
|
|
||||||
|
rand_sample = input[torch.randperm(batch_size)]
|
||||||
|
for i in range(batch_size):
|
||||||
|
self.transform_slice(input[i], rand_sample[i], total_width)
|
||||||
|
return input
|
||||||
|
|
||||||
|
def transform_slice(self, input, random_sample, total_width):
|
||||||
|
|
||||||
|
for _ in range(self.stripes_num):
|
||||||
|
distance = torch.randint(low=0, high=self.mix_width, size=(1,))[0]
|
||||||
|
bgn = torch.randint(low=0, high=total_width - distance, size=(1,))[0]
|
||||||
|
|
||||||
|
if self.dim == 2:
|
||||||
|
input[:, bgn: bgn + distance, :] = 0.5 * input[:, bgn: bgn + distance, :] + \
|
||||||
|
0.5 * random_sample[:, bgn: bgn + distance, :]
|
||||||
|
elif self.dim == 3:
|
||||||
|
input[:, :, bgn: bgn + distance] = 0.5 * input[:, :, bgn: bgn + distance] + \
|
||||||
|
0.5 * random_sample[:, :, bgn: bgn + distance]
|
||||||
|
|
||||||
|
|
||||||
|
class CutStripes:
|
||||||
|
|
||||||
|
def __init__(self, dim, cut_width, stripes_num):
|
||||||
|
""" Cutting stripes with another randomly selected sample in mini-batch.
|
||||||
|
args:
|
||||||
|
dim: int, dimension along which to cut
|
||||||
|
cut_width: int, maximum width of stripes to cut
|
||||||
|
stripes_num: int, how many stripes to cut
|
||||||
|
"""
|
||||||
|
|
||||||
|
super(CutStripes, self).__init__()
|
||||||
|
|
||||||
|
assert dim in [2, 3]
|
||||||
|
|
||||||
|
self.dim = dim
|
||||||
|
self.cut_width = cut_width
|
||||||
|
self.stripes_num = stripes_num
|
||||||
|
|
||||||
|
def __call__(self, input):
|
||||||
|
"""input: (batch_size, channel, time_steps, freq_bins)"""
|
||||||
|
|
||||||
|
assert input.ndimension() == 4
|
||||||
|
|
||||||
|
batch_size = input.shape[0]
|
||||||
|
total_width = input.shape[self.dim]
|
||||||
|
|
||||||
|
rand_sample = input[torch.randperm(batch_size)]
|
||||||
|
for i in range(batch_size):
|
||||||
|
self.transform_slice(input[i], rand_sample[i], total_width)
|
||||||
|
return input
|
||||||
|
|
||||||
|
def transform_slice(self, input, random_sample, total_width):
|
||||||
|
|
||||||
|
for _ in range(self.stripes_num):
|
||||||
|
distance = torch.randint(low=0, high=self.cut_width, size=(1,))[0]
|
||||||
|
bgn = torch.randint(low=0, high=total_width - distance, size=(1,))[0]
|
||||||
|
|
||||||
|
if self.dim == 2:
|
||||||
|
input[:, bgn: bgn + distance, :] = random_sample[:, bgn: bgn + distance, :]
|
||||||
|
elif self.dim == 3:
|
||||||
|
input[:, :, bgn: bgn + distance] = random_sample[:, :, bgn: bgn + distance]
|
||||||
|
|
||||||
|
|
||||||
|
class SpecAugmentation:
|
||||||
|
|
||||||
|
def __init__(self, time_drop_width, time_stripes_num, freq_drop_width, freq_stripes_num,
|
||||||
|
mask_type='mixture'):
|
||||||
|
"""Spec augmetation and SpecAugment++.
|
||||||
|
[ref] Park, D.S., Chan, W., Zhang, Y., Chiu, C.C., Zoph, B., Cubuk, E.D.
|
||||||
|
and Le, Q.V., 2019. Specaugment: A simple data augmentation method
|
||||||
|
for automatic speech recognition. arXiv preprint arXiv:1904.08779.
|
||||||
|
[ref] Wang H, Zou Y, Wang W., 2021. SpecAugment++: A Hidden Space
|
||||||
|
Data Augmentation Method for Acoustic Scene Classification. arXiv
|
||||||
|
preprint arXiv:2103.16858.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
time_drop_width: int
|
||||||
|
time_stripes_num: int
|
||||||
|
freq_drop_width: int
|
||||||
|
freq_stripes_num: int
|
||||||
|
mask_type: str, mask type in SpecAugment++ (zero_value, mixture, cutting)
|
||||||
|
"""
|
||||||
|
|
||||||
|
super(SpecAugmentation, self).__init__()
|
||||||
|
|
||||||
|
if mask_type == 'zero_value':
|
||||||
|
self.time_augmentator = DropStripes(dim=2, drop_width=time_drop_width,
|
||||||
|
stripes_num=time_stripes_num)
|
||||||
|
self.freq_augmentator = DropStripes(dim=3, drop_width=freq_drop_width,
|
||||||
|
stripes_num=freq_stripes_num)
|
||||||
|
elif mask_type == 'mixture':
|
||||||
|
self.time_augmentator = MixStripes(dim=2, mix_width=time_drop_width,
|
||||||
|
stripes_num=time_stripes_num)
|
||||||
|
self.freq_augmentator = MixStripes(dim=3, mix_width=freq_drop_width,
|
||||||
|
stripes_num=freq_stripes_num)
|
||||||
|
elif mask_type == 'cutting':
|
||||||
|
self.time_augmentator = CutStripes(dim=2, cut_width=time_drop_width,
|
||||||
|
stripes_num=time_stripes_num)
|
||||||
|
self.freq_augmentator = CutStripes(dim=3, cut_width=freq_drop_width,
|
||||||
|
stripes_num=freq_stripes_num)
|
||||||
|
else:
|
||||||
|
raise NameError('No such mask type in SpecAugment++')
|
||||||
|
|
||||||
|
def __call__(self, inputs):
|
||||||
|
# x should be in size [batch_size, channel, time_steps, freq_bins]
|
||||||
|
x = self.time_augmentator(inputs)
|
||||||
|
x = self.freq_augmentator(x)
|
||||||
|
return x
|
@ -9,7 +9,7 @@ import re
|
|||||||
|
|
||||||
from minigpt4.common.registry import registry
|
from minigpt4.common.registry import registry
|
||||||
from minigpt4.processors.base_processor import BaseProcessor
|
from minigpt4.processors.base_processor import BaseProcessor
|
||||||
from minigpt4.processors.randaugment import RandomAugment
|
from minigpt4.processors.vision_augment import RandomAugment
|
||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
from torchvision import transforms
|
from torchvision import transforms
|
||||||
from torchvision.transforms.functional import InterpolationMode
|
from torchvision.transforms.functional import InterpolationMode
|
||||||
|
166
minigpt4/processors/imagebind_audio_processor.py
Normal file
166
minigpt4/processors/imagebind_audio_processor.py
Normal file
@ -0,0 +1,166 @@
|
|||||||
|
from typing import Union, List
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torchaudio
|
||||||
|
from omegaconf import OmegaConf
|
||||||
|
from pytorchvideo.data.clip_sampling import ConstantClipsPerVideoSampler, RandomMultiClipSampler
|
||||||
|
from torch import Tensor
|
||||||
|
|
||||||
|
from imagebind.data.data_utils import waveform2melspec, get_constant_clip_timepoints, \
|
||||||
|
get_random_clip_timepoints
|
||||||
|
from minigpt4.processors.base_processor import BaseProcessor
|
||||||
|
from torchvision import transforms
|
||||||
|
from minigpt4.common.registry import registry
|
||||||
|
from minigpt4.processors.audio_augment import SpecAugmentation
|
||||||
|
|
||||||
|
|
||||||
|
class ImageBindAudioBaseProcessor(BaseProcessor):
|
||||||
|
def __init__(self, mean=None, std=None, target_sr=None, clip_duration=None, clips_per_video=None,
|
||||||
|
num_mel_bins=None, target_length=None, clip_sample_method="Random"):
|
||||||
|
super().__init__()
|
||||||
|
self.mean = -4.268 if mean is None else mean
|
||||||
|
self.std = 9.138 if std is None else std
|
||||||
|
self.target_sr = 16000 if target_sr is None else target_sr
|
||||||
|
self.num_mel_bins = num_mel_bins
|
||||||
|
self.target_length = target_length
|
||||||
|
self.clip_sampler = self._construct_clip_sampler(clip_duration, clips_per_video, clip_sample_method)
|
||||||
|
self.normalize = transforms.Normalize(self.mean, self.std)
|
||||||
|
|
||||||
|
def _construct_clip_sampler(self, clip_duration, clips_per_video, clip_sample_method):
|
||||||
|
if clip_duration is None or clips_per_video is None:
|
||||||
|
return None
|
||||||
|
if clip_sample_method == "Constant":
|
||||||
|
return ConstantClipsPerVideoSampler(
|
||||||
|
clip_duration=clip_duration, clips_per_video=clips_per_video
|
||||||
|
)
|
||||||
|
elif clip_sample_method == "Random":
|
||||||
|
return RandomMultiClipSampler(clip_duration=clip_duration, num_clips=clips_per_video)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def waveform_resample(self, waveform: Tensor, origin_sr: int) -> Tensor:
|
||||||
|
return torchaudio.functional.resample(
|
||||||
|
waveform, orig_freq=origin_sr, new_freq=self.target_sr
|
||||||
|
)
|
||||||
|
|
||||||
|
def clip_sample(self, waveform: Tensor) -> List[Tensor]:
|
||||||
|
if self.clip_sampler is None:
|
||||||
|
return [waveform]
|
||||||
|
elif isinstance(self.clip_sampler, ConstantClipsPerVideoSampler):
|
||||||
|
all_clips_timepoints = get_constant_clip_timepoints(self.clip_sampler, waveform.size(1) / self.target_sr)
|
||||||
|
elif isinstance(self.clip_sampler, RandomMultiClipSampler):
|
||||||
|
all_clips_timepoints = get_random_clip_timepoints(self.clip_sampler, waveform.size(1) / self.target_sr)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
||||||
|
all_clips = []
|
||||||
|
for clip_timepoints in all_clips_timepoints:
|
||||||
|
start_pos = int(clip_timepoints[0] * self.target_sr)
|
||||||
|
end_pos = int(clip_timepoints[1] * self.target_sr)
|
||||||
|
waveform_clip = waveform[:, start_pos: end_pos]
|
||||||
|
all_clips.append(waveform_clip)
|
||||||
|
return all_clips
|
||||||
|
|
||||||
|
def waveform_melspec(self, waveforms: Union[List[Tensor], Tensor]) -> List[Tensor]:
|
||||||
|
if isinstance(waveforms, Tensor):
|
||||||
|
return waveform2melspec(waveforms, self.target_sr, self.num_mel_bins, self.target_length)
|
||||||
|
else:
|
||||||
|
return [waveform2melspec(waveform, self.target_sr, self.num_mel_bins, self.target_length)
|
||||||
|
for waveform in waveforms]
|
||||||
|
|
||||||
|
|
||||||
|
@registry.register_processor("imagebind_audio_train")
|
||||||
|
class ImageBindAudioTrainProcessor(ImageBindAudioBaseProcessor):
|
||||||
|
def __init__(self, mean=None, std=None, target_sr=None, clip_duration=None, clips_per_video=None,
|
||||||
|
clip_sample_method="Random", num_mel_bins=None, target_length=None, time_drop_width=13,
|
||||||
|
time_stripes_num=2, freq_drop_width=8, freq_stripes_num=2, mask_type='mixture'):
|
||||||
|
super().__init__(mean=mean, std=std, target_sr=target_sr,
|
||||||
|
clip_duration=clip_duration, clips_per_video=clips_per_video,
|
||||||
|
num_mel_bins=num_mel_bins, target_length=target_length,
|
||||||
|
clip_sample_method=clip_sample_method)
|
||||||
|
self.spec_augment = SpecAugmentation(time_drop_width, time_stripes_num,
|
||||||
|
freq_drop_width, freq_stripes_num, mask_type)
|
||||||
|
|
||||||
|
def __call__(self, item):
|
||||||
|
# item: Tuple[Tensor, int]
|
||||||
|
waveform, origin_sr = item[0], item[1]
|
||||||
|
waveform = self.waveform_resample(waveform, origin_sr)
|
||||||
|
waveform_clips = self.clip_sample(waveform)
|
||||||
|
melspec_clips = self.waveform_melspec(waveform_clips)
|
||||||
|
normed_melspecs = [self.normalize(clip) for clip in melspec_clips]
|
||||||
|
all_clips = torch.stack(normed_melspecs, dim=0)
|
||||||
|
# all_clips: [clips_per_video, channel, mel_bins, time_steps]
|
||||||
|
# augment: [batch_size, channel, time_steps, freq_bins]
|
||||||
|
augmented_clips = self.spec_augment(all_clips.transpose(-2, -1)).transpose(-2, -1)
|
||||||
|
return augmented_clips
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_config(cls, cfg=None):
|
||||||
|
if cfg is None:
|
||||||
|
cfg = OmegaConf.create()
|
||||||
|
|
||||||
|
target_sr = cfg.get("target_sr", 16000)
|
||||||
|
clip_duration = cfg.get("clip_duration", 2)
|
||||||
|
clips_per_video = cfg.get("clips_per_video", 3)
|
||||||
|
num_mel_bins = cfg.get("num_mel_bins", 128)
|
||||||
|
target_length = cfg.get("target_length", 204)
|
||||||
|
time_drop_width = cfg.get("time_drop_width", 13)
|
||||||
|
time_stripes_num = cfg.get("time_stripes_num", 2)
|
||||||
|
# 13 * 2 / 204 = 12.75% Time Mask
|
||||||
|
freq_drop_width = cfg.get("freq_drop_width", 8)
|
||||||
|
freq_stripes_num = cfg.get("freq_stripes_num", 2)
|
||||||
|
# 8 * 2 / 128 = 12.5% Freq Mask
|
||||||
|
mask_type = cfg.get("mask_type", 'mixture')
|
||||||
|
|
||||||
|
mean = cfg.get("mean", None)
|
||||||
|
std = cfg.get("std", None)
|
||||||
|
|
||||||
|
return cls(
|
||||||
|
mean=mean, std=std, target_sr=target_sr,
|
||||||
|
clip_duration=clip_duration, clips_per_video=clips_per_video,
|
||||||
|
num_mel_bins=num_mel_bins, target_length=target_length,
|
||||||
|
time_drop_width=time_drop_width, time_stripes_num=time_stripes_num,
|
||||||
|
freq_drop_width=freq_drop_width, freq_stripes_num=freq_stripes_num,
|
||||||
|
mask_type=mask_type
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@registry.register_processor("imagebind_audio_eval")
|
||||||
|
class ImageBindAudioEvalProcessor(ImageBindAudioBaseProcessor):
|
||||||
|
def __init__(self, mean=None, std=None, target_sr=None, clip_duration=None, clips_per_video=None,
|
||||||
|
clip_sample_method="Constant", num_mel_bins=None, target_length=None):
|
||||||
|
super().__init__(mean=mean, std=std, target_sr=target_sr,
|
||||||
|
clip_duration=clip_duration, clips_per_video=clips_per_video,
|
||||||
|
num_mel_bins=num_mel_bins, target_length=target_length,
|
||||||
|
clip_sample_method=clip_sample_method)
|
||||||
|
|
||||||
|
def __call__(self, item):
|
||||||
|
# item: Tuple[Tensor, int]
|
||||||
|
waveform, origin_sr = item[0], item[1]
|
||||||
|
waveform = self.waveform_resample(waveform, origin_sr)
|
||||||
|
waveform_clips = self.clip_sample(waveform)
|
||||||
|
melspec_clips = self.waveform_melspec(waveform_clips)
|
||||||
|
normed_melspecs = [self.normalize(clip) for clip in melspec_clips]
|
||||||
|
all_clips = torch.stack(normed_melspecs, dim=0)
|
||||||
|
# all_clips: [clips_per_video, channel, mel_bins, time_steps]
|
||||||
|
return all_clips
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_config(cls, cfg=None):
|
||||||
|
if cfg is None:
|
||||||
|
cfg = OmegaConf.create()
|
||||||
|
|
||||||
|
target_sr = cfg.get("target_sr", 16000)
|
||||||
|
clip_duration = cfg.get("clip_duration", 2)
|
||||||
|
clips_per_video = cfg.get("clips_per_video", 3)
|
||||||
|
num_mel_bins = cfg.get("num_mel_bins", 128)
|
||||||
|
target_length = cfg.get("target_length", 204)
|
||||||
|
|
||||||
|
mean = cfg.get("mean", None)
|
||||||
|
std = cfg.get("std", None)
|
||||||
|
|
||||||
|
return cls(
|
||||||
|
mean=mean, std=std, target_sr=target_sr,
|
||||||
|
clip_duration=clip_duration, clips_per_video=clips_per_video,
|
||||||
|
num_mel_bins=num_mel_bins, target_length=target_length
|
||||||
|
)
|
||||||
|
|
@ -9,7 +9,7 @@ import re
|
|||||||
|
|
||||||
from minigpt4.common.registry import registry
|
from minigpt4.common.registry import registry
|
||||||
from minigpt4.processors.base_processor import BaseProcessor
|
from minigpt4.processors.base_processor import BaseProcessor
|
||||||
from minigpt4.processors.randaugment import RandomAugment
|
from minigpt4.processors.vision_augment import RandomAugment
|
||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
from torchvision import transforms
|
from torchvision import transforms
|
||||||
from torchvision.transforms.functional import InterpolationMode
|
from torchvision.transforms.functional import InterpolationMode
|
||||||
@ -147,3 +147,5 @@ class ImageBindVisionEvalProcessor(ImageBindVisionBaseProcessor):
|
|||||||
|
|
||||||
return cls(image_size=image_size, mean=mean, std=std)
|
return cls(image_size=image_size, mean=mean, std=std)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -53,7 +53,6 @@ ipdb
|
|||||||
tensorflow-cpu
|
tensorflow-cpu
|
||||||
tensorboardX
|
tensorboardX
|
||||||
# mmcv==2.0.0 -f https://download.openmmlab.com/mmcv/dist/cu117/torch1.13/index.html
|
# mmcv==2.0.0 -f https://download.openmmlab.com/mmcv/dist/cu117/torch1.13/index.html
|
||||||
|
|
||||||
bs4==0.0.1 # Needed for text cleaning
|
bs4==0.0.1 # Needed for text cleaning
|
||||||
bson==0.5.10
|
bson==0.5.10
|
||||||
byted-dataloader==0.3.6
|
byted-dataloader==0.3.6
|
||||||
@ -66,3 +65,4 @@ tensorboard==2.11.2
|
|||||||
tensorflow==2.11.0 # Needed for tensorboard hdfs support
|
tensorflow==2.11.0 # Needed for tensorboard hdfs support
|
||||||
tensorflow-io==0.30.0 # Needed for tensorboard hdfs support
|
tensorflow-io==0.30.0 # Needed for tensorboard hdfs support
|
||||||
tqdm==4.64.1
|
tqdm==4.64.1
|
||||||
|
pytorchvideo==0.1.5
|
68
train_configs/bindgpt4_stage1_audio_pretrain.yaml
Normal file
68
train_configs/bindgpt4_stage1_audio_pretrain.yaml
Normal file
@ -0,0 +1,68 @@
|
|||||||
|
model:
|
||||||
|
arch: bind_gpt4
|
||||||
|
model_type: pretrain_vicuna
|
||||||
|
freeze_imagebind: True
|
||||||
|
freeze_qformer: False
|
||||||
|
|
||||||
|
datasets:
|
||||||
|
bbc:
|
||||||
|
audio_processor:
|
||||||
|
train:
|
||||||
|
name: "imagebind_audio_train"
|
||||||
|
text_processor:
|
||||||
|
train:
|
||||||
|
name: "imagebind_caption"
|
||||||
|
sample_ratio: 31
|
||||||
|
audioset:
|
||||||
|
audio_processor:
|
||||||
|
train:
|
||||||
|
name: "imagebind_audio_train"
|
||||||
|
text_processor:
|
||||||
|
train:
|
||||||
|
name: "imagebind_caption"
|
||||||
|
sample_ratio: 108
|
||||||
|
soundbible:
|
||||||
|
audio_processor:
|
||||||
|
train:
|
||||||
|
name: "imagebind_audio_train"
|
||||||
|
text_processor:
|
||||||
|
train:
|
||||||
|
name: "imagebind_caption"
|
||||||
|
sample_ratio: 2
|
||||||
|
freesound:
|
||||||
|
audio_processor:
|
||||||
|
train:
|
||||||
|
name: "imagebind_audio_train"
|
||||||
|
text_processor:
|
||||||
|
train:
|
||||||
|
name: "imagebind_caption"
|
||||||
|
sample_ratio: 262
|
||||||
|
|
||||||
|
run:
|
||||||
|
task: image_text_pretrain
|
||||||
|
# optimizer
|
||||||
|
lr_sched: "linear_warmup_cosine_lr"
|
||||||
|
init_lr: 1e-4
|
||||||
|
min_lr: 8e-5
|
||||||
|
warmup_lr: 1e-6
|
||||||
|
|
||||||
|
weight_decay: 0.05
|
||||||
|
max_epoch: 4
|
||||||
|
batch_size_train: 64
|
||||||
|
batch_size_eval: 64
|
||||||
|
num_workers: 4
|
||||||
|
warmup_steps: 5000
|
||||||
|
iters_per_epoch: 5000
|
||||||
|
|
||||||
|
seed: 42
|
||||||
|
output_dir: "output/bindgpt4_stage1_audio_pretrain"
|
||||||
|
|
||||||
|
amp: True
|
||||||
|
resume_ckpt_path: null
|
||||||
|
evaluate: False
|
||||||
|
train_splits: ["train"]
|
||||||
|
|
||||||
|
device: "cuda"
|
||||||
|
world_size: 1
|
||||||
|
dist_url: "env://"
|
||||||
|
distributed: True
|
Loading…
Reference in New Issue
Block a user