first commit
14
LICENSE.md
Normal file
@ -0,0 +1,14 @@
|
||||
BSD 3-Clause License
|
||||
|
||||
Copyright 2023 Deyao Zhu
|
||||
All rights reserved.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
|
||||
|
||||
1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
|
||||
|
||||
2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
|
||||
|
||||
3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
14
LICENSE_Lavis.md
Normal file
@ -0,0 +1,14 @@
|
||||
BSD 3-Clause License
|
||||
|
||||
Copyright (c) 2022 Salesforce, Inc.
|
||||
All rights reserved.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
|
||||
|
||||
1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
|
||||
|
||||
2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
|
||||
|
||||
3. Neither the name of Salesforce.com nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
BIN
MiniGPT_4.pdf
Normal file
145
README.md
Normal file
@ -0,0 +1,145 @@
|
||||
# MiniGPT-4: Enhancing Vision-language Understanding with Advanced Large Language Models
|
||||
[Deyao Zhu](https://tsutikgiau.github.io/)* (On Job Market!), [Jun Chen](https://junchen14.github.io/)* (On Job Market!), [Xiaoqian Shen](https://xiaoqian-shen.github.io), Xiang Li, and Mohamed Elhoseiny. *Equal Contribution
|
||||
|
||||
**King Abdullah University of Science and Technology**
|
||||
|
||||
<a href='https://minigpt-4.github.io'><img src='https://img.shields.io/badge/Project-Page-Green'></a> <a href='MiniGPT_4.pdf'><img src='https://img.shields.io/badge/Paper-PDF-red'></a>
|
||||
|
||||
|
||||
## Online Demo
|
||||
|
||||
Click the image to chat with MiniGPT-4 around your images
|
||||
[](https://minigpt-4.github.io)
|
||||
|
||||
|
||||
## Examples
|
||||
| | |
|
||||
:-------------------------:|:-------------------------:
|
||||
 | 
|
||||
 | 
|
||||
|
||||
More examples can be found in the [project page](https://minigpt-4.github.io).
|
||||
|
||||
|
||||
|
||||
## Introduction
|
||||
- MiniGPT-4 aligns a frozen visual encoder from BLIP-2 with a frozen LLM, Vicuna, using just one projection layer.
|
||||
- The training of MiniGPT-4 consists of a first pretrain stage using roughly 5 million aligned image-text pairs for 10 hours on 4 A100s and a second finetuning stage using additional 3,500 carefully curated high-quality pairs for 7 minutes on 1 A100.
|
||||
- MiniGPT-4 processes many emerging vision-language capabilities similar to those exhibited by GPT-4.
|
||||

|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
## Getting Started
|
||||
### Installation
|
||||
|
||||
**1. Prepare the code and the environment**
|
||||
|
||||
Git clone our repository, creating a python environment and ativate it via the following command
|
||||
|
||||
```bash
|
||||
git clone https://github.com/Vision-CAIR/MiniGPT-4.git
|
||||
cd MiniGPT-4
|
||||
conda env create -f environment.yml
|
||||
conda activate minigpt4
|
||||
```
|
||||
|
||||
|
||||
**2. Prepare the pretrained Vicuna weights**
|
||||
|
||||
The current version of MiniGPT-4 is built on the v0 versoin of Vicuna-13B.
|
||||
Please refer to their instructions [here](https://huggingface.co/lmsys/vicuna-13b-delta-v0) to obtaining the weights.
|
||||
The final weights would be in a single folder with the following structure:
|
||||
|
||||
```
|
||||
vicuna_weights
|
||||
├── config.json
|
||||
├── generation_config.json
|
||||
├── pytorch_model.bin.index.json
|
||||
├── pytorch_model-00001-of-00003.bin
|
||||
...
|
||||
```
|
||||
|
||||
Then, set the path to the vicuna weight in the model config file
|
||||
[here](minigpt4/configs/models/minigpt4.yaml#L16) at Line 16.
|
||||
|
||||
**3. Prepare the pretrained MiniGPT-4 checkpoint**
|
||||
|
||||
To play with our pretrained model, download the pretrained checkpoint
|
||||
[here](https://drive.google.com/file/d/1a4zLvaiDBr-36pasffmgpvH5P7CKmpze/view?usp=share_link).
|
||||
Then, set the path to the pretrained checkpoint in the evaluation config file
|
||||
in [eval_configs/minigpt4_eval.yaml](eval_configs/minigpt4_eval.yaml#L10) at Line 10.
|
||||
|
||||
|
||||
|
||||
### Launching Demo Locally
|
||||
|
||||
Try out our demo [demo.py](demo.py) on your local machine by running
|
||||
|
||||
```
|
||||
python demo.py --cfg-path eval_configs/minigpt4_eval.yaml
|
||||
```
|
||||
|
||||
|
||||
|
||||
### Training
|
||||
The training of MiniGPT-4 contains two alignment stages.
|
||||
|
||||
**1. First pretraining stage**
|
||||
|
||||
In the first pretrained stage, the model is trained using image-text pairs from Laion and CC datasets
|
||||
to align the vision and language model. To download and prepare the datasets, please check
|
||||
our [first stage dataset preparation instruction](dataset/README_1_STAGE.md).
|
||||
After the first stage, the visual features are mapped and can be understood by the language
|
||||
model.
|
||||
To launch the first stage training, run the following command. In our experiments, we use 4 A100.
|
||||
You can change the save path in the config file
|
||||
[train_configs/minigpt4_stage1_pretrain.yaml](train_configs/minigpt4_stage1_pretrain.yaml)
|
||||
|
||||
```bash
|
||||
torchrun --nproc-per-node NUM_GPU train.py --cfg-path train_configs/minigpt4_stage1_pretrain.yaml
|
||||
```
|
||||
|
||||
**1. Second finetuning stage**
|
||||
|
||||
In the second stage, we use a small high quality image-text pair dataset created by ourselves
|
||||
and convert it to a conversation format to further align MiniGPT-4.
|
||||
To download and prepare our second stage dataset, please check our
|
||||
[second stage dataset preparation instruction](dataset/README_2_STAGE.md).
|
||||
To launch the second stage alignment,
|
||||
first specify the path to the checkpoint file trained in stage 1 in
|
||||
[train_configs/minigpt4_stage1_pretrain.yaml](train_configs/minigpt4_stage2_finetune.yaml).
|
||||
You can also specify the output path there.
|
||||
Then, run the following command. In our experiments, we use 1 A100.
|
||||
|
||||
```bash
|
||||
torchrun --nproc-per-node NUM_GPU train.py --cfg-path train_configs/minigpt4_stage2_finetune.yaml
|
||||
```
|
||||
|
||||
After the second stage alignment, MiniGPT-4 is able to talk about the image coherently and user-friendly.
|
||||
|
||||
|
||||
|
||||
|
||||
## Acknowledgement
|
||||
|
||||
+ [BLIP2](https://huggingface.co/docs/transformers/main/model_doc/blip-2)
|
||||
+ [Vicuna](https://github.com/lm-sys/FastChat)
|
||||
|
||||
|
||||
If you're using MiniGPT-4 in your research or applications, please cite using this BibTeX:
|
||||
```bibtex
|
||||
@misc{zhu2022minigpt4,
|
||||
title={MiniGPT-4: Enhancing Vision-language Understanding with Advanced Large Language Models},
|
||||
author={Deyao Zhu and Jun Chen and Xiaoqian Shen and xiang Li and Mohamed Elhoseiny},
|
||||
year={2023},
|
||||
}
|
||||
```
|
||||
|
||||
## License
|
||||
This repository is under [BSD 3-Clause License](LICENSE.md).
|
||||
Many codes are based on [Lavis](https://github.com/salesforce/LAVIS) with
|
||||
BSD 3-Clause License [here](LICENSE_Lavis.md).
|
96
dataset/README_1_STAGE.md
Normal file
@ -0,0 +1,96 @@
|
||||
## Download the filtered Conceptual Captions, SBU, LAION datasets
|
||||
|
||||
### Pre-training datasets download:
|
||||
We use the filtered synthetic captions prepared by BLIP. For more details about the dataset, please refer to [BLIP](https://github.com/salesforce/BLIP).
|
||||
|
||||
It requires ~2.3T to store LAION and CC3M+CC12M+SBU datasets
|
||||
|
||||
Image source | Filtered synthetic caption by ViT-L
|
||||
--- | :---:
|
||||
CC3M+CC12M+SBU | <a href="https://storage.googleapis.com/sfr-vision-language-research/BLIP/datasets/ccs_synthetic_filtered_large.json">Download</a>
|
||||
LAION115M | <a href="https://storage.googleapis.com/sfr-vision-language-research/BLIP/datasets/laion_synthetic_filtered_large.json">Download</a>
|
||||
|
||||
This will download two json files
|
||||
```
|
||||
ccs_synthetic_filtered_large.json
|
||||
laion_synthetic_filtered_large.json
|
||||
```
|
||||
|
||||
## prepare the data step-by-step
|
||||
|
||||
|
||||
### setup the dataset folder and move the annotation file to the data storage folder
|
||||
```
|
||||
export MINIGPT4_DATASET=/YOUR/PATH/FOR/LARGE/DATASET/
|
||||
mkdir ${MINIGPT4_DATASET}/cc_sbu
|
||||
mkdir ${MINIGPT4_DATASET}/laion
|
||||
mv ccs_synthetic_filtered_large.json ${MINIGPT4_DATASET}/cc_sbu
|
||||
mv laion_synthetic_filtered_large.json ${MINIGPT4_DATASET}/laion
|
||||
```
|
||||
|
||||
### Convert the scripts to data storate folder
|
||||
```
|
||||
cp convert_cc_sbu.py ${MINIGPT4_DATASET}/cc_sbu
|
||||
cp download_cc_sbu.sh ${MINIGPT4_DATASET}/cc_sbu
|
||||
cp convert_laion.py ${MINIGPT4_DATASET}/laion
|
||||
cp download_laion.sh ${MINIGPT4_DATASET}/laion
|
||||
```
|
||||
|
||||
|
||||
### Convert the laion and cc_sbu annotation file format to be img2dataset format
|
||||
```
|
||||
cd ${MINIGPT4_DATASET}/cc_sbu
|
||||
python convert_cc_sbu.py
|
||||
|
||||
cd ${MINIGPT4_DATASET}/laion
|
||||
python convert_laion.py
|
||||
```
|
||||
|
||||
### Download the datasets with img2dataset
|
||||
```
|
||||
cd ${MINIGPT4_DATASET}/cc_sbu
|
||||
sh download_cc_sbu.sh
|
||||
cd ${MINIGPT4_DATASET}/laion
|
||||
sh download_laion.sh
|
||||
```
|
||||
|
||||
|
||||
The final dataset structure
|
||||
|
||||
```
|
||||
.
|
||||
├── ${MINIGPT4_DATASET}
|
||||
│ ├── cc_sbu
|
||||
│ ├── convert_cc_sbu.py
|
||||
│ ├── download_cc_sbu.sh
|
||||
│ ├── ccs_synthetic_filtered_large.json
|
||||
│ ├── ccs_synthetic_filtered_large.tsv
|
||||
│ └── cc_sbu_dataset
|
||||
│ ├── 00000.tar
|
||||
│ ├── 00000.parquet
|
||||
│ ...
|
||||
│ ├── laion
|
||||
│ ├── convert_laion.py
|
||||
│ ├── download_laion.sh
|
||||
│ ├── laion_synthetic_filtered_large.json
|
||||
│ ├── laion_synthetic_filtered_large.tsv
|
||||
│ └── laion_dataset
|
||||
│ ├── 00000.tar
|
||||
│ ├── 00000.parquet
|
||||
│ ...
|
||||
...
|
||||
```
|
||||
|
||||
|
||||
## Set up the dataset configuration files
|
||||
|
||||
Then, set up the LAION dataset loading path in
|
||||
[here](../minigpt4/configs/datasets/laion/defaults.yaml#L5) at Line 5 as
|
||||
${MINIGPT4_DATASET}/laion/laion_dataset/{00000..10488}.tar
|
||||
|
||||
and the Conceptual Captoin and SBU datasets loading path in
|
||||
[here](../minigpt4/configs/datasets/cc_sbu/defaults.yaml#L5) at Line 5 as
|
||||
${MINIGPT4_DATASET}/cc_sbu/cc_sbu_dataset/{00000..01255}.tar
|
||||
|
||||
|
||||
|
19
dataset/README_2_STAGE.md
Normal file
@ -0,0 +1,19 @@
|
||||
## Second Stage Data Preparation
|
||||
|
||||
Our second stage dataset can be downloaded from
|
||||
[here](https://drive.google.com/file/d/1nJXhoEcy3KTExr17I7BXqY5Y9Lx_-n-9/view?usp=share_link)
|
||||
After extraction, you will get a data follder with the following structure:
|
||||
|
||||
```
|
||||
cc_sbu_align
|
||||
├── filter_cap.json
|
||||
└── image
|
||||
├── 2.jpg
|
||||
├── 3.jpg
|
||||
...
|
||||
```
|
||||
|
||||
Put the folder to any path you want.
|
||||
Then, set up the dataset path in the dataset config file
|
||||
[here](../minigpt4/configs/datasets/cc_sbu/align.yaml#L5) at Line 5.
|
||||
|
20
dataset/convert_cc_sbu.py
Normal file
@ -0,0 +1,20 @@
|
||||
import json
|
||||
import csv
|
||||
|
||||
# specify input and output file paths
|
||||
input_file = 'ccs_synthetic_filtered_large.json'
|
||||
output_file = 'ccs_synthetic_filtered_large.tsv'
|
||||
|
||||
# load JSON data from input file
|
||||
with open(input_file, 'r') as f:
|
||||
data = json.load(f)
|
||||
|
||||
# extract header and data from JSON
|
||||
header = data[0].keys()
|
||||
rows = [x.values() for x in data]
|
||||
|
||||
# write data to TSV file
|
||||
with open(output_file, 'w') as f:
|
||||
writer = csv.writer(f, delimiter='\t')
|
||||
writer.writerow(header)
|
||||
writer.writerows(rows)
|
20
dataset/convert_laion.py
Normal file
@ -0,0 +1,20 @@
|
||||
import json
|
||||
import csv
|
||||
|
||||
# specify input and output file paths
|
||||
input_file = 'laion_synthetic_filtered_large.json'
|
||||
output_file = 'laion_synthetic_filtered_large.tsv'
|
||||
|
||||
# load JSON data from input file
|
||||
with open(input_file, 'r') as f:
|
||||
data = json.load(f)
|
||||
|
||||
# extract header and data from JSON
|
||||
header = data[0].keys()
|
||||
rows = [x.values() for x in data]
|
||||
|
||||
# write data to TSV file
|
||||
with open(output_file, 'w') as f:
|
||||
writer = csv.writer(f, delimiter='\t')
|
||||
writer.writerow(header)
|
||||
writer.writerows(rows)
|
6
dataset/download_cc_sbu.sh
Normal file
@ -0,0 +1,6 @@
|
||||
#!/bin/bash
|
||||
|
||||
img2dataset --url_list ccs_synthetic_filtered_large.tsv --input_format "tsv"\
|
||||
--url_col "url" --caption_col "caption" --output_format webdataset\
|
||||
--output_folder cc_sbu_dataset --processes_count 16 --thread_count 128 --image_size 256 \
|
||||
--enable_wandb True
|
6
dataset/download_laion.sh
Normal file
@ -0,0 +1,6 @@
|
||||
#!/bin/bash
|
||||
|
||||
img2dataset --url_list laion_synthetic_filtered_large.tsv --input_format "tsv"\
|
||||
--url_col "url" --caption_col "caption" --output_format webdataset\
|
||||
--output_folder laion_dataset --processes_count 16 --thread_count 128 --image_size 256 \
|
||||
--enable_wandb True
|
145
demo.py
Normal file
@ -0,0 +1,145 @@
|
||||
import argparse
|
||||
import os
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.backends.cudnn as cudnn
|
||||
import gradio as gr
|
||||
|
||||
from minigpt4.common.config import Config
|
||||
from minigpt4.common.dist_utils import get_rank
|
||||
from minigpt4.common.registry import registry
|
||||
from minigpt4.conversation.conversation import Chat, CONV_VISION
|
||||
|
||||
# imports modules for registration
|
||||
from minigpt4.datasets.builders import *
|
||||
from minigpt4.models import *
|
||||
from minigpt4.processors import *
|
||||
from minigpt4.runners import *
|
||||
from minigpt4.tasks import *
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description="Demo")
|
||||
parser.add_argument("--cfg-path", required=True, help="path to configuration file.")
|
||||
parser.add_argument(
|
||||
"--options",
|
||||
nargs="+",
|
||||
help="override some settings in the used config, the key-value pair "
|
||||
"in xxx=yyy format will be merged into config file (deprecate), "
|
||||
"change to --cfg-options instead.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def setup_seeds(config):
|
||||
seed = config.run_cfg.seed + get_rank()
|
||||
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
|
||||
cudnn.benchmark = False
|
||||
cudnn.deterministic = True
|
||||
|
||||
|
||||
# ========================================
|
||||
# Model Initialization
|
||||
# ========================================
|
||||
|
||||
print('Initializing Chat')
|
||||
cfg = Config(parse_args())
|
||||
|
||||
model_config = cfg.model_cfg
|
||||
model_cls = registry.get_model_class(model_config.arch)
|
||||
model = model_cls.from_config(model_config).to('cuda:0')
|
||||
|
||||
vis_processor_cfg = cfg.datasets_cfg.cc_sbu_align.vis_processor.train
|
||||
vis_processor = registry.get_processor_class(vis_processor_cfg.name).from_config(vis_processor_cfg)
|
||||
chat = Chat(model, vis_processor)
|
||||
print('Initialization Finished')
|
||||
|
||||
# ========================================
|
||||
# Gradio Setting
|
||||
# ========================================
|
||||
|
||||
def gradio_reset(chat_state, img_list):
|
||||
chat_state.messages = []
|
||||
img_list = []
|
||||
return None, gr.update(value=None, interactive=True), gr.update(placeholder='Please upload your image first', interactive=False),gr.update(value="Upload & Start Chat", interactive=True), chat_state, img_list
|
||||
|
||||
def upload_img(gr_img, text_input, chat_state):
|
||||
if gr_img is None:
|
||||
return None, None, gr.update(interactive=True)
|
||||
chat_state = CONV_VISION.copy()
|
||||
img_list = []
|
||||
llm_message = chat.upload_img(gr_img, chat_state, img_list)
|
||||
return gr.update(interactive=False), gr.update(interactive=True, placeholder='Type and press Enter'), gr.update(value="Start Chatting", interactive=False), chat_state, img_list
|
||||
|
||||
def gradio_ask(user_message, chatbot, chat_state):
|
||||
if len(user_message) == 0:
|
||||
return gr.update(interactive=True, placeholder='Input should not be empty!'), chatbot, chat_state
|
||||
chat.ask(user_message, chat_state)
|
||||
chatbot = chatbot + [[user_message, None]]
|
||||
return '', chatbot, chat_state
|
||||
|
||||
|
||||
def gradio_answer(chatbot, chat_state, img_list, num_beams, temperature):
|
||||
llm_message = chat.answer(conv=chat_state, img_list=img_list, max_new_tokens=1000, num_beams=num_beams, temperature=temperature)[0]
|
||||
chatbot[-1][1] = llm_message
|
||||
return chatbot, chat_state, img_list
|
||||
|
||||
title = """<h1 align="center">Demo of MiniGPT-4</h1>"""
|
||||
description = """<h3>This is the demo of MiniGPT-4. Upload your images and start chatting!</h3>"""
|
||||
article = """<strong>Paper</strong>: <a href='https://github.com/Vision-CAIR/MiniGPT-4/blob/main/MiniGPT_4.pdf' target='_blank'>Here</a>
|
||||
<strong>Code</strong>: <a href='https://github.com/Vision-CAIR/MiniGPT-4' target='_blank'>Here</a>
|
||||
<strong>Project Page</strong>: <a href='https://minigpt-4.github.io/' target='_blank'>Here</a>
|
||||
"""
|
||||
|
||||
#TODO show examples below
|
||||
|
||||
with gr.Blocks() as demo:
|
||||
gr.Markdown(title)
|
||||
gr.Markdown(description)
|
||||
gr.Markdown(article)
|
||||
|
||||
with gr.Row():
|
||||
with gr.Column(scale=0.5):
|
||||
image = gr.Image(type="pil")
|
||||
upload_button = gr.Button(value="Upload & Start Chat", interactive=True, variant="primary")
|
||||
clear = gr.Button("Restart")
|
||||
|
||||
num_beams = gr.Slider(
|
||||
minimum=1,
|
||||
maximum=16,
|
||||
value=5,
|
||||
step=1,
|
||||
interactive=True,
|
||||
label="beam search numbers)",
|
||||
)
|
||||
|
||||
temperature = gr.Slider(
|
||||
minimum=0.1,
|
||||
maximum=2.0,
|
||||
value=1.0,
|
||||
step=0.1,
|
||||
interactive=True,
|
||||
label="Temperature",
|
||||
)
|
||||
|
||||
with gr.Column():
|
||||
chat_state = gr.State()
|
||||
img_list = gr.State()
|
||||
chatbot = gr.Chatbot(label='MiniGPT-4')
|
||||
text_input = gr.Textbox(label='User', placeholder='Please upload your image first', interactive=False)
|
||||
|
||||
upload_button.click(upload_img, [image, text_input, chat_state], [image, text_input, upload_button, chat_state, img_list])
|
||||
|
||||
text_input.submit(gradio_ask, [text_input, chatbot, chat_state], [text_input, chatbot, chat_state]).then(
|
||||
gradio_answer, [chatbot, chat_state, img_list, num_beams, temperature], [chatbot, chat_state, img_list]
|
||||
)
|
||||
clear.click(gradio_reset, [chat_state, img_list], [chatbot, image, text_input, upload_button, chat_state, img_list], queue=False)
|
||||
|
||||
demo.launch(share=True, enable_queue=True)
|
63
environment.yml
Normal file
@ -0,0 +1,63 @@
|
||||
name: minigpt4
|
||||
channels:
|
||||
- pytorch
|
||||
- defaults
|
||||
- anaconda
|
||||
dependencies:
|
||||
- python=3.9
|
||||
- cudatoolkit
|
||||
- pip
|
||||
- pytorch=1.12.1
|
||||
- pytorch-mutex=1.0=cuda
|
||||
- torchaudio=0.12.1
|
||||
- torchvision=0.13.1
|
||||
- pip:
|
||||
- accelerate==0.16.0
|
||||
- aiohttp==3.8.4
|
||||
- aiosignal==1.3.1
|
||||
- async-timeout==4.0.2
|
||||
- attrs==22.2.0
|
||||
- bitsandbytes==0.37.0
|
||||
- cchardet==2.1.7
|
||||
- chardet==5.1.0
|
||||
- contourpy==1.0.7
|
||||
- cycler==0.11.0
|
||||
- filelock==3.9.0
|
||||
- fonttools==4.38.0
|
||||
- frozenlist==1.3.3
|
||||
- huggingface-hub==0.12.1
|
||||
- importlib-resources==5.12.0
|
||||
- kiwisolver==1.4.4
|
||||
- matplotlib==3.7.0
|
||||
- multidict==6.0.4
|
||||
- openai==0.27.0
|
||||
- packaging==23.0
|
||||
- psutil==5.9.4
|
||||
- pycocotools==2.0.6
|
||||
- pyparsing==3.0.9
|
||||
- python-dateutil==2.8.2
|
||||
- pyyaml==6.0
|
||||
- regex==2022.10.31
|
||||
- tokenizers==0.13.2
|
||||
- tqdm==4.64.1
|
||||
- transformers==4.28.0
|
||||
- timm==0.6.13
|
||||
- spacy==3.5.1
|
||||
- webdataset==0.2.48
|
||||
- scikit-learn==1.2.2
|
||||
- scipy==1.10.1
|
||||
- yarl==1.8.2
|
||||
- zipp==3.14.0
|
||||
- omegaconf==2.3.0
|
||||
- opencv-python==4.7.0.72
|
||||
- iopath==0.1.10
|
||||
- decord==0.6.0
|
||||
- tenacity==8.2.2
|
||||
- peft
|
||||
- pycocoevalcap
|
||||
- sentence-transformers
|
||||
- umap-learn
|
||||
- notebook
|
||||
- gradio==3.24.1
|
||||
- gradio-client==0.0.8
|
||||
- wandb
|
24
eval_configs/minigpt4_eval.yaml
Normal file
@ -0,0 +1,24 @@
|
||||
model:
|
||||
arch: mini_gpt4
|
||||
model_type: pretrain_vicuna
|
||||
freeze_vit: True
|
||||
freeze_qformer: True
|
||||
max_txt_len: 160
|
||||
end_sym: "###"
|
||||
prompt_path: "prompts/alignment.txt"
|
||||
prompt_template: '###Human: {} ###Assistant: '
|
||||
ckpt: '/path/to/pretrained/ckpt/'
|
||||
|
||||
|
||||
datasets:
|
||||
cc_sbu_align:
|
||||
vis_processor:
|
||||
train:
|
||||
name: "blip2_image_eval"
|
||||
image_size: 224
|
||||
text_processor:
|
||||
train:
|
||||
name: "blip_caption"
|
||||
|
||||
run:
|
||||
task: image_text_pretrain
|
BIN
examples/ad_1.png
Normal file
After Width: | Height: | Size: 380 KiB |
BIN
examples/ad_2.png
Normal file
After Width: | Height: | Size: 457 KiB |
BIN
examples/cook_1.png
Normal file
After Width: | Height: | Size: 538 KiB |
BIN
examples/cook_2.png
Normal file
After Width: | Height: | Size: 586 KiB |
BIN
examples/describe_1.png
Normal file
After Width: | Height: | Size: 679 KiB |
BIN
examples/describe_2.png
Normal file
After Width: | Height: | Size: 555 KiB |
BIN
examples/fact_1.png
Normal file
After Width: | Height: | Size: 468 KiB |
BIN
examples/fact_2.png
Normal file
After Width: | Height: | Size: 658 KiB |
BIN
examples/fix_1.png
Normal file
After Width: | Height: | Size: 690 KiB |
BIN
examples/fix_2.png
Normal file
After Width: | Height: | Size: 586 KiB |
BIN
examples/fun_1.png
Normal file
After Width: | Height: | Size: 713 KiB |
BIN
examples/fun_2.png
Normal file
After Width: | Height: | Size: 597 KiB |
BIN
examples/logo_1.png
Normal file
After Width: | Height: | Size: 190 KiB |
BIN
examples/op_1.png
Normal file
After Width: | Height: | Size: 603 KiB |
BIN
examples/op_2.png
Normal file
After Width: | Height: | Size: 634 KiB |
BIN
examples/people_1.png
Normal file
After Width: | Height: | Size: 249 KiB |
BIN
examples/people_2.png
Normal file
After Width: | Height: | Size: 305 KiB |
BIN
examples/rhyme_1.png
Normal file
After Width: | Height: | Size: 588 KiB |
BIN
examples/rhyme_2.png
Normal file
After Width: | Height: | Size: 805 KiB |
BIN
examples/story_1.png
Normal file
After Width: | Height: | Size: 853 KiB |
BIN
examples/story_2.png
Normal file
After Width: | Height: | Size: 567 KiB |
BIN
examples/web_1.png
Normal file
After Width: | Height: | Size: 712 KiB |
BIN
examples/wop_1.png
Normal file
After Width: | Height: | Size: 519 KiB |
BIN
examples/wop_2.png
Normal file
After Width: | Height: | Size: 565 KiB |
BIN
figs/examples/ad_1.png
Normal file
After Width: | Height: | Size: 380 KiB |
BIN
figs/examples/ad_2.png
Normal file
After Width: | Height: | Size: 457 KiB |
BIN
figs/examples/cook_1.png
Normal file
After Width: | Height: | Size: 538 KiB |
BIN
figs/examples/cook_2.png
Normal file
After Width: | Height: | Size: 586 KiB |
BIN
figs/examples/describe_1.png
Normal file
After Width: | Height: | Size: 679 KiB |
BIN
figs/examples/describe_2.png
Normal file
After Width: | Height: | Size: 555 KiB |
BIN
figs/examples/fact_1.png
Normal file
After Width: | Height: | Size: 468 KiB |
BIN
figs/examples/fact_2.png
Normal file
After Width: | Height: | Size: 658 KiB |
BIN
figs/examples/fix_1.png
Normal file
After Width: | Height: | Size: 690 KiB |
BIN
figs/examples/fix_2.png
Normal file
After Width: | Height: | Size: 586 KiB |
BIN
figs/examples/fun_1.png
Normal file
After Width: | Height: | Size: 713 KiB |
BIN
figs/examples/fun_2.png
Normal file
After Width: | Height: | Size: 597 KiB |
BIN
figs/examples/logo_1.png
Normal file
After Width: | Height: | Size: 190 KiB |
BIN
figs/examples/op_1.png
Normal file
After Width: | Height: | Size: 603 KiB |
BIN
figs/examples/op_2.png
Normal file
After Width: | Height: | Size: 634 KiB |
BIN
figs/examples/people_1.png
Normal file
After Width: | Height: | Size: 249 KiB |
BIN
figs/examples/people_2.png
Normal file
After Width: | Height: | Size: 305 KiB |
BIN
figs/examples/rhyme_1.png
Normal file
After Width: | Height: | Size: 588 KiB |
BIN
figs/examples/rhyme_2.png
Normal file
After Width: | Height: | Size: 805 KiB |
BIN
figs/examples/story_1.png
Normal file
After Width: | Height: | Size: 853 KiB |
BIN
figs/examples/story_2.png
Normal file
After Width: | Height: | Size: 567 KiB |
BIN
figs/examples/web_1.png
Normal file
After Width: | Height: | Size: 712 KiB |
BIN
figs/examples/wop_1.png
Normal file
After Width: | Height: | Size: 519 KiB |
BIN
figs/examples/wop_2.png
Normal file
After Width: | Height: | Size: 565 KiB |
BIN
figs/online_demo.png
Normal file
After Width: | Height: | Size: 1.2 MiB |
BIN
figs/overview.png
Normal file
After Width: | Height: | Size: 2.4 MiB |
31
minigpt4/__init__.py
Normal file
@ -0,0 +1,31 @@
|
||||
"""
|
||||
Copyright (c) 2022, salesforce.com, inc.
|
||||
All rights reserved.
|
||||
SPDX-License-Identifier: BSD-3-Clause
|
||||
For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
from minigpt4.common.registry import registry
|
||||
|
||||
from minigpt4.datasets.builders import *
|
||||
from minigpt4.models import *
|
||||
from minigpt4.processors import *
|
||||
from minigpt4.tasks import *
|
||||
|
||||
|
||||
root_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
default_cfg = OmegaConf.load(os.path.join(root_dir, "configs/default.yaml"))
|
||||
|
||||
registry.register_path("library_root", root_dir)
|
||||
repo_root = os.path.join(root_dir, "..")
|
||||
registry.register_path("repo_root", repo_root)
|
||||
cache_root = os.path.join(repo_root, default_cfg.env.cache_root)
|
||||
registry.register_path("cache_root", cache_root)
|
||||
|
||||
registry.register("MAX_INT", sys.maxsize)
|
||||
registry.register("SPLIT_NAMES", ["train", "val", "test"])
|
0
minigpt4/common/__init__.py
Normal file
468
minigpt4/common/config.py
Normal file
@ -0,0 +1,468 @@
|
||||
"""
|
||||
Copyright (c) 2022, salesforce.com, inc.
|
||||
All rights reserved.
|
||||
SPDX-License-Identifier: BSD-3-Clause
|
||||
For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
||||
"""
|
||||
|
||||
import logging
|
||||
import json
|
||||
from typing import Dict
|
||||
|
||||
from omegaconf import OmegaConf
|
||||
from minigpt4.common.registry import registry
|
||||
|
||||
|
||||
class Config:
|
||||
def __init__(self, args):
|
||||
self.config = {}
|
||||
|
||||
self.args = args
|
||||
|
||||
# Register the config and configuration for setup
|
||||
registry.register("configuration", self)
|
||||
|
||||
user_config = self._build_opt_list(self.args.options)
|
||||
|
||||
config = OmegaConf.load(self.args.cfg_path)
|
||||
|
||||
runner_config = self.build_runner_config(config)
|
||||
model_config = self.build_model_config(config, **user_config)
|
||||
dataset_config = self.build_dataset_config(config)
|
||||
|
||||
# Validate the user-provided runner configuration
|
||||
# model and dataset configuration are supposed to be validated by the respective classes
|
||||
# [TODO] validate the model/dataset configuration
|
||||
# self._validate_runner_config(runner_config)
|
||||
|
||||
# Override the default configuration with user options.
|
||||
self.config = OmegaConf.merge(
|
||||
runner_config, model_config, dataset_config, user_config
|
||||
)
|
||||
|
||||
def _validate_runner_config(self, runner_config):
|
||||
"""
|
||||
This method validates the configuration, such that
|
||||
1) all the user specified options are valid;
|
||||
2) no type mismatches between the user specified options and the config.
|
||||
"""
|
||||
runner_config_validator = create_runner_config_validator()
|
||||
runner_config_validator.validate(runner_config)
|
||||
|
||||
def _build_opt_list(self, opts):
|
||||
opts_dot_list = self._convert_to_dot_list(opts)
|
||||
return OmegaConf.from_dotlist(opts_dot_list)
|
||||
|
||||
@staticmethod
|
||||
def build_model_config(config, **kwargs):
|
||||
model = config.get("model", None)
|
||||
assert model is not None, "Missing model configuration file."
|
||||
|
||||
model_cls = registry.get_model_class(model.arch)
|
||||
assert model_cls is not None, f"Model '{model.arch}' has not been registered."
|
||||
|
||||
model_type = kwargs.get("model.model_type", None)
|
||||
if not model_type:
|
||||
model_type = model.get("model_type", None)
|
||||
# else use the model type selected by user.
|
||||
|
||||
assert model_type is not None, "Missing model_type."
|
||||
|
||||
model_config_path = model_cls.default_config_path(model_type=model_type)
|
||||
|
||||
model_config = OmegaConf.create()
|
||||
# hiararchy override, customized config > default config
|
||||
model_config = OmegaConf.merge(
|
||||
model_config,
|
||||
OmegaConf.load(model_config_path),
|
||||
{"model": config["model"]},
|
||||
)
|
||||
|
||||
return model_config
|
||||
|
||||
@staticmethod
|
||||
def build_runner_config(config):
|
||||
return {"run": config.run}
|
||||
|
||||
@staticmethod
|
||||
def build_dataset_config(config):
|
||||
datasets = config.get("datasets", None)
|
||||
if datasets is None:
|
||||
raise KeyError(
|
||||
"Expecting 'datasets' as the root key for dataset configuration."
|
||||
)
|
||||
|
||||
dataset_config = OmegaConf.create()
|
||||
|
||||
for dataset_name in datasets:
|
||||
builder_cls = registry.get_builder_class(dataset_name)
|
||||
|
||||
dataset_config_type = datasets[dataset_name].get("type", "default")
|
||||
dataset_config_path = builder_cls.default_config_path(
|
||||
type=dataset_config_type
|
||||
)
|
||||
|
||||
# hiararchy override, customized config > default config
|
||||
dataset_config = OmegaConf.merge(
|
||||
dataset_config,
|
||||
OmegaConf.load(dataset_config_path),
|
||||
{"datasets": {dataset_name: config["datasets"][dataset_name]}},
|
||||
)
|
||||
|
||||
return dataset_config
|
||||
|
||||
def _convert_to_dot_list(self, opts):
|
||||
if opts is None:
|
||||
opts = []
|
||||
|
||||
if len(opts) == 0:
|
||||
return opts
|
||||
|
||||
has_equal = opts[0].find("=") != -1
|
||||
|
||||
if has_equal:
|
||||
return opts
|
||||
|
||||
return [(opt + "=" + value) for opt, value in zip(opts[0::2], opts[1::2])]
|
||||
|
||||
def get_config(self):
|
||||
return self.config
|
||||
|
||||
@property
|
||||
def run_cfg(self):
|
||||
return self.config.run
|
||||
|
||||
@property
|
||||
def datasets_cfg(self):
|
||||
return self.config.datasets
|
||||
|
||||
@property
|
||||
def model_cfg(self):
|
||||
return self.config.model
|
||||
|
||||
def pretty_print(self):
|
||||
logging.info("\n===== Running Parameters =====")
|
||||
logging.info(self._convert_node_to_json(self.config.run))
|
||||
|
||||
logging.info("\n====== Dataset Attributes ======")
|
||||
datasets = self.config.datasets
|
||||
|
||||
for dataset in datasets:
|
||||
if dataset in self.config.datasets:
|
||||
logging.info(f"\n======== {dataset} =======")
|
||||
dataset_config = self.config.datasets[dataset]
|
||||
logging.info(self._convert_node_to_json(dataset_config))
|
||||
else:
|
||||
logging.warning(f"No dataset named '{dataset}' in config. Skipping")
|
||||
|
||||
logging.info(f"\n====== Model Attributes ======")
|
||||
logging.info(self._convert_node_to_json(self.config.model))
|
||||
|
||||
def _convert_node_to_json(self, node):
|
||||
container = OmegaConf.to_container(node, resolve=True)
|
||||
return json.dumps(container, indent=4, sort_keys=True)
|
||||
|
||||
def to_dict(self):
|
||||
return OmegaConf.to_container(self.config)
|
||||
|
||||
|
||||
def node_to_dict(node):
|
||||
return OmegaConf.to_container(node)
|
||||
|
||||
|
||||
class ConfigValidator:
|
||||
"""
|
||||
This is a preliminary implementation to centralize and validate the configuration.
|
||||
May be altered in the future.
|
||||
|
||||
A helper class to validate configurations from yaml file.
|
||||
|
||||
This serves the following purposes:
|
||||
1. Ensure all the options in the yaml are defined, raise error if not.
|
||||
2. when type mismatches are found, the validator will raise an error.
|
||||
3. a central place to store and display helpful messages for supported configurations.
|
||||
|
||||
"""
|
||||
|
||||
class _Argument:
|
||||
def __init__(self, name, choices=None, type=None, help=None):
|
||||
self.name = name
|
||||
self.val = None
|
||||
self.choices = choices
|
||||
self.type = type
|
||||
self.help = help
|
||||
|
||||
def __str__(self):
|
||||
s = f"{self.name}={self.val}"
|
||||
if self.type is not None:
|
||||
s += f", ({self.type})"
|
||||
if self.choices is not None:
|
||||
s += f", choices: {self.choices}"
|
||||
if self.help is not None:
|
||||
s += f", ({self.help})"
|
||||
return s
|
||||
|
||||
def __init__(self, description):
|
||||
self.description = description
|
||||
|
||||
self.arguments = dict()
|
||||
|
||||
self.parsed_args = None
|
||||
|
||||
def __getitem__(self, key):
|
||||
assert self.parsed_args is not None, "No arguments parsed yet."
|
||||
|
||||
return self.parsed_args[key]
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.format_help()
|
||||
|
||||
def add_argument(self, *args, **kwargs):
|
||||
"""
|
||||
Assume the first argument is the name of the argument.
|
||||
"""
|
||||
self.arguments[args[0]] = self._Argument(*args, **kwargs)
|
||||
|
||||
def validate(self, config=None):
|
||||
"""
|
||||
Convert yaml config (dict-like) to list, required by argparse.
|
||||
"""
|
||||
for k, v in config.items():
|
||||
assert (
|
||||
k in self.arguments
|
||||
), f"""{k} is not a valid argument. Support arguments are {self.format_arguments()}."""
|
||||
|
||||
if self.arguments[k].type is not None:
|
||||
try:
|
||||
self.arguments[k].val = self.arguments[k].type(v)
|
||||
except ValueError:
|
||||
raise ValueError(f"{k} is not a valid {self.arguments[k].type}.")
|
||||
|
||||
if self.arguments[k].choices is not None:
|
||||
assert (
|
||||
v in self.arguments[k].choices
|
||||
), f"""{k} must be one of {self.arguments[k].choices}."""
|
||||
|
||||
return config
|
||||
|
||||
def format_arguments(self):
|
||||
return str([f"{k}" for k in sorted(self.arguments.keys())])
|
||||
|
||||
def format_help(self):
|
||||
# description + key-value pair string for each argument
|
||||
help_msg = str(self.description)
|
||||
return help_msg + ", available arguments: " + self.format_arguments()
|
||||
|
||||
def print_help(self):
|
||||
# display help message
|
||||
print(self.format_help())
|
||||
|
||||
|
||||
def create_runner_config_validator():
|
||||
validator = ConfigValidator(description="Runner configurations")
|
||||
|
||||
validator.add_argument(
|
||||
"runner",
|
||||
type=str,
|
||||
choices=["runner_base", "runner_iter"],
|
||||
help="""Runner to use. The "runner_base" uses epoch-based training while iter-based
|
||||
runner runs based on iters. Default: runner_base""",
|
||||
)
|
||||
# add argumetns for training dataset ratios
|
||||
validator.add_argument(
|
||||
"train_dataset_ratios",
|
||||
type=Dict[str, float],
|
||||
help="""Ratios of training dataset. This is used in iteration-based runner.
|
||||
Do not support for epoch-based runner because how to define an epoch becomes tricky.
|
||||
Default: None""",
|
||||
)
|
||||
validator.add_argument(
|
||||
"max_iters",
|
||||
type=float,
|
||||
help="Maximum number of iterations to run.",
|
||||
)
|
||||
validator.add_argument(
|
||||
"max_epoch",
|
||||
type=int,
|
||||
help="Maximum number of epochs to run.",
|
||||
)
|
||||
# add arguments for iters_per_inner_epoch
|
||||
validator.add_argument(
|
||||
"iters_per_inner_epoch",
|
||||
type=float,
|
||||
help="Number of iterations per inner epoch. This is required when runner is runner_iter.",
|
||||
)
|
||||
lr_scheds_choices = registry.list_lr_schedulers()
|
||||
validator.add_argument(
|
||||
"lr_sched",
|
||||
type=str,
|
||||
choices=lr_scheds_choices,
|
||||
help="Learning rate scheduler to use, from {}".format(lr_scheds_choices),
|
||||
)
|
||||
task_choices = registry.list_tasks()
|
||||
validator.add_argument(
|
||||
"task",
|
||||
type=str,
|
||||
choices=task_choices,
|
||||
help="Task to use, from {}".format(task_choices),
|
||||
)
|
||||
# add arguments for init_lr
|
||||
validator.add_argument(
|
||||
"init_lr",
|
||||
type=float,
|
||||
help="Initial learning rate. This will be the learning rate after warmup and before decay.",
|
||||
)
|
||||
# add arguments for min_lr
|
||||
validator.add_argument(
|
||||
"min_lr",
|
||||
type=float,
|
||||
help="Minimum learning rate (after decay).",
|
||||
)
|
||||
# add arguments for warmup_lr
|
||||
validator.add_argument(
|
||||
"warmup_lr",
|
||||
type=float,
|
||||
help="Starting learning rate for warmup.",
|
||||
)
|
||||
# add arguments for learning rate decay rate
|
||||
validator.add_argument(
|
||||
"lr_decay_rate",
|
||||
type=float,
|
||||
help="Learning rate decay rate. Required if using a decaying learning rate scheduler.",
|
||||
)
|
||||
# add arguments for weight decay
|
||||
validator.add_argument(
|
||||
"weight_decay",
|
||||
type=float,
|
||||
help="Weight decay rate.",
|
||||
)
|
||||
# add arguments for training batch size
|
||||
validator.add_argument(
|
||||
"batch_size_train",
|
||||
type=int,
|
||||
help="Training batch size.",
|
||||
)
|
||||
# add arguments for evaluation batch size
|
||||
validator.add_argument(
|
||||
"batch_size_eval",
|
||||
type=int,
|
||||
help="Evaluation batch size, including validation and testing.",
|
||||
)
|
||||
# add arguments for number of workers for data loading
|
||||
validator.add_argument(
|
||||
"num_workers",
|
||||
help="Number of workers for data loading.",
|
||||
)
|
||||
# add arguments for warm up steps
|
||||
validator.add_argument(
|
||||
"warmup_steps",
|
||||
type=int,
|
||||
help="Number of warmup steps. Required if a warmup schedule is used.",
|
||||
)
|
||||
# add arguments for random seed
|
||||
validator.add_argument(
|
||||
"seed",
|
||||
type=int,
|
||||
help="Random seed.",
|
||||
)
|
||||
# add arguments for output directory
|
||||
validator.add_argument(
|
||||
"output_dir",
|
||||
type=str,
|
||||
help="Output directory to save checkpoints and logs.",
|
||||
)
|
||||
# add arguments for whether only use evaluation
|
||||
validator.add_argument(
|
||||
"evaluate",
|
||||
help="Whether to only evaluate the model. If true, training will not be performed.",
|
||||
)
|
||||
# add arguments for splits used for training, e.g. ["train", "val"]
|
||||
validator.add_argument(
|
||||
"train_splits",
|
||||
type=list,
|
||||
help="Splits to use for training.",
|
||||
)
|
||||
# add arguments for splits used for validation, e.g. ["val"]
|
||||
validator.add_argument(
|
||||
"valid_splits",
|
||||
type=list,
|
||||
help="Splits to use for validation. If not provided, will skip the validation.",
|
||||
)
|
||||
# add arguments for splits used for testing, e.g. ["test"]
|
||||
validator.add_argument(
|
||||
"test_splits",
|
||||
type=list,
|
||||
help="Splits to use for testing. If not provided, will skip the testing.",
|
||||
)
|
||||
# add arguments for accumulating gradient for iterations
|
||||
validator.add_argument(
|
||||
"accum_grad_iters",
|
||||
type=int,
|
||||
help="Number of iterations to accumulate gradient for.",
|
||||
)
|
||||
|
||||
# ====== distributed training ======
|
||||
validator.add_argument(
|
||||
"device",
|
||||
type=str,
|
||||
choices=["cpu", "cuda"],
|
||||
help="Device to use. Support 'cuda' or 'cpu' as for now.",
|
||||
)
|
||||
validator.add_argument(
|
||||
"world_size",
|
||||
type=int,
|
||||
help="Number of processes participating in the job.",
|
||||
)
|
||||
validator.add_argument("dist_url", type=str)
|
||||
validator.add_argument("distributed", type=bool)
|
||||
# add arguments to opt using distributed sampler during evaluation or not
|
||||
validator.add_argument(
|
||||
"use_dist_eval_sampler",
|
||||
type=bool,
|
||||
help="Whether to use distributed sampler during evaluation or not.",
|
||||
)
|
||||
|
||||
# ====== task specific ======
|
||||
# generation task specific arguments
|
||||
# add arguments for maximal length of text output
|
||||
validator.add_argument(
|
||||
"max_len",
|
||||
type=int,
|
||||
help="Maximal length of text output.",
|
||||
)
|
||||
# add arguments for minimal length of text output
|
||||
validator.add_argument(
|
||||
"min_len",
|
||||
type=int,
|
||||
help="Minimal length of text output.",
|
||||
)
|
||||
# add arguments number of beams
|
||||
validator.add_argument(
|
||||
"num_beams",
|
||||
type=int,
|
||||
help="Number of beams used for beam search.",
|
||||
)
|
||||
|
||||
# vqa task specific arguments
|
||||
# add arguments for number of answer candidates
|
||||
validator.add_argument(
|
||||
"num_ans_candidates",
|
||||
type=int,
|
||||
help="""For ALBEF and BLIP, these models first rank answers according to likelihood to select answer candidates.""",
|
||||
)
|
||||
# add arguments for inference method
|
||||
validator.add_argument(
|
||||
"inference_method",
|
||||
type=str,
|
||||
choices=["genearte", "rank"],
|
||||
help="""Inference method to use for question answering. If rank, requires a answer list.""",
|
||||
)
|
||||
|
||||
# ====== model specific ======
|
||||
validator.add_argument(
|
||||
"k_test",
|
||||
type=int,
|
||||
help="Number of top k most similar samples from ITC/VTC selection to be tested.",
|
||||
)
|
||||
|
||||
return validator
|
137
minigpt4/common/dist_utils.py
Normal file
@ -0,0 +1,137 @@
|
||||
"""
|
||||
Copyright (c) 2022, salesforce.com, inc.
|
||||
All rights reserved.
|
||||
SPDX-License-Identifier: BSD-3-Clause
|
||||
For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
||||
"""
|
||||
|
||||
import datetime
|
||||
import functools
|
||||
import os
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import timm.models.hub as timm_hub
|
||||
|
||||
|
||||
def setup_for_distributed(is_master):
|
||||
"""
|
||||
This function disables printing when not in master process
|
||||
"""
|
||||
import builtins as __builtin__
|
||||
|
||||
builtin_print = __builtin__.print
|
||||
|
||||
def print(*args, **kwargs):
|
||||
force = kwargs.pop("force", False)
|
||||
if is_master or force:
|
||||
builtin_print(*args, **kwargs)
|
||||
|
||||
__builtin__.print = print
|
||||
|
||||
|
||||
def is_dist_avail_and_initialized():
|
||||
if not dist.is_available():
|
||||
return False
|
||||
if not dist.is_initialized():
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def get_world_size():
|
||||
if not is_dist_avail_and_initialized():
|
||||
return 1
|
||||
return dist.get_world_size()
|
||||
|
||||
|
||||
def get_rank():
|
||||
if not is_dist_avail_and_initialized():
|
||||
return 0
|
||||
return dist.get_rank()
|
||||
|
||||
|
||||
def is_main_process():
|
||||
return get_rank() == 0
|
||||
|
||||
|
||||
def init_distributed_mode(args):
|
||||
if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
|
||||
args.rank = int(os.environ["RANK"])
|
||||
args.world_size = int(os.environ["WORLD_SIZE"])
|
||||
args.gpu = int(os.environ["LOCAL_RANK"])
|
||||
elif "SLURM_PROCID" in os.environ:
|
||||
args.rank = int(os.environ["SLURM_PROCID"])
|
||||
args.gpu = args.rank % torch.cuda.device_count()
|
||||
else:
|
||||
print("Not using distributed mode")
|
||||
args.distributed = False
|
||||
return
|
||||
|
||||
args.distributed = True
|
||||
|
||||
torch.cuda.set_device(args.gpu)
|
||||
args.dist_backend = "nccl"
|
||||
print(
|
||||
"| distributed init (rank {}, world {}): {}".format(
|
||||
args.rank, args.world_size, args.dist_url
|
||||
),
|
||||
flush=True,
|
||||
)
|
||||
torch.distributed.init_process_group(
|
||||
backend=args.dist_backend,
|
||||
init_method=args.dist_url,
|
||||
world_size=args.world_size,
|
||||
rank=args.rank,
|
||||
timeout=datetime.timedelta(
|
||||
days=365
|
||||
), # allow auto-downloading and de-compressing
|
||||
)
|
||||
torch.distributed.barrier()
|
||||
setup_for_distributed(args.rank == 0)
|
||||
|
||||
|
||||
def get_dist_info():
|
||||
if torch.__version__ < "1.0":
|
||||
initialized = dist._initialized
|
||||
else:
|
||||
initialized = dist.is_initialized()
|
||||
if initialized:
|
||||
rank = dist.get_rank()
|
||||
world_size = dist.get_world_size()
|
||||
else: # non-distributed training
|
||||
rank = 0
|
||||
world_size = 1
|
||||
return rank, world_size
|
||||
|
||||
|
||||
def main_process(func):
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
rank, _ = get_dist_info()
|
||||
if rank == 0:
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def download_cached_file(url, check_hash=True, progress=False):
|
||||
"""
|
||||
Download a file from a URL and cache it locally. If the file already exists, it is not downloaded again.
|
||||
If distributed, only the main process downloads the file, and the other processes wait for the file to be downloaded.
|
||||
"""
|
||||
|
||||
def get_cached_file_path():
|
||||
# a hack to sync the file path across processes
|
||||
parts = torch.hub.urlparse(url)
|
||||
filename = os.path.basename(parts.path)
|
||||
cached_file = os.path.join(timm_hub.get_cache_dir(), filename)
|
||||
|
||||
return cached_file
|
||||
|
||||
if is_main_process():
|
||||
timm_hub.download_cached_file(url, check_hash, progress)
|
||||
|
||||
if is_dist_avail_and_initialized():
|
||||
dist.barrier()
|
||||
|
||||
return get_cached_file_path()
|
24
minigpt4/common/gradcam.py
Normal file
@ -0,0 +1,24 @@
|
||||
import numpy as np
|
||||
from matplotlib import pyplot as plt
|
||||
from scipy.ndimage import filters
|
||||
from skimage import transform as skimage_transform
|
||||
|
||||
|
||||
def getAttMap(img, attMap, blur=True, overlap=True):
|
||||
attMap -= attMap.min()
|
||||
if attMap.max() > 0:
|
||||
attMap /= attMap.max()
|
||||
attMap = skimage_transform.resize(attMap, (img.shape[:2]), order=3, mode="constant")
|
||||
if blur:
|
||||
attMap = filters.gaussian_filter(attMap, 0.02 * max(img.shape[:2]))
|
||||
attMap -= attMap.min()
|
||||
attMap /= attMap.max()
|
||||
cmap = plt.get_cmap("jet")
|
||||
attMapV = cmap(attMap)
|
||||
attMapV = np.delete(attMapV, 3, 2)
|
||||
if overlap:
|
||||
attMap = (
|
||||
1 * (1 - attMap**0.7).reshape(attMap.shape + (1,)) * img
|
||||
+ (attMap**0.7).reshape(attMap.shape + (1,)) * attMapV
|
||||
)
|
||||
return attMap
|
195
minigpt4/common/logger.py
Normal file
@ -0,0 +1,195 @@
|
||||
"""
|
||||
Copyright (c) 2022, salesforce.com, inc.
|
||||
All rights reserved.
|
||||
SPDX-License-Identifier: BSD-3-Clause
|
||||
For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
||||
"""
|
||||
|
||||
import datetime
|
||||
import logging
|
||||
import time
|
||||
from collections import defaultdict, deque
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from minigpt4.common import dist_utils
|
||||
|
||||
|
||||
class SmoothedValue(object):
|
||||
"""Track a series of values and provide access to smoothed values over a
|
||||
window or the global series average.
|
||||
"""
|
||||
|
||||
def __init__(self, window_size=20, fmt=None):
|
||||
if fmt is None:
|
||||
fmt = "{median:.4f} ({global_avg:.4f})"
|
||||
self.deque = deque(maxlen=window_size)
|
||||
self.total = 0.0
|
||||
self.count = 0
|
||||
self.fmt = fmt
|
||||
|
||||
def update(self, value, n=1):
|
||||
self.deque.append(value)
|
||||
self.count += n
|
||||
self.total += value * n
|
||||
|
||||
def synchronize_between_processes(self):
|
||||
"""
|
||||
Warning: does not synchronize the deque!
|
||||
"""
|
||||
if not dist_utils.is_dist_avail_and_initialized():
|
||||
return
|
||||
t = torch.tensor([self.count, self.total], dtype=torch.float64, device="cuda")
|
||||
dist.barrier()
|
||||
dist.all_reduce(t)
|
||||
t = t.tolist()
|
||||
self.count = int(t[0])
|
||||
self.total = t[1]
|
||||
|
||||
@property
|
||||
def median(self):
|
||||
d = torch.tensor(list(self.deque))
|
||||
return d.median().item()
|
||||
|
||||
@property
|
||||
def avg(self):
|
||||
d = torch.tensor(list(self.deque), dtype=torch.float32)
|
||||
return d.mean().item()
|
||||
|
||||
@property
|
||||
def global_avg(self):
|
||||
return self.total / self.count
|
||||
|
||||
@property
|
||||
def max(self):
|
||||
return max(self.deque)
|
||||
|
||||
@property
|
||||
def value(self):
|
||||
return self.deque[-1]
|
||||
|
||||
def __str__(self):
|
||||
return self.fmt.format(
|
||||
median=self.median,
|
||||
avg=self.avg,
|
||||
global_avg=self.global_avg,
|
||||
max=self.max,
|
||||
value=self.value,
|
||||
)
|
||||
|
||||
|
||||
class MetricLogger(object):
|
||||
def __init__(self, delimiter="\t"):
|
||||
self.meters = defaultdict(SmoothedValue)
|
||||
self.delimiter = delimiter
|
||||
|
||||
def update(self, **kwargs):
|
||||
for k, v in kwargs.items():
|
||||
if isinstance(v, torch.Tensor):
|
||||
v = v.item()
|
||||
assert isinstance(v, (float, int))
|
||||
self.meters[k].update(v)
|
||||
|
||||
def __getattr__(self, attr):
|
||||
if attr in self.meters:
|
||||
return self.meters[attr]
|
||||
if attr in self.__dict__:
|
||||
return self.__dict__[attr]
|
||||
raise AttributeError(
|
||||
"'{}' object has no attribute '{}'".format(type(self).__name__, attr)
|
||||
)
|
||||
|
||||
def __str__(self):
|
||||
loss_str = []
|
||||
for name, meter in self.meters.items():
|
||||
loss_str.append("{}: {}".format(name, str(meter)))
|
||||
return self.delimiter.join(loss_str)
|
||||
|
||||
def global_avg(self):
|
||||
loss_str = []
|
||||
for name, meter in self.meters.items():
|
||||
loss_str.append("{}: {:.4f}".format(name, meter.global_avg))
|
||||
return self.delimiter.join(loss_str)
|
||||
|
||||
def synchronize_between_processes(self):
|
||||
for meter in self.meters.values():
|
||||
meter.synchronize_between_processes()
|
||||
|
||||
def add_meter(self, name, meter):
|
||||
self.meters[name] = meter
|
||||
|
||||
def log_every(self, iterable, print_freq, header=None):
|
||||
i = 0
|
||||
if not header:
|
||||
header = ""
|
||||
start_time = time.time()
|
||||
end = time.time()
|
||||
iter_time = SmoothedValue(fmt="{avg:.4f}")
|
||||
data_time = SmoothedValue(fmt="{avg:.4f}")
|
||||
space_fmt = ":" + str(len(str(len(iterable)))) + "d"
|
||||
log_msg = [
|
||||
header,
|
||||
"[{0" + space_fmt + "}/{1}]",
|
||||
"eta: {eta}",
|
||||
"{meters}",
|
||||
"time: {time}",
|
||||
"data: {data}",
|
||||
]
|
||||
if torch.cuda.is_available():
|
||||
log_msg.append("max mem: {memory:.0f}")
|
||||
log_msg = self.delimiter.join(log_msg)
|
||||
MB = 1024.0 * 1024.0
|
||||
for obj in iterable:
|
||||
data_time.update(time.time() - end)
|
||||
yield obj
|
||||
iter_time.update(time.time() - end)
|
||||
if i % print_freq == 0 or i == len(iterable) - 1:
|
||||
eta_seconds = iter_time.global_avg * (len(iterable) - i)
|
||||
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
|
||||
if torch.cuda.is_available():
|
||||
print(
|
||||
log_msg.format(
|
||||
i,
|
||||
len(iterable),
|
||||
eta=eta_string,
|
||||
meters=str(self),
|
||||
time=str(iter_time),
|
||||
data=str(data_time),
|
||||
memory=torch.cuda.max_memory_allocated() / MB,
|
||||
)
|
||||
)
|
||||
else:
|
||||
print(
|
||||
log_msg.format(
|
||||
i,
|
||||
len(iterable),
|
||||
eta=eta_string,
|
||||
meters=str(self),
|
||||
time=str(iter_time),
|
||||
data=str(data_time),
|
||||
)
|
||||
)
|
||||
i += 1
|
||||
end = time.time()
|
||||
total_time = time.time() - start_time
|
||||
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
||||
print(
|
||||
"{} Total time: {} ({:.4f} s / it)".format(
|
||||
header, total_time_str, total_time / len(iterable)
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class AttrDict(dict):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(AttrDict, self).__init__(*args, **kwargs)
|
||||
self.__dict__ = self
|
||||
|
||||
|
||||
def setup_logger():
|
||||
logging.basicConfig(
|
||||
level=logging.INFO if dist_utils.is_main_process() else logging.WARN,
|
||||
format="%(asctime)s [%(levelname)s] %(message)s",
|
||||
handlers=[logging.StreamHandler()],
|
||||
)
|
119
minigpt4/common/optims.py
Normal file
@ -0,0 +1,119 @@
|
||||
"""
|
||||
Copyright (c) 2022, salesforce.com, inc.
|
||||
All rights reserved.
|
||||
SPDX-License-Identifier: BSD-3-Clause
|
||||
For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
||||
"""
|
||||
|
||||
import math
|
||||
|
||||
from minigpt4.common.registry import registry
|
||||
|
||||
|
||||
@registry.register_lr_scheduler("linear_warmup_step_lr")
|
||||
class LinearWarmupStepLRScheduler:
|
||||
def __init__(
|
||||
self,
|
||||
optimizer,
|
||||
max_epoch,
|
||||
min_lr,
|
||||
init_lr,
|
||||
decay_rate=1,
|
||||
warmup_start_lr=-1,
|
||||
warmup_steps=0,
|
||||
**kwargs
|
||||
):
|
||||
self.optimizer = optimizer
|
||||
|
||||
self.max_epoch = max_epoch
|
||||
self.min_lr = min_lr
|
||||
|
||||
self.decay_rate = decay_rate
|
||||
|
||||
self.init_lr = init_lr
|
||||
self.warmup_steps = warmup_steps
|
||||
self.warmup_start_lr = warmup_start_lr if warmup_start_lr >= 0 else init_lr
|
||||
|
||||
def step(self, cur_epoch, cur_step):
|
||||
if cur_epoch == 0:
|
||||
warmup_lr_schedule(
|
||||
step=cur_step,
|
||||
optimizer=self.optimizer,
|
||||
max_step=self.warmup_steps,
|
||||
init_lr=self.warmup_start_lr,
|
||||
max_lr=self.init_lr,
|
||||
)
|
||||
else:
|
||||
step_lr_schedule(
|
||||
epoch=cur_epoch,
|
||||
optimizer=self.optimizer,
|
||||
init_lr=self.init_lr,
|
||||
min_lr=self.min_lr,
|
||||
decay_rate=self.decay_rate,
|
||||
)
|
||||
|
||||
|
||||
@registry.register_lr_scheduler("linear_warmup_cosine_lr")
|
||||
class LinearWarmupCosineLRScheduler:
|
||||
def __init__(
|
||||
self,
|
||||
optimizer,
|
||||
max_epoch,
|
||||
iters_per_epoch,
|
||||
min_lr,
|
||||
init_lr,
|
||||
warmup_steps=0,
|
||||
warmup_start_lr=-1,
|
||||
**kwargs
|
||||
):
|
||||
self.optimizer = optimizer
|
||||
|
||||
self.max_epoch = max_epoch
|
||||
self.iters_per_epoch = iters_per_epoch
|
||||
self.min_lr = min_lr
|
||||
|
||||
self.init_lr = init_lr
|
||||
self.warmup_steps = warmup_steps
|
||||
self.warmup_start_lr = warmup_start_lr if warmup_start_lr >= 0 else init_lr
|
||||
|
||||
def step(self, cur_epoch, cur_step):
|
||||
total_cur_step = cur_epoch * self.iters_per_epoch + cur_step
|
||||
if total_cur_step < self.warmup_steps:
|
||||
warmup_lr_schedule(
|
||||
step=cur_step,
|
||||
optimizer=self.optimizer,
|
||||
max_step=self.warmup_steps,
|
||||
init_lr=self.warmup_start_lr,
|
||||
max_lr=self.init_lr,
|
||||
)
|
||||
else:
|
||||
cosine_lr_schedule(
|
||||
epoch=total_cur_step,
|
||||
optimizer=self.optimizer,
|
||||
max_epoch=self.max_epoch * self.iters_per_epoch,
|
||||
init_lr=self.init_lr,
|
||||
min_lr=self.min_lr,
|
||||
)
|
||||
|
||||
|
||||
def cosine_lr_schedule(optimizer, epoch, max_epoch, init_lr, min_lr):
|
||||
"""Decay the learning rate"""
|
||||
lr = (init_lr - min_lr) * 0.5 * (
|
||||
1.0 + math.cos(math.pi * epoch / max_epoch)
|
||||
) + min_lr
|
||||
for param_group in optimizer.param_groups:
|
||||
param_group["lr"] = lr
|
||||
|
||||
|
||||
def warmup_lr_schedule(optimizer, step, max_step, init_lr, max_lr):
|
||||
"""Warmup the learning rate"""
|
||||
lr = min(max_lr, init_lr + (max_lr - init_lr) * step / max(max_step, 1))
|
||||
for param_group in optimizer.param_groups:
|
||||
param_group["lr"] = lr
|
||||
|
||||
|
||||
def step_lr_schedule(optimizer, epoch, init_lr, min_lr, decay_rate):
|
||||
"""Decay the learning rate"""
|
||||
lr = max(min_lr, init_lr * (decay_rate**epoch))
|
||||
for param_group in optimizer.param_groups:
|
||||
param_group["lr"] = lr
|
329
minigpt4/common/registry.py
Normal file
@ -0,0 +1,329 @@
|
||||
"""
|
||||
Copyright (c) 2022, salesforce.com, inc.
|
||||
All rights reserved.
|
||||
SPDX-License-Identifier: BSD-3-Clause
|
||||
For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
||||
"""
|
||||
|
||||
|
||||
class Registry:
|
||||
mapping = {
|
||||
"builder_name_mapping": {},
|
||||
"task_name_mapping": {},
|
||||
"processor_name_mapping": {},
|
||||
"model_name_mapping": {},
|
||||
"lr_scheduler_name_mapping": {},
|
||||
"runner_name_mapping": {},
|
||||
"state": {},
|
||||
"paths": {},
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def register_builder(cls, name):
|
||||
r"""Register a dataset builder to registry with key 'name'
|
||||
|
||||
Args:
|
||||
name: Key with which the builder will be registered.
|
||||
|
||||
Usage:
|
||||
|
||||
from minigpt4.common.registry import registry
|
||||
from minigpt4.datasets.base_dataset_builder import BaseDatasetBuilder
|
||||
"""
|
||||
|
||||
def wrap(builder_cls):
|
||||
from minigpt4.datasets.builders.base_dataset_builder import BaseDatasetBuilder
|
||||
|
||||
assert issubclass(
|
||||
builder_cls, BaseDatasetBuilder
|
||||
), "All builders must inherit BaseDatasetBuilder class, found {}".format(
|
||||
builder_cls
|
||||
)
|
||||
if name in cls.mapping["builder_name_mapping"]:
|
||||
raise KeyError(
|
||||
"Name '{}' already registered for {}.".format(
|
||||
name, cls.mapping["builder_name_mapping"][name]
|
||||
)
|
||||
)
|
||||
cls.mapping["builder_name_mapping"][name] = builder_cls
|
||||
return builder_cls
|
||||
|
||||
return wrap
|
||||
|
||||
@classmethod
|
||||
def register_task(cls, name):
|
||||
r"""Register a task to registry with key 'name'
|
||||
|
||||
Args:
|
||||
name: Key with which the task will be registered.
|
||||
|
||||
Usage:
|
||||
|
||||
from minigpt4.common.registry import registry
|
||||
"""
|
||||
|
||||
def wrap(task_cls):
|
||||
from minigpt4.tasks.base_task import BaseTask
|
||||
|
||||
assert issubclass(
|
||||
task_cls, BaseTask
|
||||
), "All tasks must inherit BaseTask class"
|
||||
if name in cls.mapping["task_name_mapping"]:
|
||||
raise KeyError(
|
||||
"Name '{}' already registered for {}.".format(
|
||||
name, cls.mapping["task_name_mapping"][name]
|
||||
)
|
||||
)
|
||||
cls.mapping["task_name_mapping"][name] = task_cls
|
||||
return task_cls
|
||||
|
||||
return wrap
|
||||
|
||||
@classmethod
|
||||
def register_model(cls, name):
|
||||
r"""Register a task to registry with key 'name'
|
||||
|
||||
Args:
|
||||
name: Key with which the task will be registered.
|
||||
|
||||
Usage:
|
||||
|
||||
from minigpt4.common.registry import registry
|
||||
"""
|
||||
|
||||
def wrap(model_cls):
|
||||
from minigpt4.models import BaseModel
|
||||
|
||||
assert issubclass(
|
||||
model_cls, BaseModel
|
||||
), "All models must inherit BaseModel class"
|
||||
if name in cls.mapping["model_name_mapping"]:
|
||||
raise KeyError(
|
||||
"Name '{}' already registered for {}.".format(
|
||||
name, cls.mapping["model_name_mapping"][name]
|
||||
)
|
||||
)
|
||||
cls.mapping["model_name_mapping"][name] = model_cls
|
||||
return model_cls
|
||||
|
||||
return wrap
|
||||
|
||||
@classmethod
|
||||
def register_processor(cls, name):
|
||||
r"""Register a processor to registry with key 'name'
|
||||
|
||||
Args:
|
||||
name: Key with which the task will be registered.
|
||||
|
||||
Usage:
|
||||
|
||||
from minigpt4.common.registry import registry
|
||||
"""
|
||||
|
||||
def wrap(processor_cls):
|
||||
from minigpt4.processors import BaseProcessor
|
||||
|
||||
assert issubclass(
|
||||
processor_cls, BaseProcessor
|
||||
), "All processors must inherit BaseProcessor class"
|
||||
if name in cls.mapping["processor_name_mapping"]:
|
||||
raise KeyError(
|
||||
"Name '{}' already registered for {}.".format(
|
||||
name, cls.mapping["processor_name_mapping"][name]
|
||||
)
|
||||
)
|
||||
cls.mapping["processor_name_mapping"][name] = processor_cls
|
||||
return processor_cls
|
||||
|
||||
return wrap
|
||||
|
||||
@classmethod
|
||||
def register_lr_scheduler(cls, name):
|
||||
r"""Register a model to registry with key 'name'
|
||||
|
||||
Args:
|
||||
name: Key with which the task will be registered.
|
||||
|
||||
Usage:
|
||||
|
||||
from minigpt4.common.registry import registry
|
||||
"""
|
||||
|
||||
def wrap(lr_sched_cls):
|
||||
if name in cls.mapping["lr_scheduler_name_mapping"]:
|
||||
raise KeyError(
|
||||
"Name '{}' already registered for {}.".format(
|
||||
name, cls.mapping["lr_scheduler_name_mapping"][name]
|
||||
)
|
||||
)
|
||||
cls.mapping["lr_scheduler_name_mapping"][name] = lr_sched_cls
|
||||
return lr_sched_cls
|
||||
|
||||
return wrap
|
||||
|
||||
@classmethod
|
||||
def register_runner(cls, name):
|
||||
r"""Register a model to registry with key 'name'
|
||||
|
||||
Args:
|
||||
name: Key with which the task will be registered.
|
||||
|
||||
Usage:
|
||||
|
||||
from minigpt4.common.registry import registry
|
||||
"""
|
||||
|
||||
def wrap(runner_cls):
|
||||
if name in cls.mapping["runner_name_mapping"]:
|
||||
raise KeyError(
|
||||
"Name '{}' already registered for {}.".format(
|
||||
name, cls.mapping["runner_name_mapping"][name]
|
||||
)
|
||||
)
|
||||
cls.mapping["runner_name_mapping"][name] = runner_cls
|
||||
return runner_cls
|
||||
|
||||
return wrap
|
||||
|
||||
@classmethod
|
||||
def register_path(cls, name, path):
|
||||
r"""Register a path to registry with key 'name'
|
||||
|
||||
Args:
|
||||
name: Key with which the path will be registered.
|
||||
|
||||
Usage:
|
||||
|
||||
from minigpt4.common.registry import registry
|
||||
"""
|
||||
assert isinstance(path, str), "All path must be str."
|
||||
if name in cls.mapping["paths"]:
|
||||
raise KeyError("Name '{}' already registered.".format(name))
|
||||
cls.mapping["paths"][name] = path
|
||||
|
||||
@classmethod
|
||||
def register(cls, name, obj):
|
||||
r"""Register an item to registry with key 'name'
|
||||
|
||||
Args:
|
||||
name: Key with which the item will be registered.
|
||||
|
||||
Usage::
|
||||
|
||||
from minigpt4.common.registry import registry
|
||||
|
||||
registry.register("config", {})
|
||||
"""
|
||||
path = name.split(".")
|
||||
current = cls.mapping["state"]
|
||||
|
||||
for part in path[:-1]:
|
||||
if part not in current:
|
||||
current[part] = {}
|
||||
current = current[part]
|
||||
|
||||
current[path[-1]] = obj
|
||||
|
||||
# @classmethod
|
||||
# def get_trainer_class(cls, name):
|
||||
# return cls.mapping["trainer_name_mapping"].get(name, None)
|
||||
|
||||
@classmethod
|
||||
def get_builder_class(cls, name):
|
||||
return cls.mapping["builder_name_mapping"].get(name, None)
|
||||
|
||||
@classmethod
|
||||
def get_model_class(cls, name):
|
||||
return cls.mapping["model_name_mapping"].get(name, None)
|
||||
|
||||
@classmethod
|
||||
def get_task_class(cls, name):
|
||||
return cls.mapping["task_name_mapping"].get(name, None)
|
||||
|
||||
@classmethod
|
||||
def get_processor_class(cls, name):
|
||||
return cls.mapping["processor_name_mapping"].get(name, None)
|
||||
|
||||
@classmethod
|
||||
def get_lr_scheduler_class(cls, name):
|
||||
return cls.mapping["lr_scheduler_name_mapping"].get(name, None)
|
||||
|
||||
@classmethod
|
||||
def get_runner_class(cls, name):
|
||||
return cls.mapping["runner_name_mapping"].get(name, None)
|
||||
|
||||
@classmethod
|
||||
def list_runners(cls):
|
||||
return sorted(cls.mapping["runner_name_mapping"].keys())
|
||||
|
||||
@classmethod
|
||||
def list_models(cls):
|
||||
return sorted(cls.mapping["model_name_mapping"].keys())
|
||||
|
||||
@classmethod
|
||||
def list_tasks(cls):
|
||||
return sorted(cls.mapping["task_name_mapping"].keys())
|
||||
|
||||
@classmethod
|
||||
def list_processors(cls):
|
||||
return sorted(cls.mapping["processor_name_mapping"].keys())
|
||||
|
||||
@classmethod
|
||||
def list_lr_schedulers(cls):
|
||||
return sorted(cls.mapping["lr_scheduler_name_mapping"].keys())
|
||||
|
||||
@classmethod
|
||||
def list_datasets(cls):
|
||||
return sorted(cls.mapping["builder_name_mapping"].keys())
|
||||
|
||||
@classmethod
|
||||
def get_path(cls, name):
|
||||
return cls.mapping["paths"].get(name, None)
|
||||
|
||||
@classmethod
|
||||
def get(cls, name, default=None, no_warning=False):
|
||||
r"""Get an item from registry with key 'name'
|
||||
|
||||
Args:
|
||||
name (string): Key whose value needs to be retrieved.
|
||||
default: If passed and key is not in registry, default value will
|
||||
be returned with a warning. Default: None
|
||||
no_warning (bool): If passed as True, warning when key doesn't exist
|
||||
will not be generated. Useful for MMF's
|
||||
internal operations. Default: False
|
||||
"""
|
||||
original_name = name
|
||||
name = name.split(".")
|
||||
value = cls.mapping["state"]
|
||||
for subname in name:
|
||||
value = value.get(subname, default)
|
||||
if value is default:
|
||||
break
|
||||
|
||||
if (
|
||||
"writer" in cls.mapping["state"]
|
||||
and value == default
|
||||
and no_warning is False
|
||||
):
|
||||
cls.mapping["state"]["writer"].warning(
|
||||
"Key {} is not present in registry, returning default value "
|
||||
"of {}".format(original_name, default)
|
||||
)
|
||||
return value
|
||||
|
||||
@classmethod
|
||||
def unregister(cls, name):
|
||||
r"""Remove an item from registry with key 'name'
|
||||
|
||||
Args:
|
||||
name: Key which needs to be removed.
|
||||
Usage::
|
||||
|
||||
from mmf.common.registry import registry
|
||||
|
||||
config = registry.unregister("config")
|
||||
"""
|
||||
return cls.mapping["state"].pop(name, None)
|
||||
|
||||
|
||||
registry = Registry()
|
424
minigpt4/common/utils.py
Normal file
@ -0,0 +1,424 @@
|
||||
"""
|
||||
Copyright (c) 2022, salesforce.com, inc.
|
||||
All rights reserved.
|
||||
SPDX-License-Identifier: BSD-3-Clause
|
||||
For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
||||
"""
|
||||
|
||||
import io
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import pickle
|
||||
import re
|
||||
import shutil
|
||||
import urllib
|
||||
import urllib.error
|
||||
import urllib.request
|
||||
from typing import Optional
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import yaml
|
||||
from iopath.common.download import download
|
||||
from iopath.common.file_io import file_lock, g_pathmgr
|
||||
from minigpt4.common.registry import registry
|
||||
from torch.utils.model_zoo import tqdm
|
||||
from torchvision.datasets.utils import (
|
||||
check_integrity,
|
||||
download_file_from_google_drive,
|
||||
extract_archive,
|
||||
)
|
||||
|
||||
|
||||
def now():
|
||||
from datetime import datetime
|
||||
|
||||
return datetime.now().strftime("%Y%m%d%H%M")[:-1]
|
||||
|
||||
|
||||
def is_url(url_or_filename):
|
||||
parsed = urlparse(url_or_filename)
|
||||
return parsed.scheme in ("http", "https")
|
||||
|
||||
|
||||
def get_cache_path(rel_path):
|
||||
return os.path.expanduser(os.path.join(registry.get_path("cache_root"), rel_path))
|
||||
|
||||
|
||||
def get_abs_path(rel_path):
|
||||
return os.path.join(registry.get_path("library_root"), rel_path)
|
||||
|
||||
|
||||
def load_json(filename):
|
||||
with open(filename, "r") as f:
|
||||
return json.load(f)
|
||||
|
||||
|
||||
# The following are adapted from torchvision and vissl
|
||||
# torchvision: https://github.com/pytorch/vision
|
||||
# vissl: https://github.com/facebookresearch/vissl/blob/main/vissl/utils/download.py
|
||||
|
||||
|
||||
def makedir(dir_path):
|
||||
"""
|
||||
Create the directory if it does not exist.
|
||||
"""
|
||||
is_success = False
|
||||
try:
|
||||
if not g_pathmgr.exists(dir_path):
|
||||
g_pathmgr.mkdirs(dir_path)
|
||||
is_success = True
|
||||
except BaseException:
|
||||
print(f"Error creating directory: {dir_path}")
|
||||
return is_success
|
||||
|
||||
|
||||
def get_redirected_url(url: str):
|
||||
"""
|
||||
Given a URL, returns the URL it redirects to or the
|
||||
original URL in case of no indirection
|
||||
"""
|
||||
import requests
|
||||
|
||||
with requests.Session() as session:
|
||||
with session.get(url, stream=True, allow_redirects=True) as response:
|
||||
if response.history:
|
||||
return response.url
|
||||
else:
|
||||
return url
|
||||
|
||||
|
||||
def to_google_drive_download_url(view_url: str) -> str:
|
||||
"""
|
||||
Utility function to transform a view URL of google drive
|
||||
to a download URL for google drive
|
||||
Example input:
|
||||
https://drive.google.com/file/d/137RyRjvTBkBiIfeYBNZBtViDHQ6_Ewsp/view
|
||||
Example output:
|
||||
https://drive.google.com/uc?export=download&id=137RyRjvTBkBiIfeYBNZBtViDHQ6_Ewsp
|
||||
"""
|
||||
splits = view_url.split("/")
|
||||
assert splits[-1] == "view"
|
||||
file_id = splits[-2]
|
||||
return f"https://drive.google.com/uc?export=download&id={file_id}"
|
||||
|
||||
|
||||
def download_google_drive_url(url: str, output_path: str, output_file_name: str):
|
||||
"""
|
||||
Download a file from google drive
|
||||
Downloading an URL from google drive requires confirmation when
|
||||
the file of the size is too big (google drive notifies that
|
||||
anti-viral checks cannot be performed on such files)
|
||||
"""
|
||||
import requests
|
||||
|
||||
with requests.Session() as session:
|
||||
|
||||
# First get the confirmation token and append it to the URL
|
||||
with session.get(url, stream=True, allow_redirects=True) as response:
|
||||
for k, v in response.cookies.items():
|
||||
if k.startswith("download_warning"):
|
||||
url = url + "&confirm=" + v
|
||||
|
||||
# Then download the content of the file
|
||||
with session.get(url, stream=True, verify=True) as response:
|
||||
makedir(output_path)
|
||||
path = os.path.join(output_path, output_file_name)
|
||||
total_size = int(response.headers.get("Content-length", 0))
|
||||
with open(path, "wb") as file:
|
||||
from tqdm import tqdm
|
||||
|
||||
with tqdm(total=total_size) as progress_bar:
|
||||
for block in response.iter_content(
|
||||
chunk_size=io.DEFAULT_BUFFER_SIZE
|
||||
):
|
||||
file.write(block)
|
||||
progress_bar.update(len(block))
|
||||
|
||||
|
||||
def _get_google_drive_file_id(url: str) -> Optional[str]:
|
||||
parts = urlparse(url)
|
||||
|
||||
if re.match(r"(drive|docs)[.]google[.]com", parts.netloc) is None:
|
||||
return None
|
||||
|
||||
match = re.match(r"/file/d/(?P<id>[^/]*)", parts.path)
|
||||
if match is None:
|
||||
return None
|
||||
|
||||
return match.group("id")
|
||||
|
||||
|
||||
def _urlretrieve(url: str, filename: str, chunk_size: int = 1024) -> None:
|
||||
with open(filename, "wb") as fh:
|
||||
with urllib.request.urlopen(
|
||||
urllib.request.Request(url, headers={"User-Agent": "vissl"})
|
||||
) as response:
|
||||
with tqdm(total=response.length) as pbar:
|
||||
for chunk in iter(lambda: response.read(chunk_size), ""):
|
||||
if not chunk:
|
||||
break
|
||||
pbar.update(chunk_size)
|
||||
fh.write(chunk)
|
||||
|
||||
|
||||
def download_url(
|
||||
url: str,
|
||||
root: str,
|
||||
filename: Optional[str] = None,
|
||||
md5: Optional[str] = None,
|
||||
) -> None:
|
||||
"""Download a file from a url and place it in root.
|
||||
Args:
|
||||
url (str): URL to download file from
|
||||
root (str): Directory to place downloaded file in
|
||||
filename (str, optional): Name to save the file under.
|
||||
If None, use the basename of the URL.
|
||||
md5 (str, optional): MD5 checksum of the download. If None, do not check
|
||||
"""
|
||||
root = os.path.expanduser(root)
|
||||
if not filename:
|
||||
filename = os.path.basename(url)
|
||||
fpath = os.path.join(root, filename)
|
||||
|
||||
makedir(root)
|
||||
|
||||
# check if file is already present locally
|
||||
if check_integrity(fpath, md5):
|
||||
print("Using downloaded and verified file: " + fpath)
|
||||
return
|
||||
|
||||
# expand redirect chain if needed
|
||||
url = get_redirected_url(url)
|
||||
|
||||
# check if file is located on Google Drive
|
||||
file_id = _get_google_drive_file_id(url)
|
||||
if file_id is not None:
|
||||
return download_file_from_google_drive(file_id, root, filename, md5)
|
||||
|
||||
# download the file
|
||||
try:
|
||||
print("Downloading " + url + " to " + fpath)
|
||||
_urlretrieve(url, fpath)
|
||||
except (urllib.error.URLError, IOError) as e: # type: ignore[attr-defined]
|
||||
if url[:5] == "https":
|
||||
url = url.replace("https:", "http:")
|
||||
print(
|
||||
"Failed download. Trying https -> http instead."
|
||||
" Downloading " + url + " to " + fpath
|
||||
)
|
||||
_urlretrieve(url, fpath)
|
||||
else:
|
||||
raise e
|
||||
|
||||
# check integrity of downloaded file
|
||||
if not check_integrity(fpath, md5):
|
||||
raise RuntimeError("File not found or corrupted.")
|
||||
|
||||
|
||||
def download_and_extract_archive(
|
||||
url: str,
|
||||
download_root: str,
|
||||
extract_root: Optional[str] = None,
|
||||
filename: Optional[str] = None,
|
||||
md5: Optional[str] = None,
|
||||
remove_finished: bool = False,
|
||||
) -> None:
|
||||
download_root = os.path.expanduser(download_root)
|
||||
if extract_root is None:
|
||||
extract_root = download_root
|
||||
if not filename:
|
||||
filename = os.path.basename(url)
|
||||
|
||||
download_url(url, download_root, filename, md5)
|
||||
|
||||
archive = os.path.join(download_root, filename)
|
||||
print("Extracting {} to {}".format(archive, extract_root))
|
||||
extract_archive(archive, extract_root, remove_finished)
|
||||
|
||||
|
||||
def cache_url(url: str, cache_dir: str) -> str:
|
||||
"""
|
||||
This implementation downloads the remote resource and caches it locally.
|
||||
The resource will only be downloaded if not previously requested.
|
||||
"""
|
||||
parsed_url = urlparse(url)
|
||||
dirname = os.path.join(cache_dir, os.path.dirname(parsed_url.path.lstrip("/")))
|
||||
makedir(dirname)
|
||||
filename = url.split("/")[-1]
|
||||
cached = os.path.join(dirname, filename)
|
||||
with file_lock(cached):
|
||||
if not os.path.isfile(cached):
|
||||
logging.info(f"Downloading {url} to {cached} ...")
|
||||
cached = download(url, dirname, filename=filename)
|
||||
logging.info(f"URL {url} cached in {cached}")
|
||||
return cached
|
||||
|
||||
|
||||
# TODO (prigoyal): convert this into RAII-style API
|
||||
def create_file_symlink(file1, file2):
|
||||
"""
|
||||
Simply create the symlinks for a given file1 to file2.
|
||||
Useful during model checkpointing to symlinks to the
|
||||
latest successful checkpoint.
|
||||
"""
|
||||
try:
|
||||
if g_pathmgr.exists(file2):
|
||||
g_pathmgr.rm(file2)
|
||||
g_pathmgr.symlink(file1, file2)
|
||||
except Exception as e:
|
||||
logging.info(f"Could NOT create symlink. Error: {e}")
|
||||
|
||||
|
||||
def save_file(data, filename, append_to_json=True, verbose=True):
|
||||
"""
|
||||
Common i/o utility to handle saving data to various file formats.
|
||||
Supported:
|
||||
.pkl, .pickle, .npy, .json
|
||||
Specifically for .json, users have the option to either append (default)
|
||||
or rewrite by passing in Boolean value to append_to_json.
|
||||
"""
|
||||
if verbose:
|
||||
logging.info(f"Saving data to file: {filename}")
|
||||
file_ext = os.path.splitext(filename)[1]
|
||||
if file_ext in [".pkl", ".pickle"]:
|
||||
with g_pathmgr.open(filename, "wb") as fopen:
|
||||
pickle.dump(data, fopen, pickle.HIGHEST_PROTOCOL)
|
||||
elif file_ext == ".npy":
|
||||
with g_pathmgr.open(filename, "wb") as fopen:
|
||||
np.save(fopen, data)
|
||||
elif file_ext == ".json":
|
||||
if append_to_json:
|
||||
with g_pathmgr.open(filename, "a") as fopen:
|
||||
fopen.write(json.dumps(data, sort_keys=True) + "\n")
|
||||
fopen.flush()
|
||||
else:
|
||||
with g_pathmgr.open(filename, "w") as fopen:
|
||||
fopen.write(json.dumps(data, sort_keys=True) + "\n")
|
||||
fopen.flush()
|
||||
elif file_ext == ".yaml":
|
||||
with g_pathmgr.open(filename, "w") as fopen:
|
||||
dump = yaml.dump(data)
|
||||
fopen.write(dump)
|
||||
fopen.flush()
|
||||
else:
|
||||
raise Exception(f"Saving {file_ext} is not supported yet")
|
||||
|
||||
if verbose:
|
||||
logging.info(f"Saved data to file: {filename}")
|
||||
|
||||
|
||||
def load_file(filename, mmap_mode=None, verbose=True, allow_pickle=False):
|
||||
"""
|
||||
Common i/o utility to handle loading data from various file formats.
|
||||
Supported:
|
||||
.pkl, .pickle, .npy, .json
|
||||
For the npy files, we support reading the files in mmap_mode.
|
||||
If the mmap_mode of reading is not successful, we load data without the
|
||||
mmap_mode.
|
||||
"""
|
||||
if verbose:
|
||||
logging.info(f"Loading data from file: {filename}")
|
||||
|
||||
file_ext = os.path.splitext(filename)[1]
|
||||
if file_ext == ".txt":
|
||||
with g_pathmgr.open(filename, "r") as fopen:
|
||||
data = fopen.readlines()
|
||||
elif file_ext in [".pkl", ".pickle"]:
|
||||
with g_pathmgr.open(filename, "rb") as fopen:
|
||||
data = pickle.load(fopen, encoding="latin1")
|
||||
elif file_ext == ".npy":
|
||||
if mmap_mode:
|
||||
try:
|
||||
with g_pathmgr.open(filename, "rb") as fopen:
|
||||
data = np.load(
|
||||
fopen,
|
||||
allow_pickle=allow_pickle,
|
||||
encoding="latin1",
|
||||
mmap_mode=mmap_mode,
|
||||
)
|
||||
except ValueError as e:
|
||||
logging.info(
|
||||
f"Could not mmap {filename}: {e}. Trying without g_pathmgr"
|
||||
)
|
||||
data = np.load(
|
||||
filename,
|
||||
allow_pickle=allow_pickle,
|
||||
encoding="latin1",
|
||||
mmap_mode=mmap_mode,
|
||||
)
|
||||
logging.info("Successfully loaded without g_pathmgr")
|
||||
except Exception:
|
||||
logging.info("Could not mmap without g_pathmgr. Trying without mmap")
|
||||
with g_pathmgr.open(filename, "rb") as fopen:
|
||||
data = np.load(fopen, allow_pickle=allow_pickle, encoding="latin1")
|
||||
else:
|
||||
with g_pathmgr.open(filename, "rb") as fopen:
|
||||
data = np.load(fopen, allow_pickle=allow_pickle, encoding="latin1")
|
||||
elif file_ext == ".json":
|
||||
with g_pathmgr.open(filename, "r") as fopen:
|
||||
data = json.load(fopen)
|
||||
elif file_ext == ".yaml":
|
||||
with g_pathmgr.open(filename, "r") as fopen:
|
||||
data = yaml.load(fopen, Loader=yaml.FullLoader)
|
||||
elif file_ext == ".csv":
|
||||
with g_pathmgr.open(filename, "r") as fopen:
|
||||
data = pd.read_csv(fopen)
|
||||
else:
|
||||
raise Exception(f"Reading from {file_ext} is not supported yet")
|
||||
return data
|
||||
|
||||
|
||||
def abspath(resource_path: str):
|
||||
"""
|
||||
Make a path absolute, but take into account prefixes like
|
||||
"http://" or "manifold://"
|
||||
"""
|
||||
regex = re.compile(r"^\w+://")
|
||||
if regex.match(resource_path) is None:
|
||||
return os.path.abspath(resource_path)
|
||||
else:
|
||||
return resource_path
|
||||
|
||||
|
||||
def makedir(dir_path):
|
||||
"""
|
||||
Create the directory if it does not exist.
|
||||
"""
|
||||
is_success = False
|
||||
try:
|
||||
if not g_pathmgr.exists(dir_path):
|
||||
g_pathmgr.mkdirs(dir_path)
|
||||
is_success = True
|
||||
except BaseException:
|
||||
logging.info(f"Error creating directory: {dir_path}")
|
||||
return is_success
|
||||
|
||||
|
||||
def is_url(input_url):
|
||||
"""
|
||||
Check if an input string is a url. look for http(s):// and ignoring the case
|
||||
"""
|
||||
is_url = re.match(r"^(?:http)s?://", input_url, re.IGNORECASE) is not None
|
||||
return is_url
|
||||
|
||||
|
||||
def cleanup_dir(dir):
|
||||
"""
|
||||
Utility for deleting a directory. Useful for cleaning the storage space
|
||||
that contains various training artifacts like checkpoints, data etc.
|
||||
"""
|
||||
if os.path.exists(dir):
|
||||
logging.info(f"Deleting directory: {dir}")
|
||||
shutil.rmtree(dir)
|
||||
logging.info(f"Deleted contents of directory: {dir}")
|
||||
|
||||
|
||||
def get_file_size(filename):
|
||||
"""
|
||||
Given a file, get the size of file in MB
|
||||
"""
|
||||
size_in_mb = os.path.getsize(filename) / float(1024**2)
|
||||
return size_in_mb
|
5
minigpt4/configs/datasets/cc_sbu/align.yaml
Normal file
@ -0,0 +1,5 @@
|
||||
datasets:
|
||||
cc_sbu_align:
|
||||
data_type: images
|
||||
build_info:
|
||||
storage: /path/to/cc_sbu_align/
|
5
minigpt4/configs/datasets/cc_sbu/defaults.yaml
Normal file
@ -0,0 +1,5 @@
|
||||
datasets:
|
||||
cc_sbu:
|
||||
data_type: images
|
||||
build_info:
|
||||
storage: /path/to/cc_sbu_dataset/{00000..01255}.tar
|
5
minigpt4/configs/datasets/laion/defaults.yaml
Normal file
@ -0,0 +1,5 @@
|
||||
datasets:
|
||||
laion:
|
||||
data_type: images
|
||||
build_info:
|
||||
storage: /path/to/laion_dataset/{00000..10488}.tar
|
5
minigpt4/configs/default.yaml
Normal file
@ -0,0 +1,5 @@
|
||||
env:
|
||||
# For default users
|
||||
# cache_root: "cache"
|
||||
# For internal use with persistent storage
|
||||
cache_root: "/export/home/.cache/minigpt4"
|
33
minigpt4/configs/models/minigpt4.yaml
Normal file
@ -0,0 +1,33 @@
|
||||
model:
|
||||
arch: mini_gpt4
|
||||
|
||||
# vit encoder
|
||||
image_size: 224
|
||||
drop_path_rate: 0
|
||||
use_grad_checkpoint: False
|
||||
vit_precision: "fp16"
|
||||
freeze_vit: True
|
||||
freeze_qformer: True
|
||||
|
||||
# Q-Former
|
||||
num_query_token: 32
|
||||
|
||||
# Vicuna
|
||||
llama_model: "/path/to/vicuna/weights/"
|
||||
|
||||
# generation configs
|
||||
prompt: ""
|
||||
|
||||
preprocess:
|
||||
vis_processor:
|
||||
train:
|
||||
name: "blip2_image_train"
|
||||
image_size: 224
|
||||
eval:
|
||||
name: "blip2_image_eval"
|
||||
image_size: 224
|
||||
text_processor:
|
||||
train:
|
||||
name: "blip_caption"
|
||||
eval:
|
||||
name: "blip_caption"
|
0
minigpt4/conversation/__init__.py
Normal file
195
minigpt4/conversation/conversation.py
Normal file
@ -0,0 +1,195 @@
|
||||
import argparse
|
||||
import time
|
||||
from PIL import Image
|
||||
|
||||
import torch
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM, LlamaTokenizer
|
||||
from transformers import StoppingCriteria, StoppingCriteriaList
|
||||
|
||||
import dataclasses
|
||||
from enum import auto, Enum
|
||||
from typing import List, Tuple, Any
|
||||
|
||||
from minigpt4.common.registry import registry
|
||||
|
||||
|
||||
class SeparatorStyle(Enum):
|
||||
"""Different separator style."""
|
||||
SINGLE = auto()
|
||||
TWO = auto()
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Conversation:
|
||||
"""A class that keeps all conversation history."""
|
||||
system: str
|
||||
roles: List[str]
|
||||
messages: List[List[str]]
|
||||
offset: int
|
||||
# system_img: List[Image.Image] = []
|
||||
sep_style: SeparatorStyle = SeparatorStyle.SINGLE
|
||||
sep: str = "###"
|
||||
sep2: str = None
|
||||
|
||||
skip_next: bool = False
|
||||
conv_id: Any = None
|
||||
|
||||
def get_prompt(self):
|
||||
if self.sep_style == SeparatorStyle.SINGLE:
|
||||
ret = self.system + self.sep
|
||||
for role, message in self.messages:
|
||||
if message:
|
||||
ret += role + ": " + message + self.sep
|
||||
else:
|
||||
ret += role + ":"
|
||||
return ret
|
||||
elif self.sep_style == SeparatorStyle.TWO:
|
||||
seps = [self.sep, self.sep2]
|
||||
ret = self.system + seps[0]
|
||||
for i, (role, message) in enumerate(self.messages):
|
||||
if message:
|
||||
ret += role + ": " + message + seps[i % 2]
|
||||
else:
|
||||
ret += role + ":"
|
||||
return ret
|
||||
else:
|
||||
raise ValueError(f"Invalid style: {self.sep_style}")
|
||||
|
||||
def append_message(self, role, message):
|
||||
self.messages.append([role, message])
|
||||
|
||||
def to_gradio_chatbot(self):
|
||||
ret = []
|
||||
for i, (role, msg) in enumerate(self.messages[self.offset:]):
|
||||
if i % 2 == 0:
|
||||
ret.append([msg, None])
|
||||
else:
|
||||
ret[-1][-1] = msg
|
||||
return ret
|
||||
|
||||
def copy(self):
|
||||
return Conversation(
|
||||
system=self.system,
|
||||
# system_img=self.system_img,
|
||||
roles=self.roles,
|
||||
messages=[[x, y] for x, y in self.messages],
|
||||
offset=self.offset,
|
||||
sep_style=self.sep_style,
|
||||
sep=self.sep,
|
||||
sep2=self.sep2,
|
||||
conv_id=self.conv_id)
|
||||
|
||||
def dict(self):
|
||||
return {
|
||||
"system": self.system,
|
||||
# "system_img": self.system_img,
|
||||
"roles": self.roles,
|
||||
"messages": self.messages,
|
||||
"offset": self.offset,
|
||||
"sep": self.sep,
|
||||
"sep2": self.sep2,
|
||||
"conv_id": self.conv_id,
|
||||
}
|
||||
|
||||
|
||||
class StoppingCriteriaSub(StoppingCriteria):
|
||||
|
||||
def __init__(self, stops=[], encounters=1):
|
||||
super().__init__()
|
||||
self.stops = stops
|
||||
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
|
||||
for stop in self.stops:
|
||||
if torch.all((stop == input_ids[0][-len(stop):])).item():
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
CONV_VISION = Conversation(
|
||||
system="Give the following image: <Img>ImageContent</Img>. "
|
||||
"You will be able to see the image once I provide it to you. Please answer my questions.",
|
||||
roles=("Human", "Assistant"),
|
||||
messages=[],
|
||||
offset=2,
|
||||
sep_style=SeparatorStyle.SINGLE,
|
||||
sep="###",
|
||||
)
|
||||
|
||||
|
||||
|
||||
class Chat:
|
||||
def __init__(self, model, vis_processor, device='cuda:0'):
|
||||
self.device = device
|
||||
self.model = model
|
||||
self.vis_processor = vis_processor
|
||||
stop_words_ids = [torch.tensor([835]).to(self.device),
|
||||
torch.tensor([2277, 29937]).to(self.device)] # '###' can be encoded in two different ways.
|
||||
self.stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)])
|
||||
|
||||
def ask(self, text, conv):
|
||||
if len(conv.messages) > 0 and conv.messages[-1][0] == conv.roles[0] \
|
||||
and conv.messages[-1][1][-6:] == '</Img>': # last message is image.
|
||||
conv.messages[-1][1] = ' '.join([conv.messages[-1][1], text])
|
||||
else:
|
||||
conv.append_message(conv.roles[0], text)
|
||||
|
||||
def answer(self, conv, img_list, max_new_tokens=200, num_beams=5, min_length=1, top_p=0.9,
|
||||
repetition_penalty=1.0, length_penalty=1, temperature=1):
|
||||
conv.append_message(conv.roles[1], None)
|
||||
embs = self.get_context_emb(conv, img_list)
|
||||
outputs = self.model.llama_model.generate(
|
||||
inputs_embeds=embs,
|
||||
max_new_tokens=max_new_tokens,
|
||||
stopping_criteria=self.stopping_criteria,
|
||||
num_beams=num_beams,
|
||||
min_length=min_length,
|
||||
top_p=top_p,
|
||||
repetition_penalty=repetition_penalty,
|
||||
length_penalty=length_penalty,
|
||||
temperature=temperature,
|
||||
)
|
||||
output_token = outputs[0]
|
||||
if output_token[0] == 0:
|
||||
output_token = output_token[1:]
|
||||
output_text = self.model.llama_tokenizer.decode(output_token, add_special_tokens=False)
|
||||
output_text = output_text.split('###')[0] # remove the stop sign '###'
|
||||
output_text = output_text.split('Assistant:')[-1].strip()
|
||||
conv.messages[-1][1] = output_text
|
||||
return output_text, output_token.cpu().numpy()
|
||||
|
||||
def upload_img(self, image, conv, img_list):
|
||||
if isinstance(image, str): # is a image path
|
||||
raw_image = Image.open(image).convert('RGB')
|
||||
image = self.vis_processor(raw_image).unsqueeze(0).to(self.device)
|
||||
elif isinstance(image, Image.Image):
|
||||
raw_image = image
|
||||
image = self.vis_processor(raw_image).unsqueeze(0).to(self.device)
|
||||
elif isinstance(image, torch.Tensor):
|
||||
if len(image.shape) == 3:
|
||||
image = image.unsqueeze(0)
|
||||
image = image.to(self.device)
|
||||
|
||||
image_emb, _ = self.model.encode_img(image)
|
||||
img_list.append(image_emb)
|
||||
conv.append_message(conv.roles[0], "<Img><ImageHere></Img>")
|
||||
msg = "Received."
|
||||
# self.conv.append_message(self.conv.roles[1], msg)
|
||||
return msg
|
||||
|
||||
def get_context_emb(self, conv, img_list):
|
||||
prompt = conv.get_prompt()
|
||||
prompt_segs = prompt.split('<ImageHere>')
|
||||
assert len(prompt_segs) == len(img_list) + 1, "Unmatched numbers of image placeholders and images."
|
||||
seg_tokens = [
|
||||
self.model.llama_tokenizer(
|
||||
seg, return_tensors="pt", add_special_tokens=i == 0).to(self.device).input_ids
|
||||
# only add bos to the first seg
|
||||
for i, seg in enumerate(prompt_segs)
|
||||
]
|
||||
seg_embs = [self.model.llama_model.model.embed_tokens(seg_t) for seg_t in seg_tokens]
|
||||
mixed_embs = [emb for pair in zip(seg_embs[:-1], img_list) for emb in pair] + [seg_embs[-1]]
|
||||
mixed_embs = torch.cat(mixed_embs, dim=1)
|
||||
return mixed_embs
|
||||
|
||||
|
0
minigpt4/datasets/__init__.py
Normal file
72
minigpt4/datasets/builders/__init__.py
Normal file
@ -0,0 +1,72 @@
|
||||
"""
|
||||
Copyright (c) 2022, salesforce.com, inc.
|
||||
All rights reserved.
|
||||
SPDX-License-Identifier: BSD-3-Clause
|
||||
For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
||||
"""
|
||||
|
||||
from minigpt4.datasets.builders.base_dataset_builder import load_dataset_config
|
||||
from minigpt4.datasets.builders.image_text_pair_builder import (
|
||||
CCSBUBuilder,
|
||||
LaionBuilder,
|
||||
CCSBUAlignBuilder
|
||||
)
|
||||
from minigpt4.common.registry import registry
|
||||
|
||||
__all__ = [
|
||||
"CCSBUBuilder",
|
||||
"LaionBuilder",
|
||||
"CCSBUAlignBuilder"
|
||||
]
|
||||
|
||||
|
||||
def load_dataset(name, cfg_path=None, vis_path=None, data_type=None):
|
||||
"""
|
||||
Example
|
||||
|
||||
>>> dataset = load_dataset("coco_caption", cfg=None)
|
||||
>>> splits = dataset.keys()
|
||||
>>> print([len(dataset[split]) for split in splits])
|
||||
|
||||
"""
|
||||
if cfg_path is None:
|
||||
cfg = None
|
||||
else:
|
||||
cfg = load_dataset_config(cfg_path)
|
||||
|
||||
try:
|
||||
builder = registry.get_builder_class(name)(cfg)
|
||||
except TypeError:
|
||||
print(
|
||||
f"Dataset {name} not found. Available datasets:\n"
|
||||
+ ", ".join([str(k) for k in dataset_zoo.get_names()])
|
||||
)
|
||||
exit(1)
|
||||
|
||||
if vis_path is not None:
|
||||
if data_type is None:
|
||||
# use default data type in the config
|
||||
data_type = builder.config.data_type
|
||||
|
||||
assert (
|
||||
data_type in builder.config.build_info
|
||||
), f"Invalid data_type {data_type} for {name}."
|
||||
|
||||
builder.config.build_info.get(data_type).storage = vis_path
|
||||
|
||||
dataset = builder.build_datasets()
|
||||
return dataset
|
||||
|
||||
|
||||
class DatasetZoo:
|
||||
def __init__(self) -> None:
|
||||
self.dataset_zoo = {
|
||||
k: list(v.DATASET_CONFIG_DICT.keys())
|
||||
for k, v in sorted(registry.mapping["builder_name_mapping"].items())
|
||||
}
|
||||
|
||||
def get_names(self):
|
||||
return list(self.dataset_zoo.keys())
|
||||
|
||||
|
||||
dataset_zoo = DatasetZoo()
|
236
minigpt4/datasets/builders/base_dataset_builder.py
Normal file
@ -0,0 +1,236 @@
|
||||
"""
|
||||
This file is from
|
||||
Copyright (c) 2022, salesforce.com, inc.
|
||||
All rights reserved.
|
||||
SPDX-License-Identifier: BSD-3-Clause
|
||||
For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
import warnings
|
||||
|
||||
from omegaconf import OmegaConf
|
||||
import torch.distributed as dist
|
||||
from torchvision.datasets.utils import download_url
|
||||
|
||||
import minigpt4.common.utils as utils
|
||||
from minigpt4.common.dist_utils import is_dist_avail_and_initialized, is_main_process
|
||||
from minigpt4.common.registry import registry
|
||||
from minigpt4.processors.base_processor import BaseProcessor
|
||||
|
||||
|
||||
|
||||
class BaseDatasetBuilder:
|
||||
train_dataset_cls, eval_dataset_cls = None, None
|
||||
|
||||
def __init__(self, cfg=None):
|
||||
super().__init__()
|
||||
|
||||
if cfg is None:
|
||||
# help to create datasets from default config.
|
||||
self.config = load_dataset_config(self.default_config_path())
|
||||
elif isinstance(cfg, str):
|
||||
self.config = load_dataset_config(cfg)
|
||||
else:
|
||||
# when called from task.build_dataset()
|
||||
self.config = cfg
|
||||
|
||||
self.data_type = self.config.data_type
|
||||
|
||||
self.vis_processors = {"train": BaseProcessor(), "eval": BaseProcessor()}
|
||||
self.text_processors = {"train": BaseProcessor(), "eval": BaseProcessor()}
|
||||
|
||||
def build_datasets(self):
|
||||
# download, split, etc...
|
||||
# only called on 1 GPU/TPU in distributed
|
||||
|
||||
if is_main_process():
|
||||
self._download_data()
|
||||
|
||||
if is_dist_avail_and_initialized():
|
||||
dist.barrier()
|
||||
|
||||
# at this point, all the annotations and image/videos should be all downloaded to the specified locations.
|
||||
logging.info("Building datasets...")
|
||||
datasets = self.build() # dataset['train'/'val'/'test']
|
||||
|
||||
return datasets
|
||||
|
||||
def build_processors(self):
|
||||
vis_proc_cfg = self.config.get("vis_processor")
|
||||
txt_proc_cfg = self.config.get("text_processor")
|
||||
|
||||
if vis_proc_cfg is not None:
|
||||
vis_train_cfg = vis_proc_cfg.get("train")
|
||||
vis_eval_cfg = vis_proc_cfg.get("eval")
|
||||
|
||||
self.vis_processors["train"] = self._build_proc_from_cfg(vis_train_cfg)
|
||||
self.vis_processors["eval"] = self._build_proc_from_cfg(vis_eval_cfg)
|
||||
|
||||
if txt_proc_cfg is not None:
|
||||
txt_train_cfg = txt_proc_cfg.get("train")
|
||||
txt_eval_cfg = txt_proc_cfg.get("eval")
|
||||
|
||||
self.text_processors["train"] = self._build_proc_from_cfg(txt_train_cfg)
|
||||
self.text_processors["eval"] = self._build_proc_from_cfg(txt_eval_cfg)
|
||||
|
||||
@staticmethod
|
||||
def _build_proc_from_cfg(cfg):
|
||||
return (
|
||||
registry.get_processor_class(cfg.name).from_config(cfg)
|
||||
if cfg is not None
|
||||
else None
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def default_config_path(cls, type="default"):
|
||||
return utils.get_abs_path(cls.DATASET_CONFIG_DICT[type])
|
||||
|
||||
def _download_data(self):
|
||||
self._download_ann()
|
||||
self._download_vis()
|
||||
|
||||
def _download_ann(self):
|
||||
"""
|
||||
Download annotation files if necessary.
|
||||
All the vision-language datasets should have annotations of unified format.
|
||||
|
||||
storage_path can be:
|
||||
(1) relative/absolute: will be prefixed with env.cache_root to make full path if relative.
|
||||
(2) basename/dirname: will be suffixed with base name of URL if dirname is provided.
|
||||
|
||||
Local annotation paths should be relative.
|
||||
"""
|
||||
anns = self.config.build_info.annotations
|
||||
|
||||
splits = anns.keys()
|
||||
|
||||
cache_root = registry.get_path("cache_root")
|
||||
|
||||
for split in splits:
|
||||
info = anns[split]
|
||||
|
||||
urls, storage_paths = info.get("url", None), info.storage
|
||||
|
||||
if isinstance(urls, str):
|
||||
urls = [urls]
|
||||
if isinstance(storage_paths, str):
|
||||
storage_paths = [storage_paths]
|
||||
|
||||
assert len(urls) == len(storage_paths)
|
||||
|
||||
for url_or_filename, storage_path in zip(urls, storage_paths):
|
||||
# if storage_path is relative, make it full by prefixing with cache_root.
|
||||
if not os.path.isabs(storage_path):
|
||||
storage_path = os.path.join(cache_root, storage_path)
|
||||
|
||||
dirname = os.path.dirname(storage_path)
|
||||
if not os.path.exists(dirname):
|
||||
os.makedirs(dirname)
|
||||
|
||||
if os.path.isfile(url_or_filename):
|
||||
src, dst = url_or_filename, storage_path
|
||||
if not os.path.exists(dst):
|
||||
shutil.copyfile(src=src, dst=dst)
|
||||
else:
|
||||
logging.info("Using existing file {}.".format(dst))
|
||||
else:
|
||||
if os.path.isdir(storage_path):
|
||||
# if only dirname is provided, suffix with basename of URL.
|
||||
raise ValueError(
|
||||
"Expecting storage_path to be a file path, got directory {}".format(
|
||||
storage_path
|
||||
)
|
||||
)
|
||||
else:
|
||||
filename = os.path.basename(storage_path)
|
||||
|
||||
download_url(url=url_or_filename, root=dirname, filename=filename)
|
||||
|
||||
def _download_vis(self):
|
||||
|
||||
storage_path = self.config.build_info.get(self.data_type).storage
|
||||
storage_path = utils.get_cache_path(storage_path)
|
||||
|
||||
if not os.path.exists(storage_path):
|
||||
warnings.warn(
|
||||
f"""
|
||||
The specified path {storage_path} for visual inputs does not exist.
|
||||
Please provide a correct path to the visual inputs or
|
||||
refer to datasets/download_scripts/README.md for downloading instructions.
|
||||
"""
|
||||
)
|
||||
|
||||
def build(self):
|
||||
"""
|
||||
Create by split datasets inheriting torch.utils.data.Datasets.
|
||||
|
||||
# build() can be dataset-specific. Overwrite to customize.
|
||||
"""
|
||||
self.build_processors()
|
||||
|
||||
build_info = self.config.build_info
|
||||
|
||||
ann_info = build_info.annotations
|
||||
vis_info = build_info.get(self.data_type)
|
||||
|
||||
datasets = dict()
|
||||
for split in ann_info.keys():
|
||||
if split not in ["train", "val", "test"]:
|
||||
continue
|
||||
|
||||
is_train = split == "train"
|
||||
|
||||
# processors
|
||||
vis_processor = (
|
||||
self.vis_processors["train"]
|
||||
if is_train
|
||||
else self.vis_processors["eval"]
|
||||
)
|
||||
text_processor = (
|
||||
self.text_processors["train"]
|
||||
if is_train
|
||||
else self.text_processors["eval"]
|
||||
)
|
||||
|
||||
# annotation path
|
||||
ann_paths = ann_info.get(split).storage
|
||||
if isinstance(ann_paths, str):
|
||||
ann_paths = [ann_paths]
|
||||
|
||||
abs_ann_paths = []
|
||||
for ann_path in ann_paths:
|
||||
if not os.path.isabs(ann_path):
|
||||
ann_path = utils.get_cache_path(ann_path)
|
||||
abs_ann_paths.append(ann_path)
|
||||
ann_paths = abs_ann_paths
|
||||
|
||||
# visual data storage path
|
||||
vis_path = os.path.join(vis_info.storage, split)
|
||||
|
||||
if not os.path.isabs(vis_path):
|
||||
# vis_path = os.path.join(utils.get_cache_path(), vis_path)
|
||||
vis_path = utils.get_cache_path(vis_path)
|
||||
|
||||
if not os.path.exists(vis_path):
|
||||
warnings.warn("storage path {} does not exist.".format(vis_path))
|
||||
|
||||
# create datasets
|
||||
dataset_cls = self.train_dataset_cls if is_train else self.eval_dataset_cls
|
||||
datasets[split] = dataset_cls(
|
||||
vis_processor=vis_processor,
|
||||
text_processor=text_processor,
|
||||
ann_paths=ann_paths,
|
||||
vis_root=vis_path,
|
||||
)
|
||||
|
||||
return datasets
|
||||
|
||||
|
||||
def load_dataset_config(cfg_path):
|
||||
cfg = OmegaConf.load(cfg_path).datasets
|
||||
cfg = cfg[list(cfg.keys())[0]]
|
||||
|
||||
return cfg
|
104
minigpt4/datasets/builders/image_text_pair_builder.py
Normal file
@ -0,0 +1,104 @@
|
||||
import os
|
||||
import logging
|
||||
|
||||
from minigpt4.common.registry import registry
|
||||
from minigpt4.datasets.builders.base_dataset_builder import BaseDatasetBuilder
|
||||
from minigpt4.datasets.datasets.laion_dataset import LaionDataset
|
||||
from minigpt4.datasets.datasets.cc_sbu_dataset import CCSBUDataset, CCSBUAlignDataset
|
||||
|
||||
|
||||
@registry.register_builder("cc_sbu")
|
||||
class CCSBUBuilder(BaseDatasetBuilder):
|
||||
train_dataset_cls = CCSBUDataset
|
||||
|
||||
DATASET_CONFIG_DICT = {"default": "configs/datasets/cc_sbu/defaults.yaml"}
|
||||
|
||||
def _download_ann(self):
|
||||
pass
|
||||
|
||||
def _download_vis(self):
|
||||
pass
|
||||
|
||||
def build(self):
|
||||
self.build_processors()
|
||||
|
||||
build_info = self.config.build_info
|
||||
|
||||
datasets = dict()
|
||||
split = "train"
|
||||
|
||||
# create datasets
|
||||
# [NOTE] return inner_datasets (wds.DataPipeline)
|
||||
dataset_cls = self.train_dataset_cls
|
||||
datasets[split] = dataset_cls(
|
||||
vis_processor=self.vis_processors[split],
|
||||
text_processor=self.text_processors[split],
|
||||
location=build_info.storage,
|
||||
).inner_dataset
|
||||
|
||||
return datasets
|
||||
|
||||
|
||||
@registry.register_builder("laion")
|
||||
class LaionBuilder(BaseDatasetBuilder):
|
||||
train_dataset_cls = LaionDataset
|
||||
|
||||
DATASET_CONFIG_DICT = {"default": "configs/datasets/laion/defaults.yaml"}
|
||||
|
||||
def _download_ann(self):
|
||||
pass
|
||||
|
||||
def _download_vis(self):
|
||||
pass
|
||||
|
||||
def build(self):
|
||||
self.build_processors()
|
||||
|
||||
build_info = self.config.build_info
|
||||
|
||||
datasets = dict()
|
||||
split = "train"
|
||||
|
||||
# create datasets
|
||||
# [NOTE] return inner_datasets (wds.DataPipeline)
|
||||
dataset_cls = self.train_dataset_cls
|
||||
datasets[split] = dataset_cls(
|
||||
vis_processor=self.vis_processors[split],
|
||||
text_processor=self.text_processors[split],
|
||||
location=build_info.storage,
|
||||
).inner_dataset
|
||||
|
||||
return datasets
|
||||
|
||||
|
||||
@registry.register_builder("cc_sbu_align")
|
||||
class CCSBUAlignBuilder(BaseDatasetBuilder):
|
||||
train_dataset_cls = CCSBUAlignDataset
|
||||
|
||||
DATASET_CONFIG_DICT = {
|
||||
"default": "configs/datasets/cc_sbu/align.yaml",
|
||||
}
|
||||
|
||||
def build_datasets(self):
|
||||
# at this point, all the annotations and image/videos should be all downloaded to the specified locations.
|
||||
logging.info("Building datasets...")
|
||||
self.build_processors()
|
||||
|
||||
build_info = self.config.build_info
|
||||
storage_path = build_info.storage
|
||||
|
||||
datasets = dict()
|
||||
|
||||
if not os.path.exists(storage_path):
|
||||
warnings.warn("storage path {} does not exist.".format(vis_path))
|
||||
|
||||
# create datasets
|
||||
dataset_cls = self.train_dataset_cls
|
||||
datasets['train'] = dataset_cls(
|
||||
vis_processor=self.vis_processors["train"],
|
||||
text_processor=self.text_processors["train"],
|
||||
ann_paths=[os.path.join(storage_path, 'filter_cap.json')],
|
||||
vis_root=os.path.join(storage_path, 'image'),
|
||||
)
|
||||
|
||||
return datasets
|
196
minigpt4/datasets/data_utils.py
Normal file
@ -0,0 +1,196 @@
|
||||
"""
|
||||
Copyright (c) 2022, salesforce.com, inc.
|
||||
All rights reserved.
|
||||
SPDX-License-Identifier: BSD-3-Clause
|
||||
For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
||||
"""
|
||||
|
||||
import gzip
|
||||
import logging
|
||||
import os
|
||||
import random as rnd
|
||||
import tarfile
|
||||
import zipfile
|
||||
import random
|
||||
from typing import List
|
||||
from tqdm import tqdm
|
||||
|
||||
import decord
|
||||
from decord import VideoReader
|
||||
import webdataset as wds
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.data.dataset import IterableDataset
|
||||
|
||||
from minigpt4.common.registry import registry
|
||||
from minigpt4.datasets.datasets.base_dataset import ConcatDataset
|
||||
|
||||
|
||||
decord.bridge.set_bridge("torch")
|
||||
MAX_INT = registry.get("MAX_INT")
|
||||
|
||||
|
||||
class ChainDataset(wds.DataPipeline):
|
||||
r"""Dataset for chaining multiple :class:`DataPipeline` s.
|
||||
|
||||
This class is useful to assemble different existing dataset streams. The
|
||||
chaining operation is done on-the-fly, so concatenating large-scale
|
||||
datasets with this class will be efficient.
|
||||
|
||||
Args:
|
||||
datasets (iterable of IterableDataset): datasets to be chained together
|
||||
"""
|
||||
def __init__(self, datasets: List[wds.DataPipeline]) -> None:
|
||||
super().__init__()
|
||||
self.datasets = datasets
|
||||
self.prob = []
|
||||
self.names = []
|
||||
for dataset in self.datasets:
|
||||
if hasattr(dataset, 'name'):
|
||||
self.names.append(dataset.name)
|
||||
else:
|
||||
self.names.append('Unknown')
|
||||
if hasattr(dataset, 'sample_ratio'):
|
||||
self.prob.append(dataset.sample_ratio)
|
||||
else:
|
||||
self.prob.append(1)
|
||||
logging.info("One of the datapipeline doesn't define ratio and set to 1 automatically.")
|
||||
|
||||
def __iter__(self):
|
||||
datastreams = [iter(dataset) for dataset in self.datasets]
|
||||
while True:
|
||||
select_datastream = random.choices(datastreams, weights=self.prob, k=1)[0]
|
||||
yield next(select_datastream)
|
||||
|
||||
|
||||
def apply_to_sample(f, sample):
|
||||
if len(sample) == 0:
|
||||
return {}
|
||||
|
||||
def _apply(x):
|
||||
if torch.is_tensor(x):
|
||||
return f(x)
|
||||
elif isinstance(x, dict):
|
||||
return {key: _apply(value) for key, value in x.items()}
|
||||
elif isinstance(x, list):
|
||||
return [_apply(x) for x in x]
|
||||
else:
|
||||
return x
|
||||
|
||||
return _apply(sample)
|
||||
|
||||
|
||||
def move_to_cuda(sample):
|
||||
def _move_to_cuda(tensor):
|
||||
return tensor.cuda()
|
||||
|
||||
return apply_to_sample(_move_to_cuda, sample)
|
||||
|
||||
|
||||
def prepare_sample(samples, cuda_enabled=True):
|
||||
if cuda_enabled:
|
||||
samples = move_to_cuda(samples)
|
||||
|
||||
# TODO fp16 support
|
||||
|
||||
return samples
|
||||
|
||||
|
||||
def reorg_datasets_by_split(datasets):
|
||||
"""
|
||||
Organizes datasets by split.
|
||||
|
||||
Args:
|
||||
datasets: dict of torch.utils.data.Dataset objects by name.
|
||||
|
||||
Returns:
|
||||
Dict of datasets by split {split_name: List[Datasets]}.
|
||||
"""
|
||||
# if len(datasets) == 1:
|
||||
# return datasets[list(datasets.keys())[0]]
|
||||
# else:
|
||||
reorg_datasets = dict()
|
||||
|
||||
# reorganize by split
|
||||
for _, dataset in datasets.items():
|
||||
for split_name, dataset_split in dataset.items():
|
||||
if split_name not in reorg_datasets:
|
||||
reorg_datasets[split_name] = [dataset_split]
|
||||
else:
|
||||
reorg_datasets[split_name].append(dataset_split)
|
||||
|
||||
return reorg_datasets
|
||||
|
||||
|
||||
def concat_datasets(datasets):
|
||||
"""
|
||||
Concatenates multiple datasets into a single dataset.
|
||||
|
||||
It supports may-style datasets and DataPipeline from WebDataset. Currently, does not support
|
||||
generic IterableDataset because it requires creating separate samplers.
|
||||
|
||||
Now only supports conctenating training datasets and assuming validation and testing
|
||||
have only a single dataset. This is because metrics should not be computed on the concatenated
|
||||
datasets.
|
||||
|
||||
Args:
|
||||
datasets: dict of torch.utils.data.Dataset objects by split.
|
||||
|
||||
Returns:
|
||||
Dict of concatenated datasets by split, "train" is the concatenation of multiple datasets,
|
||||
"val" and "test" remain the same.
|
||||
|
||||
If the input training datasets contain both map-style and DataPipeline datasets, returns
|
||||
a tuple, where the first element is a concatenated map-style dataset and the second
|
||||
element is a chained DataPipeline dataset.
|
||||
|
||||
"""
|
||||
# concatenate datasets in the same split
|
||||
for split_name in datasets:
|
||||
if split_name != "train":
|
||||
assert (
|
||||
len(datasets[split_name]) == 1
|
||||
), "Do not support multiple {} datasets.".format(split_name)
|
||||
datasets[split_name] = datasets[split_name][0]
|
||||
else:
|
||||
iterable_datasets, map_datasets = [], []
|
||||
for dataset in datasets[split_name]:
|
||||
if isinstance(dataset, wds.DataPipeline):
|
||||
logging.info(
|
||||
"Dataset {} is IterableDataset, can't be concatenated.".format(
|
||||
dataset
|
||||
)
|
||||
)
|
||||
iterable_datasets.append(dataset)
|
||||
elif isinstance(dataset, IterableDataset):
|
||||
raise NotImplementedError(
|
||||
"Do not support concatenation of generic IterableDataset."
|
||||
)
|
||||
else:
|
||||
map_datasets.append(dataset)
|
||||
|
||||
# if len(iterable_datasets) > 0:
|
||||
# concatenate map-style datasets and iterable-style datasets separately
|
||||
if len(iterable_datasets) > 1:
|
||||
chained_datasets = (
|
||||
ChainDataset(iterable_datasets)
|
||||
)
|
||||
elif len(iterable_datasets) == 1:
|
||||
chained_datasets = iterable_datasets[0]
|
||||
else:
|
||||
chained_datasets = None
|
||||
|
||||
concat_datasets = (
|
||||
ConcatDataset(map_datasets) if len(map_datasets) > 0 else None
|
||||
)
|
||||
|
||||
train_datasets = concat_datasets, chained_datasets
|
||||
train_datasets = tuple([x for x in train_datasets if x is not None])
|
||||
train_datasets = (
|
||||
train_datasets[0] if len(train_datasets) == 1 else train_datasets
|
||||
)
|
||||
|
||||
datasets[split_name] = train_datasets
|
||||
|
||||
return datasets
|
||||
|
0
minigpt4/datasets/datasets/__init__.py
Normal file
68
minigpt4/datasets/datasets/base_dataset.py
Normal file
@ -0,0 +1,68 @@
|
||||
"""
|
||||
Copyright (c) 2022, salesforce.com, inc.
|
||||
All rights reserved.
|
||||
SPDX-License-Identifier: BSD-3-Clause
|
||||
For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
||||
"""
|
||||
|
||||
import json
|
||||
from typing import Iterable
|
||||
|
||||
from torch.utils.data import Dataset, ConcatDataset
|
||||
from torch.utils.data.dataloader import default_collate
|
||||
|
||||
|
||||
class BaseDataset(Dataset):
|
||||
def __init__(
|
||||
self, vis_processor=None, text_processor=None, vis_root=None, ann_paths=[]
|
||||
):
|
||||
"""
|
||||
vis_root (string): Root directory of images (e.g. coco/images/)
|
||||
ann_root (string): directory to store the annotation file
|
||||
"""
|
||||
self.vis_root = vis_root
|
||||
|
||||
self.annotation = []
|
||||
for ann_path in ann_paths:
|
||||
self.annotation.extend(json.load(open(ann_path, "r"))['annotations'])
|
||||
|
||||
self.vis_processor = vis_processor
|
||||
self.text_processor = text_processor
|
||||
|
||||
self._add_instance_ids()
|
||||
|
||||
def __len__(self):
|
||||
return len(self.annotation)
|
||||
|
||||
def collater(self, samples):
|
||||
return default_collate(samples)
|
||||
|
||||
def set_processors(self, vis_processor, text_processor):
|
||||
self.vis_processor = vis_processor
|
||||
self.text_processor = text_processor
|
||||
|
||||
def _add_instance_ids(self, key="instance_id"):
|
||||
for idx, ann in enumerate(self.annotation):
|
||||
ann[key] = str(idx)
|
||||
|
||||
|
||||
class ConcatDataset(ConcatDataset):
|
||||
def __init__(self, datasets: Iterable[Dataset]) -> None:
|
||||
super().__init__(datasets)
|
||||
|
||||
def collater(self, samples):
|
||||
# TODO For now only supports datasets with same underlying collater implementations
|
||||
|
||||
all_keys = set()
|
||||
for s in samples:
|
||||
all_keys.update(s)
|
||||
|
||||
shared_keys = all_keys
|
||||
for s in samples:
|
||||
shared_keys = shared_keys & set(s.keys())
|
||||
|
||||
samples_shared_keys = []
|
||||
for s in samples:
|
||||
samples_shared_keys.append({k: s[k] for k in s.keys() if k in shared_keys})
|
||||
|
||||
return self.datasets[0].collater(samples_shared_keys)
|
85
minigpt4/datasets/datasets/caption_datasets.py
Normal file
@ -0,0 +1,85 @@
|
||||
"""
|
||||
Copyright (c) 2022, salesforce.com, inc.
|
||||
All rights reserved.
|
||||
SPDX-License-Identifier: BSD-3-Clause
|
||||
For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
||||
"""
|
||||
|
||||
import os
|
||||
from collections import OrderedDict
|
||||
|
||||
from minigpt4.datasets.datasets.base_dataset import BaseDataset
|
||||
from PIL import Image
|
||||
|
||||
|
||||
class __DisplMixin:
|
||||
def displ_item(self, index):
|
||||
sample, ann = self.__getitem__(index), self.annotation[index]
|
||||
|
||||
return OrderedDict(
|
||||
{
|
||||
"file": ann["image"],
|
||||
"caption": ann["caption"],
|
||||
"image": sample["image"],
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class CaptionDataset(BaseDataset, __DisplMixin):
|
||||
def __init__(self, vis_processor, text_processor, vis_root, ann_paths):
|
||||
"""
|
||||
vis_root (string): Root directory of images (e.g. coco/images/)
|
||||
ann_root (string): directory to store the annotation file
|
||||
"""
|
||||
super().__init__(vis_processor, text_processor, vis_root, ann_paths)
|
||||
|
||||
self.img_ids = {}
|
||||
n = 0
|
||||
for ann in self.annotation:
|
||||
img_id = ann["image_id"]
|
||||
if img_id not in self.img_ids.keys():
|
||||
self.img_ids[img_id] = n
|
||||
n += 1
|
||||
|
||||
def __getitem__(self, index):
|
||||
|
||||
# TODO this assumes image input, not general enough
|
||||
ann = self.annotation[index]
|
||||
|
||||
img_file = '{:0>12}.jpg'.format(ann["image_id"])
|
||||
image_path = os.path.join(self.vis_root, img_file)
|
||||
image = Image.open(image_path).convert("RGB")
|
||||
|
||||
image = self.vis_processor(image)
|
||||
caption = self.text_processor(ann["caption"])
|
||||
|
||||
return {
|
||||
"image": image,
|
||||
"text_input": caption,
|
||||
"image_id": self.img_ids[ann["image_id"]],
|
||||
}
|
||||
|
||||
|
||||
class CaptionEvalDataset(BaseDataset, __DisplMixin):
|
||||
def __init__(self, vis_processor, text_processor, vis_root, ann_paths):
|
||||
"""
|
||||
vis_root (string): Root directory of images (e.g. coco/images/)
|
||||
ann_root (string): directory to store the annotation file
|
||||
split (string): val or test
|
||||
"""
|
||||
super().__init__(vis_processor, text_processor, vis_root, ann_paths)
|
||||
|
||||
def __getitem__(self, index):
|
||||
|
||||
ann = self.annotation[index]
|
||||
|
||||
image_path = os.path.join(self.vis_root, ann["image"])
|
||||
image = Image.open(image_path).convert("RGB")
|
||||
|
||||
image = self.vis_processor(image)
|
||||
|
||||
return {
|
||||
"image": image,
|
||||
"image_id": ann["image_id"],
|
||||
"instance_id": ann["instance_id"],
|
||||
}
|
47
minigpt4/datasets/datasets/cc_sbu_dataset.py
Normal file
@ -0,0 +1,47 @@
|
||||
import os
|
||||
from PIL import Image
|
||||
import webdataset as wds
|
||||
from minigpt4.datasets.datasets.base_dataset import BaseDataset
|
||||
from minigpt4.datasets.datasets.caption_datasets import CaptionDataset
|
||||
|
||||
|
||||
class CCSBUDataset(BaseDataset):
|
||||
def __init__(self, vis_processor, text_processor, location):
|
||||
super().__init__(vis_processor=vis_processor, text_processor=text_processor)
|
||||
|
||||
self.inner_dataset = wds.DataPipeline(
|
||||
wds.ResampledShards(location),
|
||||
wds.tarfile_to_samples(handler=wds.warn_and_continue),
|
||||
wds.shuffle(1000, handler=wds.warn_and_continue),
|
||||
wds.decode("pilrgb", handler=wds.warn_and_continue),
|
||||
wds.to_tuple("jpg", "json", handler=wds.warn_and_continue),
|
||||
wds.map_tuple(self.vis_processor, handler=wds.warn_and_continue),
|
||||
wds.map(self.to_dict, handler=wds.warn_and_continue),
|
||||
)
|
||||
|
||||
def to_dict(self, sample):
|
||||
return {
|
||||
"image": sample[0],
|
||||
"text_input": self.text_processor(sample[1]["caption"]),
|
||||
}
|
||||
|
||||
|
||||
class CCSBUAlignDataset(CaptionDataset):
|
||||
|
||||
def __getitem__(self, index):
|
||||
|
||||
# TODO this assumes image input, not general enough
|
||||
ann = self.annotation[index]
|
||||
|
||||
img_file = '{}.jpg'.format(ann["image_id"])
|
||||
image_path = os.path.join(self.vis_root, img_file)
|
||||
image = Image.open(image_path).convert("RGB")
|
||||
|
||||
image = self.vis_processor(image)
|
||||
caption = ann["caption"]
|
||||
|
||||
return {
|
||||
"image": image,
|
||||
"text_input": caption,
|
||||
"image_id": self.img_ids[ann["image_id"]],
|
||||
}
|
162
minigpt4/datasets/datasets/dataloader_utils.py
Normal file
@ -0,0 +1,162 @@
|
||||
"""
|
||||
Copyright (c) 2022, salesforce.com, inc.
|
||||
All rights reserved.
|
||||
SPDX-License-Identifier: BSD-3-Clause
|
||||
For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
||||
"""
|
||||
|
||||
import time
|
||||
import random
|
||||
import torch
|
||||
from minigpt4.datasets.data_utils import move_to_cuda
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
|
||||
class MultiIterLoader:
|
||||
"""
|
||||
A simple wrapper for iterating over multiple iterators.
|
||||
|
||||
Args:
|
||||
loaders (List[Loader]): List of Iterator loaders.
|
||||
ratios (List[float]): List of ratios to sample from each loader. If None, all loaders are sampled uniformly.
|
||||
"""
|
||||
|
||||
def __init__(self, loaders, ratios=None):
|
||||
# assert all loaders has __next__ method
|
||||
for loader in loaders:
|
||||
assert hasattr(
|
||||
loader, "__next__"
|
||||
), "Loader {} has no __next__ method.".format(loader)
|
||||
|
||||
if ratios is None:
|
||||
ratios = [1.0] * len(loaders)
|
||||
else:
|
||||
assert len(ratios) == len(loaders)
|
||||
ratios = [float(ratio) / sum(ratios) for ratio in ratios]
|
||||
|
||||
self.loaders = loaders
|
||||
self.ratios = ratios
|
||||
|
||||
def __next__(self):
|
||||
# random sample from each loader by ratio
|
||||
loader_idx = random.choices(range(len(self.loaders)), self.ratios, k=1)[0]
|
||||
return next(self.loaders[loader_idx])
|
||||
|
||||
|
||||
class PrefetchLoader(object):
|
||||
"""
|
||||
Modified from https://github.com/ChenRocks/UNITER.
|
||||
|
||||
overlap compute and cuda data transfer
|
||||
(copied and then modified from nvidia apex)
|
||||
"""
|
||||
|
||||
def __init__(self, loader):
|
||||
self.loader = loader
|
||||
self.stream = torch.cuda.Stream()
|
||||
|
||||
def __iter__(self):
|
||||
loader_it = iter(self.loader)
|
||||
self.preload(loader_it)
|
||||
batch = self.next(loader_it)
|
||||
while batch is not None:
|
||||
is_tuple = isinstance(batch, tuple)
|
||||
if is_tuple:
|
||||
task, batch = batch
|
||||
|
||||
if is_tuple:
|
||||
yield task, batch
|
||||
else:
|
||||
yield batch
|
||||
batch = self.next(loader_it)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.loader)
|
||||
|
||||
def preload(self, it):
|
||||
try:
|
||||
self.batch = next(it)
|
||||
except StopIteration:
|
||||
self.batch = None
|
||||
return
|
||||
# if record_stream() doesn't work, another option is to make sure
|
||||
# device inputs are created on the main stream.
|
||||
# self.next_input_gpu = torch.empty_like(self.next_input,
|
||||
# device='cuda')
|
||||
# self.next_target_gpu = torch.empty_like(self.next_target,
|
||||
# device='cuda')
|
||||
# Need to make sure the memory allocated for next_* is not still in use
|
||||
# by the main stream at the time we start copying to next_*:
|
||||
# self.stream.wait_stream(torch.cuda.current_stream())
|
||||
with torch.cuda.stream(self.stream):
|
||||
self.batch = move_to_cuda(self.batch)
|
||||
# more code for the alternative if record_stream() doesn't work:
|
||||
# copy_ will record the use of the pinned source tensor in this
|
||||
# side stream.
|
||||
# self.next_input_gpu.copy_(self.next_input, non_blocking=True)
|
||||
# self.next_target_gpu.copy_(self.next_target, non_blocking=True)
|
||||
# self.next_input = self.next_input_gpu
|
||||
# self.next_target = self.next_target_gpu
|
||||
|
||||
def next(self, it):
|
||||
torch.cuda.current_stream().wait_stream(self.stream)
|
||||
batch = self.batch
|
||||
if batch is not None:
|
||||
record_cuda_stream(batch)
|
||||
self.preload(it)
|
||||
return batch
|
||||
|
||||
def __getattr__(self, name):
|
||||
method = self.loader.__getattribute__(name)
|
||||
return method
|
||||
|
||||
|
||||
def record_cuda_stream(batch):
|
||||
if isinstance(batch, torch.Tensor):
|
||||
batch.record_stream(torch.cuda.current_stream())
|
||||
elif isinstance(batch, list) or isinstance(batch, tuple):
|
||||
for t in batch:
|
||||
record_cuda_stream(t)
|
||||
elif isinstance(batch, dict):
|
||||
for t in batch.values():
|
||||
record_cuda_stream(t)
|
||||
else:
|
||||
pass
|
||||
|
||||
|
||||
class IterLoader:
|
||||
"""
|
||||
A wrapper to convert DataLoader as an infinite iterator.
|
||||
|
||||
Modified from:
|
||||
https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/iter_based_runner.py
|
||||
"""
|
||||
|
||||
def __init__(self, dataloader: DataLoader, use_distributed: bool = False):
|
||||
self._dataloader = dataloader
|
||||
self.iter_loader = iter(self._dataloader)
|
||||
self._use_distributed = use_distributed
|
||||
self._epoch = 0
|
||||
|
||||
@property
|
||||
def epoch(self) -> int:
|
||||
return self._epoch
|
||||
|
||||
def __next__(self):
|
||||
try:
|
||||
data = next(self.iter_loader)
|
||||
except StopIteration:
|
||||
self._epoch += 1
|
||||
if hasattr(self._dataloader.sampler, "set_epoch") and self._use_distributed:
|
||||
self._dataloader.sampler.set_epoch(self._epoch)
|
||||
time.sleep(2) # Prevent possible deadlock during epoch transition
|
||||
self.iter_loader = iter(self._dataloader)
|
||||
data = next(self.iter_loader)
|
||||
|
||||
return data
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def __len__(self):
|
||||
return len(self._dataloader)
|
31
minigpt4/datasets/datasets/laion_dataset.py
Normal file
@ -0,0 +1,31 @@
|
||||
"""
|
||||
Copyright (c) 2022, salesforce.com, inc.
|
||||
All rights reserved.
|
||||
SPDX-License-Identifier: BSD-3-Clause
|
||||
For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
||||
"""
|
||||
|
||||
import webdataset as wds
|
||||
from minigpt4.datasets.datasets.base_dataset import BaseDataset
|
||||
|
||||
|
||||
class LaionDataset(BaseDataset):
|
||||
def __init__(self, vis_processor, text_processor, location):
|
||||
super().__init__(vis_processor=vis_processor, text_processor=text_processor)
|
||||
|
||||
self.inner_dataset = wds.DataPipeline(
|
||||
wds.ResampledShards(location),
|
||||
wds.tarfile_to_samples(handler=wds.warn_and_continue),
|
||||
wds.shuffle(1000, handler=wds.warn_and_continue),
|
||||
wds.decode("pilrgb", handler=wds.warn_and_continue),
|
||||
wds.to_tuple("jpg", "json", handler=wds.warn_and_continue),
|
||||
wds.map_tuple(self.vis_processor, handler=wds.warn_and_continue),
|
||||
wds.map(self.to_dict, handler=wds.warn_and_continue),
|
||||
)
|
||||
|
||||
def to_dict(self, sample):
|
||||
return {
|
||||
"image": sample[0],
|
||||
"text_input": self.text_processor(sample[1]["caption"]),
|
||||
}
|
||||
|
1216
minigpt4/models/Qformer.py
Normal file
200
minigpt4/models/__init__.py
Normal file
@ -0,0 +1,200 @@
|
||||
"""
|
||||
Copyright (c) 2022, salesforce.com, inc.
|
||||
All rights reserved.
|
||||
SPDX-License-Identifier: BSD-3-Clause
|
||||
For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
||||
"""
|
||||
|
||||
import logging
|
||||
import torch
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
from minigpt4.common.registry import registry
|
||||
from minigpt4.models.base_model import BaseModel
|
||||
from minigpt4.models.blip2 import Blip2Base
|
||||
from minigpt4.models.mini_gpt4 import MiniGPT4
|
||||
from minigpt4.processors.base_processor import BaseProcessor
|
||||
|
||||
|
||||
__all__ = [
|
||||
"load_model",
|
||||
"BaseModel",
|
||||
"Blip2Base",
|
||||
"MiniGPT4",
|
||||
]
|
||||
|
||||
|
||||
def load_model(name, model_type, is_eval=False, device="cpu", checkpoint=None):
|
||||
"""
|
||||
Load supported models.
|
||||
|
||||
To list all available models and types in registry:
|
||||
>>> from minigpt4.models import model_zoo
|
||||
>>> print(model_zoo)
|
||||
|
||||
Args:
|
||||
name (str): name of the model.
|
||||
model_type (str): type of the model.
|
||||
is_eval (bool): whether the model is in eval mode. Default: False.
|
||||
device (str): device to use. Default: "cpu".
|
||||
checkpoint (str): path or to checkpoint. Default: None.
|
||||
Note that expecting the checkpoint to have the same keys in state_dict as the model.
|
||||
|
||||
Returns:
|
||||
model (torch.nn.Module): model.
|
||||
"""
|
||||
|
||||
model = registry.get_model_class(name).from_pretrained(model_type=model_type)
|
||||
|
||||
if checkpoint is not None:
|
||||
model.load_checkpoint(checkpoint)
|
||||
|
||||
if is_eval:
|
||||
model.eval()
|
||||
|
||||
if device == "cpu":
|
||||
model = model.float()
|
||||
|
||||
return model.to(device)
|
||||
|
||||
|
||||
def load_preprocess(config):
|
||||
"""
|
||||
Load preprocessor configs and construct preprocessors.
|
||||
|
||||
If no preprocessor is specified, return BaseProcessor, which does not do any preprocessing.
|
||||
|
||||
Args:
|
||||
config (dict): preprocessor configs.
|
||||
|
||||
Returns:
|
||||
vis_processors (dict): preprocessors for visual inputs.
|
||||
txt_processors (dict): preprocessors for text inputs.
|
||||
|
||||
Key is "train" or "eval" for processors used in training and evaluation respectively.
|
||||
"""
|
||||
|
||||
def _build_proc_from_cfg(cfg):
|
||||
return (
|
||||
registry.get_processor_class(cfg.name).from_config(cfg)
|
||||
if cfg is not None
|
||||
else BaseProcessor()
|
||||
)
|
||||
|
||||
vis_processors = dict()
|
||||
txt_processors = dict()
|
||||
|
||||
vis_proc_cfg = config.get("vis_processor")
|
||||
txt_proc_cfg = config.get("text_processor")
|
||||
|
||||
if vis_proc_cfg is not None:
|
||||
vis_train_cfg = vis_proc_cfg.get("train")
|
||||
vis_eval_cfg = vis_proc_cfg.get("eval")
|
||||
else:
|
||||
vis_train_cfg = None
|
||||
vis_eval_cfg = None
|
||||
|
||||
vis_processors["train"] = _build_proc_from_cfg(vis_train_cfg)
|
||||
vis_processors["eval"] = _build_proc_from_cfg(vis_eval_cfg)
|
||||
|
||||
if txt_proc_cfg is not None:
|
||||
txt_train_cfg = txt_proc_cfg.get("train")
|
||||
txt_eval_cfg = txt_proc_cfg.get("eval")
|
||||
else:
|
||||
txt_train_cfg = None
|
||||
txt_eval_cfg = None
|
||||
|
||||
txt_processors["train"] = _build_proc_from_cfg(txt_train_cfg)
|
||||
txt_processors["eval"] = _build_proc_from_cfg(txt_eval_cfg)
|
||||
|
||||
return vis_processors, txt_processors
|
||||
|
||||
|
||||
def load_model_and_preprocess(name, model_type, is_eval=False, device="cpu"):
|
||||
"""
|
||||
Load model and its related preprocessors.
|
||||
|
||||
List all available models and types in registry:
|
||||
>>> from minigpt4.models import model_zoo
|
||||
>>> print(model_zoo)
|
||||
|
||||
Args:
|
||||
name (str): name of the model.
|
||||
model_type (str): type of the model.
|
||||
is_eval (bool): whether the model is in eval mode. Default: False.
|
||||
device (str): device to use. Default: "cpu".
|
||||
|
||||
Returns:
|
||||
model (torch.nn.Module): model.
|
||||
vis_processors (dict): preprocessors for visual inputs.
|
||||
txt_processors (dict): preprocessors for text inputs.
|
||||
"""
|
||||
model_cls = registry.get_model_class(name)
|
||||
|
||||
# load model
|
||||
model = model_cls.from_pretrained(model_type=model_type)
|
||||
|
||||
if is_eval:
|
||||
model.eval()
|
||||
|
||||
# load preprocess
|
||||
cfg = OmegaConf.load(model_cls.default_config_path(model_type))
|
||||
if cfg is not None:
|
||||
preprocess_cfg = cfg.preprocess
|
||||
|
||||
vis_processors, txt_processors = load_preprocess(preprocess_cfg)
|
||||
else:
|
||||
vis_processors, txt_processors = None, None
|
||||
logging.info(
|
||||
f"""No default preprocess for model {name} ({model_type}).
|
||||
This can happen if the model is not finetuned on downstream datasets,
|
||||
or it is not intended for direct use without finetuning.
|
||||
"""
|
||||
)
|
||||
|
||||
if device == "cpu" or device == torch.device("cpu"):
|
||||
model = model.float()
|
||||
|
||||
return model.to(device), vis_processors, txt_processors
|
||||
|
||||
|
||||
class ModelZoo:
|
||||
"""
|
||||
A utility class to create string representation of available model architectures and types.
|
||||
|
||||
>>> from minigpt4.models import model_zoo
|
||||
>>> # list all available models
|
||||
>>> print(model_zoo)
|
||||
>>> # show total number of models
|
||||
>>> print(len(model_zoo))
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.model_zoo = {
|
||||
k: list(v.PRETRAINED_MODEL_CONFIG_DICT.keys())
|
||||
for k, v in registry.mapping["model_name_mapping"].items()
|
||||
}
|
||||
|
||||
def __str__(self) -> str:
|
||||
return (
|
||||
"=" * 50
|
||||
+ "\n"
|
||||
+ f"{'Architectures':<30} {'Types'}\n"
|
||||
+ "=" * 50
|
||||
+ "\n"
|
||||
+ "\n".join(
|
||||
[
|
||||
f"{name:<30} {', '.join(types)}"
|
||||
for name, types in self.model_zoo.items()
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self.model_zoo.items())
|
||||
|
||||
def __len__(self):
|
||||
return sum([len(v) for v in self.model_zoo.values()])
|
||||
|
||||
|
||||
model_zoo = ModelZoo()
|
247
minigpt4/models/base_model.py
Normal file
@ -0,0 +1,247 @@
|
||||
"""
|
||||
Copyright (c) 2022, salesforce.com, inc.
|
||||
All rights reserved.
|
||||
SPDX-License-Identifier: BSD-3-Clause
|
||||
For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from minigpt4.common.dist_utils import download_cached_file, is_dist_avail_and_initialized
|
||||
from minigpt4.common.utils import get_abs_path, is_url
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
|
||||
class BaseModel(nn.Module):
|
||||
"""Base class for models."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return list(self.parameters())[0].device
|
||||
|
||||
def load_checkpoint(self, url_or_filename):
|
||||
"""
|
||||
Load from a finetuned checkpoint.
|
||||
|
||||
This should expect no mismatch in the model keys and the checkpoint keys.
|
||||
"""
|
||||
|
||||
if is_url(url_or_filename):
|
||||
cached_file = download_cached_file(
|
||||
url_or_filename, check_hash=False, progress=True
|
||||
)
|
||||
checkpoint = torch.load(cached_file, map_location="cpu")
|
||||
elif os.path.isfile(url_or_filename):
|
||||
checkpoint = torch.load(url_or_filename, map_location="cpu")
|
||||
else:
|
||||
raise RuntimeError("checkpoint url or path is invalid")
|
||||
|
||||
if "model" in checkpoint.keys():
|
||||
state_dict = checkpoint["model"]
|
||||
else:
|
||||
state_dict = checkpoint
|
||||
|
||||
msg = self.load_state_dict(state_dict, strict=False)
|
||||
|
||||
logging.info("Missing keys {}".format(msg.missing_keys))
|
||||
logging.info("load checkpoint from %s" % url_or_filename)
|
||||
|
||||
return msg
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, model_type):
|
||||
"""
|
||||
Build a pretrained model from default configuration file, specified by model_type.
|
||||
|
||||
Args:
|
||||
- model_type (str): model type, specifying architecture and checkpoints.
|
||||
|
||||
Returns:
|
||||
- model (nn.Module): pretrained or finetuned model, depending on the configuration.
|
||||
"""
|
||||
model_cfg = OmegaConf.load(cls.default_config_path(model_type)).model
|
||||
model = cls.from_config(model_cfg)
|
||||
|
||||
return model
|
||||
|
||||
@classmethod
|
||||
def default_config_path(cls, model_type):
|
||||
assert (
|
||||
model_type in cls.PRETRAINED_MODEL_CONFIG_DICT
|
||||
), "Unknown model type {}".format(model_type)
|
||||
return get_abs_path(cls.PRETRAINED_MODEL_CONFIG_DICT[model_type])
|
||||
|
||||
def load_checkpoint_from_config(self, cfg, **kwargs):
|
||||
"""
|
||||
Load checkpoint as specified in the config file.
|
||||
|
||||
If load_finetuned is True, load the finetuned model; otherwise, load the pretrained model.
|
||||
When loading the pretrained model, each task-specific architecture may define their
|
||||
own load_from_pretrained() method.
|
||||
"""
|
||||
load_finetuned = cfg.get("load_finetuned", True)
|
||||
if load_finetuned:
|
||||
finetune_path = cfg.get("finetuned", None)
|
||||
assert (
|
||||
finetune_path is not None
|
||||
), "Found load_finetuned is True, but finetune_path is None."
|
||||
self.load_checkpoint(url_or_filename=finetune_path)
|
||||
else:
|
||||
# load pre-trained weights
|
||||
pretrain_path = cfg.get("pretrained", None)
|
||||
assert "Found load_finetuned is False, but pretrain_path is None."
|
||||
self.load_from_pretrained(url_or_filename=pretrain_path, **kwargs)
|
||||
|
||||
def before_evaluation(self, **kwargs):
|
||||
pass
|
||||
|
||||
def show_n_params(self, return_str=True):
|
||||
tot = 0
|
||||
for p in self.parameters():
|
||||
w = 1
|
||||
for x in p.shape:
|
||||
w *= x
|
||||
tot += w
|
||||
if return_str:
|
||||
if tot >= 1e6:
|
||||
return "{:.1f}M".format(tot / 1e6)
|
||||
else:
|
||||
return "{:.1f}K".format(tot / 1e3)
|
||||
else:
|
||||
return tot
|
||||
|
||||
|
||||
class BaseEncoder(nn.Module):
|
||||
"""
|
||||
Base class for primitive encoders, such as ViT, TimeSformer, etc.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward_features(self, samples, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return list(self.parameters())[0].device
|
||||
|
||||
|
||||
class SharedQueueMixin:
|
||||
@torch.no_grad()
|
||||
def _dequeue_and_enqueue(self, image_feat, text_feat, idxs=None):
|
||||
# gather keys before updating queue
|
||||
image_feats = concat_all_gather(image_feat)
|
||||
text_feats = concat_all_gather(text_feat)
|
||||
|
||||
batch_size = image_feats.shape[0]
|
||||
|
||||
ptr = int(self.queue_ptr)
|
||||
assert self.queue_size % batch_size == 0 # for simplicity
|
||||
|
||||
# replace the keys at ptr (dequeue and enqueue)
|
||||
self.image_queue[:, ptr : ptr + batch_size] = image_feats.T
|
||||
self.text_queue[:, ptr : ptr + batch_size] = text_feats.T
|
||||
|
||||
if idxs is not None:
|
||||
idxs = concat_all_gather(idxs)
|
||||
self.idx_queue[:, ptr : ptr + batch_size] = idxs.T
|
||||
|
||||
ptr = (ptr + batch_size) % self.queue_size # move pointer
|
||||
self.queue_ptr[0] = ptr
|
||||
|
||||
|
||||
class MomentumDistilationMixin:
|
||||
@torch.no_grad()
|
||||
def copy_params(self):
|
||||
for model_pair in self.model_pairs:
|
||||
for param, param_m in zip(
|
||||
model_pair[0].parameters(), model_pair[1].parameters()
|
||||
):
|
||||
param_m.data.copy_(param.data) # initialize
|
||||
param_m.requires_grad = False # not update by gradient
|
||||
|
||||
@torch.no_grad()
|
||||
def _momentum_update(self):
|
||||
for model_pair in self.model_pairs:
|
||||
for param, param_m in zip(
|
||||
model_pair[0].parameters(), model_pair[1].parameters()
|
||||
):
|
||||
param_m.data = param_m.data * self.momentum + param.data * (
|
||||
1.0 - self.momentum
|
||||
)
|
||||
|
||||
|
||||
class GatherLayer(torch.autograd.Function):
|
||||
"""
|
||||
Gather tensors from all workers with support for backward propagation:
|
||||
This implementation does not cut the gradients as torch.distributed.all_gather does.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, x):
|
||||
output = [
|
||||
torch.zeros_like(x) for _ in range(torch.distributed.get_world_size())
|
||||
]
|
||||
torch.distributed.all_gather(output, x)
|
||||
return tuple(output)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, *grads):
|
||||
all_gradients = torch.stack(grads)
|
||||
torch.distributed.all_reduce(all_gradients)
|
||||
return all_gradients[torch.distributed.get_rank()]
|
||||
|
||||
|
||||
def all_gather_with_grad(tensors):
|
||||
"""
|
||||
Performs all_gather operation on the provided tensors.
|
||||
Graph remains connected for backward grad computation.
|
||||
"""
|
||||
# Queue the gathered tensors
|
||||
world_size = torch.distributed.get_world_size()
|
||||
# There is no need for reduction in the single-proc case
|
||||
if world_size == 1:
|
||||
return tensors
|
||||
|
||||
# tensor_all = GatherLayer.apply(tensors)
|
||||
tensor_all = GatherLayer.apply(tensors)
|
||||
|
||||
return torch.cat(tensor_all, dim=0)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def concat_all_gather(tensor):
|
||||
"""
|
||||
Performs all_gather operation on the provided tensors.
|
||||
*** Warning ***: torch.distributed.all_gather has no gradient.
|
||||
"""
|
||||
# if use distributed training
|
||||
if not is_dist_avail_and_initialized():
|
||||
return tensor
|
||||
|
||||
tensors_gather = [
|
||||
torch.ones_like(tensor) for _ in range(torch.distributed.get_world_size())
|
||||
]
|
||||
torch.distributed.all_gather(tensors_gather, tensor, async_op=False)
|
||||
|
||||
output = torch.cat(tensors_gather, dim=0)
|
||||
return output
|
||||
|
||||
|
||||
def tile(x, dim, n_tile):
|
||||
init_dim = x.size(dim)
|
||||
repeat_idx = [1] * x.dim()
|
||||
repeat_idx[dim] = n_tile
|
||||
x = x.repeat(*(repeat_idx))
|
||||
order_index = torch.LongTensor(
|
||||
np.concatenate([init_dim * np.arange(n_tile) + i for i in range(init_dim)])
|
||||
)
|
||||
return torch.index_select(x, dim, order_index.to(x.device))
|
221
minigpt4/models/blip2.py
Normal file
@ -0,0 +1,221 @@
|
||||
"""
|
||||
Copyright (c) 2023, salesforce.com, inc.
|
||||
All rights reserved.
|
||||
SPDX-License-Identifier: BSD-3-Clause
|
||||
For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
||||
"""
|
||||
import contextlib
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
import datetime
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.distributed as dist
|
||||
import torch.nn.functional as F
|
||||
|
||||
import minigpt4.common.dist_utils as dist_utils
|
||||
from minigpt4.common.dist_utils import download_cached_file
|
||||
from minigpt4.common.utils import is_url
|
||||
from minigpt4.common.logger import MetricLogger
|
||||
from minigpt4.models.base_model import BaseModel
|
||||
from minigpt4.models.Qformer import BertConfig, BertLMHeadModel
|
||||
from minigpt4.models.eva_vit import create_eva_vit_g
|
||||
from transformers import BertTokenizer
|
||||
|
||||
|
||||
class Blip2Base(BaseModel):
|
||||
@classmethod
|
||||
def init_tokenizer(cls):
|
||||
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
|
||||
tokenizer.add_special_tokens({"bos_token": "[DEC]"})
|
||||
return tokenizer
|
||||
|
||||
def maybe_autocast(self, dtype=torch.float16):
|
||||
# if on cpu, don't use autocast
|
||||
# if on gpu, use autocast with dtype if provided, otherwise use torch.float16
|
||||
enable_autocast = self.device != torch.device("cpu")
|
||||
|
||||
if enable_autocast:
|
||||
return torch.cuda.amp.autocast(dtype=dtype)
|
||||
else:
|
||||
return contextlib.nullcontext()
|
||||
|
||||
@classmethod
|
||||
def init_Qformer(cls, num_query_token, vision_width, cross_attention_freq=2):
|
||||
encoder_config = BertConfig.from_pretrained("bert-base-uncased")
|
||||
encoder_config.encoder_width = vision_width
|
||||
# insert cross-attention layer every other block
|
||||
encoder_config.add_cross_attention = True
|
||||
encoder_config.cross_attention_freq = cross_attention_freq
|
||||
encoder_config.query_length = num_query_token
|
||||
Qformer = BertLMHeadModel(config=encoder_config)
|
||||
query_tokens = nn.Parameter(
|
||||
torch.zeros(1, num_query_token, encoder_config.hidden_size)
|
||||
)
|
||||
query_tokens.data.normal_(mean=0.0, std=encoder_config.initializer_range)
|
||||
return Qformer, query_tokens
|
||||
|
||||
@classmethod
|
||||
def init_vision_encoder(
|
||||
cls, model_name, img_size, drop_path_rate, use_grad_checkpoint, precision
|
||||
):
|
||||
assert model_name == "eva_clip_g", "vit model must be eva_clip_g for current version of MiniGPT-4"
|
||||
visual_encoder = create_eva_vit_g(
|
||||
img_size, drop_path_rate, use_grad_checkpoint, precision
|
||||
)
|
||||
|
||||
ln_vision = LayerNorm(visual_encoder.num_features)
|
||||
return visual_encoder, ln_vision
|
||||
|
||||
def load_from_pretrained(self, url_or_filename):
|
||||
if is_url(url_or_filename):
|
||||
cached_file = download_cached_file(
|
||||
url_or_filename, check_hash=False, progress=True
|
||||
)
|
||||
checkpoint = torch.load(cached_file, map_location="cpu")
|
||||
elif os.path.isfile(url_or_filename):
|
||||
checkpoint = torch.load(url_or_filename, map_location="cpu")
|
||||
else:
|
||||
raise RuntimeError("checkpoint url or path is invalid")
|
||||
|
||||
state_dict = checkpoint["model"]
|
||||
|
||||
msg = self.load_state_dict(state_dict, strict=False)
|
||||
|
||||
# logging.info("Missing keys {}".format(msg.missing_keys))
|
||||
logging.info("load checkpoint from %s" % url_or_filename)
|
||||
|
||||
return msg
|
||||
|
||||
|
||||
def disabled_train(self, mode=True):
|
||||
"""Overwrite model.train with this function to make sure train/eval mode
|
||||
does not change anymore."""
|
||||
return self
|
||||
|
||||
|
||||
class LayerNorm(nn.LayerNorm):
|
||||
"""Subclass torch's LayerNorm to handle fp16."""
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
orig_type = x.dtype
|
||||
ret = super().forward(x.type(torch.float32))
|
||||
return ret.type(orig_type)
|
||||
|
||||
|
||||
def compute_sim_matrix(model, data_loader, **kwargs):
|
||||
k_test = kwargs.pop("k_test")
|
||||
|
||||
metric_logger = MetricLogger(delimiter=" ")
|
||||
header = "Evaluation:"
|
||||
|
||||
logging.info("Computing features for evaluation...")
|
||||
start_time = time.time()
|
||||
|
||||
texts = data_loader.dataset.text
|
||||
num_text = len(texts)
|
||||
text_bs = 256
|
||||
text_ids = []
|
||||
text_embeds = []
|
||||
text_atts = []
|
||||
for i in range(0, num_text, text_bs):
|
||||
text = texts[i : min(num_text, i + text_bs)]
|
||||
text_input = model.tokenizer(
|
||||
text,
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
max_length=35,
|
||||
return_tensors="pt",
|
||||
).to(model.device)
|
||||
text_feat = model.forward_text(text_input)
|
||||
text_embed = F.normalize(model.text_proj(text_feat))
|
||||
text_embeds.append(text_embed)
|
||||
text_ids.append(text_input.input_ids)
|
||||
text_atts.append(text_input.attention_mask)
|
||||
|
||||
text_embeds = torch.cat(text_embeds, dim=0)
|
||||
text_ids = torch.cat(text_ids, dim=0)
|
||||
text_atts = torch.cat(text_atts, dim=0)
|
||||
|
||||
vit_feats = []
|
||||
image_embeds = []
|
||||
for samples in data_loader:
|
||||
image = samples["image"]
|
||||
|
||||
image = image.to(model.device)
|
||||
image_feat, vit_feat = model.forward_image(image)
|
||||
image_embed = model.vision_proj(image_feat)
|
||||
image_embed = F.normalize(image_embed, dim=-1)
|
||||
|
||||
vit_feats.append(vit_feat.cpu())
|
||||
image_embeds.append(image_embed)
|
||||
|
||||
vit_feats = torch.cat(vit_feats, dim=0)
|
||||
image_embeds = torch.cat(image_embeds, dim=0)
|
||||
|
||||
sims_matrix = []
|
||||
for image_embed in image_embeds:
|
||||
sim_q2t = image_embed @ text_embeds.t()
|
||||
sim_i2t, _ = sim_q2t.max(0)
|
||||
sims_matrix.append(sim_i2t)
|
||||
sims_matrix = torch.stack(sims_matrix, dim=0)
|
||||
|
||||
score_matrix_i2t = torch.full(
|
||||
(len(data_loader.dataset.image), len(texts)), -100.0
|
||||
).to(model.device)
|
||||
|
||||
num_tasks = dist_utils.get_world_size()
|
||||
rank = dist_utils.get_rank()
|
||||
step = sims_matrix.size(0) // num_tasks + 1
|
||||
start = rank * step
|
||||
end = min(sims_matrix.size(0), start + step)
|
||||
|
||||
for i, sims in enumerate(
|
||||
metric_logger.log_every(sims_matrix[start:end], 50, header)
|
||||
):
|
||||
topk_sim, topk_idx = sims.topk(k=k_test, dim=0)
|
||||
image_inputs = vit_feats[start + i].repeat(k_test, 1, 1).to(model.device)
|
||||
score = model.compute_itm(
|
||||
image_inputs=image_inputs,
|
||||
text_ids=text_ids[topk_idx],
|
||||
text_atts=text_atts[topk_idx],
|
||||
).float()
|
||||
score_matrix_i2t[start + i, topk_idx] = score + topk_sim
|
||||
|
||||
sims_matrix = sims_matrix.t()
|
||||
score_matrix_t2i = torch.full(
|
||||
(len(texts), len(data_loader.dataset.image)), -100.0
|
||||
).to(model.device)
|
||||
|
||||
step = sims_matrix.size(0) // num_tasks + 1
|
||||
start = rank * step
|
||||
end = min(sims_matrix.size(0), start + step)
|
||||
|
||||
for i, sims in enumerate(
|
||||
metric_logger.log_every(sims_matrix[start:end], 50, header)
|
||||
):
|
||||
topk_sim, topk_idx = sims.topk(k=k_test, dim=0)
|
||||
image_inputs = vit_feats[topk_idx.cpu()].to(model.device)
|
||||
score = model.compute_itm(
|
||||
image_inputs=image_inputs,
|
||||
text_ids=text_ids[start + i].repeat(k_test, 1),
|
||||
text_atts=text_atts[start + i].repeat(k_test, 1),
|
||||
).float()
|
||||
score_matrix_t2i[start + i, topk_idx] = score + topk_sim
|
||||
|
||||
if dist_utils.is_dist_avail_and_initialized():
|
||||
dist.barrier()
|
||||
torch.distributed.all_reduce(
|
||||
score_matrix_i2t, op=torch.distributed.ReduceOp.SUM
|
||||
)
|
||||
torch.distributed.all_reduce(
|
||||
score_matrix_t2i, op=torch.distributed.ReduceOp.SUM
|
||||
)
|
||||
|
||||
total_time = time.time() - start_time
|
||||
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
||||
logging.info("Evaluation time {}".format(total_time_str))
|
||||
|
||||
return score_matrix_i2t.cpu().numpy(), score_matrix_t2i.cpu().numpy()
|
110
minigpt4/models/blip2_outputs.py
Normal file
@ -0,0 +1,110 @@
|
||||
"""
|
||||
Copyright (c) 2022, salesforce.com, inc.
|
||||
All rights reserved.
|
||||
SPDX-License-Identifier: BSD-3-Clause
|
||||
For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from transformers.modeling_outputs import (
|
||||
ModelOutput,
|
||||
BaseModelOutputWithPoolingAndCrossAttentions,
|
||||
CausalLMOutputWithCrossAttentions,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class BlipSimilarity(ModelOutput):
|
||||
sim_i2t: torch.FloatTensor = None
|
||||
sim_t2i: torch.FloatTensor = None
|
||||
|
||||
sim_i2t_m: Optional[torch.FloatTensor] = None
|
||||
sim_t2i_m: Optional[torch.FloatTensor] = None
|
||||
|
||||
sim_i2t_targets: Optional[torch.FloatTensor] = None
|
||||
sim_t2i_targets: Optional[torch.FloatTensor] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class BlipIntermediateOutput(ModelOutput):
|
||||
"""
|
||||
Data class for intermediate outputs of BLIP models.
|
||||
|
||||
image_embeds (torch.FloatTensor): Image embeddings, shape (batch_size, num_patches, embed_dim).
|
||||
text_embeds (torch.FloatTensor): Text embeddings, shape (batch_size, seq_len, embed_dim).
|
||||
|
||||
image_embeds_m (torch.FloatTensor): Image embeddings from momentum visual encoder, shape (batch_size, num_patches, embed_dim).
|
||||
text_embeds_m (torch.FloatTensor): Text embeddings from momentum text encoder, shape (batch_size, seq_len, embed_dim).
|
||||
|
||||
encoder_output (BaseModelOutputWithPoolingAndCrossAttentions): output from the image-grounded text encoder.
|
||||
encoder_output_neg (BaseModelOutputWithPoolingAndCrossAttentions): output from the image-grounded text encoder for negative pairs.
|
||||
|
||||
decoder_output (CausalLMOutputWithCrossAttentions): output from the image-grounded text decoder.
|
||||
decoder_labels (torch.LongTensor): labels for the captioning loss.
|
||||
|
||||
itm_logits (torch.FloatTensor): logits for the image-text matching loss, shape (batch_size * 3, 2).
|
||||
itm_labels (torch.LongTensor): labels for the image-text matching loss, shape (batch_size * 3,)
|
||||
|
||||
"""
|
||||
|
||||
# uni-modal features
|
||||
image_embeds: torch.FloatTensor = None
|
||||
text_embeds: Optional[torch.FloatTensor] = None
|
||||
|
||||
image_embeds_m: Optional[torch.FloatTensor] = None
|
||||
text_embeds_m: Optional[torch.FloatTensor] = None
|
||||
|
||||
# intermediate outputs of multimodal encoder
|
||||
encoder_output: Optional[BaseModelOutputWithPoolingAndCrossAttentions] = None
|
||||
encoder_output_neg: Optional[BaseModelOutputWithPoolingAndCrossAttentions] = None
|
||||
|
||||
itm_logits: Optional[torch.FloatTensor] = None
|
||||
itm_labels: Optional[torch.LongTensor] = None
|
||||
|
||||
# intermediate outputs of multimodal decoder
|
||||
decoder_output: Optional[CausalLMOutputWithCrossAttentions] = None
|
||||
decoder_labels: Optional[torch.LongTensor] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class BlipOutput(ModelOutput):
|
||||
# some finetuned models (e.g. BlipVQA) do not compute similarity, thus optional.
|
||||
sims: Optional[BlipSimilarity] = None
|
||||
|
||||
intermediate_output: BlipIntermediateOutput = None
|
||||
|
||||
loss: Optional[torch.FloatTensor] = None
|
||||
|
||||
loss_itc: Optional[torch.FloatTensor] = None
|
||||
|
||||
loss_itm: Optional[torch.FloatTensor] = None
|
||||
|
||||
loss_lm: Optional[torch.FloatTensor] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class BlipOutputFeatures(ModelOutput):
|
||||
"""
|
||||
Data class of features from BlipFeatureExtractor.
|
||||
|
||||
Args:
|
||||
image_embeds: (torch.FloatTensor) of shape (batch_size, num_patches+1, embed_dim), optional
|
||||
image_features: (torch.FloatTensor) of shape (batch_size, num_patches+1, feature_dim), optional
|
||||
text_embeds: (torch.FloatTensor) of shape (batch_size, sequence_length+1, embed_dim), optional
|
||||
text_features: (torch.FloatTensor) of shape (batch_size, sequence_length+1, feature_dim), optional
|
||||
|
||||
The first embedding or feature is for the [CLS] token.
|
||||
|
||||
Features are obtained by projecting the corresponding embedding into a normalized low-dimensional space.
|
||||
"""
|
||||
|
||||
image_embeds: Optional[torch.FloatTensor] = None
|
||||
image_embeds_proj: Optional[torch.FloatTensor] = None
|
||||
|
||||
text_embeds: Optional[torch.FloatTensor] = None
|
||||
text_embeds_proj: Optional[torch.FloatTensor] = None
|
||||
|
||||
multimodal_embeds: Optional[torch.FloatTensor] = None
|
442
minigpt4/models/eva_vit.py
Normal file
@ -0,0 +1,442 @@
|
||||
# Based on EVA, BEIT, timm and DeiT code bases
|
||||
# https://github.com/baaivision/EVA
|
||||
# https://github.com/rwightman/pytorch-image-models/tree/master/timm
|
||||
# https://github.com/microsoft/unilm/tree/master/beit
|
||||
# https://github.com/facebookresearch/deit/
|
||||
# https://github.com/facebookresearch/dino
|
||||
# --------------------------------------------------------'
|
||||
import math
|
||||
from functools import partial
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.checkpoint as checkpoint
|
||||
from timm.models.layers import drop_path, to_2tuple, trunc_normal_
|
||||
from timm.models.registry import register_model
|
||||
|
||||
from minigpt4.common.dist_utils import download_cached_file
|
||||
|
||||
def _cfg(url='', **kwargs):
|
||||
return {
|
||||
'url': url,
|
||||
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
|
||||
'crop_pct': .9, 'interpolation': 'bicubic',
|
||||
'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5),
|
||||
**kwargs
|
||||
}
|
||||
|
||||
|
||||
class DropPath(nn.Module):
|
||||
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
||||
"""
|
||||
def __init__(self, drop_prob=None):
|
||||
super(DropPath, self).__init__()
|
||||
self.drop_prob = drop_prob
|
||||
|
||||
def forward(self, x):
|
||||
return drop_path(x, self.drop_prob, self.training)
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
return 'p={}'.format(self.drop_prob)
|
||||
|
||||
|
||||
class Mlp(nn.Module):
|
||||
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
|
||||
super().__init__()
|
||||
out_features = out_features or in_features
|
||||
hidden_features = hidden_features or in_features
|
||||
self.fc1 = nn.Linear(in_features, hidden_features)
|
||||
self.act = act_layer()
|
||||
self.fc2 = nn.Linear(hidden_features, out_features)
|
||||
self.drop = nn.Dropout(drop)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.fc1(x)
|
||||
x = self.act(x)
|
||||
# x = self.drop(x)
|
||||
# commit this for the orignal BERT implement
|
||||
x = self.fc2(x)
|
||||
x = self.drop(x)
|
||||
return x
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(
|
||||
self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.,
|
||||
proj_drop=0., window_size=None, attn_head_dim=None):
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
head_dim = dim // num_heads
|
||||
if attn_head_dim is not None:
|
||||
head_dim = attn_head_dim
|
||||
all_head_dim = head_dim * self.num_heads
|
||||
self.scale = qk_scale or head_dim ** -0.5
|
||||
|
||||
self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False)
|
||||
if qkv_bias:
|
||||
self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
|
||||
self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
|
||||
else:
|
||||
self.q_bias = None
|
||||
self.v_bias = None
|
||||
|
||||
if window_size:
|
||||
self.window_size = window_size
|
||||
self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
|
||||
self.relative_position_bias_table = nn.Parameter(
|
||||
torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH
|
||||
# cls to token & token 2 cls & cls to cls
|
||||
|
||||
# get pair-wise relative position index for each token inside the window
|
||||
coords_h = torch.arange(window_size[0])
|
||||
coords_w = torch.arange(window_size[1])
|
||||
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
|
||||
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
|
||||
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
|
||||
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
|
||||
relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
|
||||
relative_coords[:, :, 1] += window_size[1] - 1
|
||||
relative_coords[:, :, 0] *= 2 * window_size[1] - 1
|
||||
relative_position_index = \
|
||||
torch.zeros(size=(window_size[0] * window_size[1] + 1, ) * 2, dtype=relative_coords.dtype)
|
||||
relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
|
||||
relative_position_index[0, 0:] = self.num_relative_distance - 3
|
||||
relative_position_index[0:, 0] = self.num_relative_distance - 2
|
||||
relative_position_index[0, 0] = self.num_relative_distance - 1
|
||||
|
||||
self.register_buffer("relative_position_index", relative_position_index)
|
||||
else:
|
||||
self.window_size = None
|
||||
self.relative_position_bias_table = None
|
||||
self.relative_position_index = None
|
||||
|
||||
self.attn_drop = nn.Dropout(attn_drop)
|
||||
self.proj = nn.Linear(all_head_dim, dim)
|
||||
self.proj_drop = nn.Dropout(proj_drop)
|
||||
|
||||
def forward(self, x, rel_pos_bias=None):
|
||||
B, N, C = x.shape
|
||||
qkv_bias = None
|
||||
if self.q_bias is not None:
|
||||
qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))
|
||||
# qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
||||
qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
|
||||
qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
||||
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
|
||||
|
||||
q = q * self.scale
|
||||
attn = (q @ k.transpose(-2, -1))
|
||||
|
||||
if self.relative_position_bias_table is not None:
|
||||
relative_position_bias = \
|
||||
self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
|
||||
self.window_size[0] * self.window_size[1] + 1,
|
||||
self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH
|
||||
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
|
||||
attn = attn + relative_position_bias.unsqueeze(0)
|
||||
|
||||
if rel_pos_bias is not None:
|
||||
attn = attn + rel_pos_bias
|
||||
|
||||
attn = attn.softmax(dim=-1)
|
||||
attn = self.attn_drop(attn)
|
||||
|
||||
x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
|
||||
x = self.proj(x)
|
||||
x = self.proj_drop(x)
|
||||
return x
|
||||
|
||||
|
||||
class Block(nn.Module):
|
||||
|
||||
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
|
||||
drop_path=0., init_values=None, act_layer=nn.GELU, norm_layer=nn.LayerNorm,
|
||||
window_size=None, attn_head_dim=None):
|
||||
super().__init__()
|
||||
self.norm1 = norm_layer(dim)
|
||||
self.attn = Attention(
|
||||
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
|
||||
attn_drop=attn_drop, proj_drop=drop, window_size=window_size, attn_head_dim=attn_head_dim)
|
||||
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
|
||||
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
||||
self.norm2 = norm_layer(dim)
|
||||
mlp_hidden_dim = int(dim * mlp_ratio)
|
||||
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
||||
|
||||
if init_values is not None and init_values > 0:
|
||||
self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)
|
||||
self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)
|
||||
else:
|
||||
self.gamma_1, self.gamma_2 = None, None
|
||||
|
||||
def forward(self, x, rel_pos_bias=None):
|
||||
if self.gamma_1 is None:
|
||||
x = x + self.drop_path(self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias))
|
||||
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
||||
else:
|
||||
x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias))
|
||||
x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
|
||||
return x
|
||||
|
||||
|
||||
class PatchEmbed(nn.Module):
|
||||
""" Image to Patch Embedding
|
||||
"""
|
||||
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
|
||||
super().__init__()
|
||||
img_size = to_2tuple(img_size)
|
||||
patch_size = to_2tuple(patch_size)
|
||||
num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
|
||||
self.patch_shape = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
|
||||
self.img_size = img_size
|
||||
self.patch_size = patch_size
|
||||
self.num_patches = num_patches
|
||||
|
||||
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
|
||||
|
||||
def forward(self, x, **kwargs):
|
||||
B, C, H, W = x.shape
|
||||
# FIXME look at relaxing size constraints
|
||||
assert H == self.img_size[0] and W == self.img_size[1], \
|
||||
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
|
||||
x = self.proj(x).flatten(2).transpose(1, 2)
|
||||
return x
|
||||
|
||||
|
||||
class RelativePositionBias(nn.Module):
|
||||
|
||||
def __init__(self, window_size, num_heads):
|
||||
super().__init__()
|
||||
self.window_size = window_size
|
||||
self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
|
||||
self.relative_position_bias_table = nn.Parameter(
|
||||
torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH
|
||||
# cls to token & token 2 cls & cls to cls
|
||||
|
||||
# get pair-wise relative position index for each token inside the window
|
||||
coords_h = torch.arange(window_size[0])
|
||||
coords_w = torch.arange(window_size[1])
|
||||
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
|
||||
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
|
||||
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
|
||||
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
|
||||
relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
|
||||
relative_coords[:, :, 1] += window_size[1] - 1
|
||||
relative_coords[:, :, 0] *= 2 * window_size[1] - 1
|
||||
relative_position_index = \
|
||||
torch.zeros(size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype)
|
||||
relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
|
||||
relative_position_index[0, 0:] = self.num_relative_distance - 3
|
||||
relative_position_index[0:, 0] = self.num_relative_distance - 2
|
||||
relative_position_index[0, 0] = self.num_relative_distance - 1
|
||||
|
||||
self.register_buffer("relative_position_index", relative_position_index)
|
||||
|
||||
# trunc_normal_(self.relative_position_bias_table, std=.02)
|
||||
|
||||
def forward(self):
|
||||
relative_position_bias = \
|
||||
self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
|
||||
self.window_size[0] * self.window_size[1] + 1,
|
||||
self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH
|
||||
return relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
|
||||
|
||||
|
||||
class VisionTransformer(nn.Module):
|
||||
""" Vision Transformer with support for patch or hybrid CNN input stage
|
||||
"""
|
||||
def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
|
||||
num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
|
||||
drop_path_rate=0., norm_layer=nn.LayerNorm, init_values=None,
|
||||
use_abs_pos_emb=True, use_rel_pos_bias=False, use_shared_rel_pos_bias=False,
|
||||
use_mean_pooling=True, init_scale=0.001, use_checkpoint=False):
|
||||
super().__init__()
|
||||
self.image_size = img_size
|
||||
self.num_classes = num_classes
|
||||
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
|
||||
|
||||
self.patch_embed = PatchEmbed(
|
||||
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
|
||||
num_patches = self.patch_embed.num_patches
|
||||
|
||||
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
||||
if use_abs_pos_emb:
|
||||
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
|
||||
else:
|
||||
self.pos_embed = None
|
||||
self.pos_drop = nn.Dropout(p=drop_rate)
|
||||
|
||||
if use_shared_rel_pos_bias:
|
||||
self.rel_pos_bias = RelativePositionBias(window_size=self.patch_embed.patch_shape, num_heads=num_heads)
|
||||
else:
|
||||
self.rel_pos_bias = None
|
||||
self.use_checkpoint = use_checkpoint
|
||||
|
||||
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
|
||||
self.use_rel_pos_bias = use_rel_pos_bias
|
||||
self.blocks = nn.ModuleList([
|
||||
Block(
|
||||
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
|
||||
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
|
||||
init_values=init_values, window_size=self.patch_embed.patch_shape if use_rel_pos_bias else None)
|
||||
for i in range(depth)])
|
||||
# self.norm = nn.Identity() if use_mean_pooling else norm_layer(embed_dim)
|
||||
# self.fc_norm = norm_layer(embed_dim) if use_mean_pooling else None
|
||||
# self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
||||
|
||||
if self.pos_embed is not None:
|
||||
trunc_normal_(self.pos_embed, std=.02)
|
||||
trunc_normal_(self.cls_token, std=.02)
|
||||
# trunc_normal_(self.mask_token, std=.02)
|
||||
# if isinstance(self.head, nn.Linear):
|
||||
# trunc_normal_(self.head.weight, std=.02)
|
||||
self.apply(self._init_weights)
|
||||
self.fix_init_weight()
|
||||
# if isinstance(self.head, nn.Linear):
|
||||
# self.head.weight.data.mul_(init_scale)
|
||||
# self.head.bias.data.mul_(init_scale)
|
||||
|
||||
def fix_init_weight(self):
|
||||
def rescale(param, layer_id):
|
||||
param.div_(math.sqrt(2.0 * layer_id))
|
||||
|
||||
for layer_id, layer in enumerate(self.blocks):
|
||||
rescale(layer.attn.proj.weight.data, layer_id + 1)
|
||||
rescale(layer.mlp.fc2.weight.data, layer_id + 1)
|
||||
|
||||
def _init_weights(self, m):
|
||||
if isinstance(m, nn.Linear):
|
||||
trunc_normal_(m.weight, std=.02)
|
||||
if isinstance(m, nn.Linear) and m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.LayerNorm):
|
||||
nn.init.constant_(m.bias, 0)
|
||||
nn.init.constant_(m.weight, 1.0)
|
||||
|
||||
def get_classifier(self):
|
||||
return self.head
|
||||
|
||||
def reset_classifier(self, num_classes, global_pool=''):
|
||||
self.num_classes = num_classes
|
||||
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
||||
|
||||
def forward_features(self, x):
|
||||
x = self.patch_embed(x)
|
||||
batch_size, seq_len, _ = x.size()
|
||||
|
||||
cls_tokens = self.cls_token.expand(batch_size, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
|
||||
x = torch.cat((cls_tokens, x), dim=1)
|
||||
if self.pos_embed is not None:
|
||||
x = x + self.pos_embed
|
||||
x = self.pos_drop(x)
|
||||
|
||||
rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None
|
||||
for blk in self.blocks:
|
||||
if self.use_checkpoint:
|
||||
x = checkpoint.checkpoint(blk, x, rel_pos_bias)
|
||||
else:
|
||||
x = blk(x, rel_pos_bias)
|
||||
return x
|
||||
# x = self.norm(x)
|
||||
|
||||
# if self.fc_norm is not None:
|
||||
# t = x[:, 1:, :]
|
||||
# return self.fc_norm(t.mean(1))
|
||||
# else:
|
||||
# return x[:, 0]
|
||||
|
||||
def forward(self, x):
|
||||
x = self.forward_features(x)
|
||||
# x = self.head(x)
|
||||
return x
|
||||
|
||||
def get_intermediate_layers(self, x):
|
||||
x = self.patch_embed(x)
|
||||
batch_size, seq_len, _ = x.size()
|
||||
|
||||
cls_tokens = self.cls_token.expand(batch_size, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
|
||||
x = torch.cat((cls_tokens, x), dim=1)
|
||||
if self.pos_embed is not None:
|
||||
x = x + self.pos_embed
|
||||
x = self.pos_drop(x)
|
||||
|
||||
features = []
|
||||
rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None
|
||||
for blk in self.blocks:
|
||||
x = blk(x, rel_pos_bias)
|
||||
features.append(x)
|
||||
|
||||
return features
|
||||
|
||||
|
||||
def interpolate_pos_embed(model, checkpoint_model):
|
||||
if 'pos_embed' in checkpoint_model:
|
||||
pos_embed_checkpoint = checkpoint_model['pos_embed'].float()
|
||||
embedding_size = pos_embed_checkpoint.shape[-1]
|
||||
num_patches = model.patch_embed.num_patches
|
||||
num_extra_tokens = model.pos_embed.shape[-2] - num_patches
|
||||
# height (== width) for the checkpoint position embedding
|
||||
orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
|
||||
# height (== width) for the new position embedding
|
||||
new_size = int(num_patches ** 0.5)
|
||||
# class_token and dist_token are kept unchanged
|
||||
if orig_size != new_size:
|
||||
print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size))
|
||||
extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
|
||||
# only the position tokens are interpolated
|
||||
pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
|
||||
pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
|
||||
pos_tokens = torch.nn.functional.interpolate(
|
||||
pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
|
||||
pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
|
||||
new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
|
||||
checkpoint_model['pos_embed'] = new_pos_embed
|
||||
|
||||
|
||||
def convert_weights_to_fp16(model: nn.Module):
|
||||
"""Convert applicable model parameters to fp16"""
|
||||
|
||||
def _convert_weights_to_fp16(l):
|
||||
if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
|
||||
l.weight.data = l.weight.data.half()
|
||||
if l.bias is not None:
|
||||
l.bias.data = l.bias.data.half()
|
||||
|
||||
# if isinstance(l, (nn.MultiheadAttention, Attention)):
|
||||
# for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
|
||||
# tensor = getattr(l, attr)
|
||||
# if tensor is not None:
|
||||
# tensor.data = tensor.data.half()
|
||||
|
||||
model.apply(_convert_weights_to_fp16)
|
||||
|
||||
|
||||
def create_eva_vit_g(img_size=224,drop_path_rate=0.4,use_checkpoint=False,precision="fp16"):
|
||||
model = VisionTransformer(
|
||||
img_size=img_size,
|
||||
patch_size=14,
|
||||
use_mean_pooling=False,
|
||||
embed_dim=1408,
|
||||
depth=39,
|
||||
num_heads=1408//88,
|
||||
mlp_ratio=4.3637,
|
||||
qkv_bias=True,
|
||||
drop_path_rate=drop_path_rate,
|
||||
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
||||
use_checkpoint=use_checkpoint,
|
||||
)
|
||||
url = "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/eva_vit_g.pth"
|
||||
cached_file = download_cached_file(
|
||||
url, check_hash=False, progress=True
|
||||
)
|
||||
state_dict = torch.load(cached_file, map_location="cpu")
|
||||
interpolate_pos_embed(model,state_dict)
|
||||
|
||||
incompatible_keys = model.load_state_dict(state_dict, strict=False)
|
||||
# print(incompatible_keys)
|
||||
|
||||
if precision == "fp16":
|
||||
# model.to("cuda")
|
||||
convert_weights_to_fp16(model)
|
||||
return model
|
242
minigpt4/models/mini_gpt4.py
Normal file
@ -0,0 +1,242 @@
|
||||
import logging
|
||||
import random
|
||||
|
||||
import torch
|
||||
from torch.cuda.amp import autocast as autocast
|
||||
import torch.nn as nn
|
||||
|
||||
from minigpt4.common.registry import registry
|
||||
from minigpt4.models.blip2 import Blip2Base, disabled_train
|
||||
from minigpt4.models.modeling_llama import LlamaForCausalLM
|
||||
from transformers import LlamaTokenizer
|
||||
|
||||
|
||||
@registry.register_model("mini_gpt4")
|
||||
class MiniGPT4(Blip2Base):
|
||||
"""
|
||||
BLIP2 GPT-LLAMA model.
|
||||
"""
|
||||
|
||||
PRETRAINED_MODEL_CONFIG_DICT = {
|
||||
"pretrain_vicuna": "configs/models/minigpt4.yaml",
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vit_model="eva_clip_g",
|
||||
q_former_model="https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/blip2_pretrained_flant5xxl.pth",
|
||||
img_size=224,
|
||||
drop_path_rate=0,
|
||||
use_grad_checkpoint=False,
|
||||
vit_precision="fp16",
|
||||
freeze_vit=True,
|
||||
freeze_qformer=True,
|
||||
num_query_token=32,
|
||||
llama_model="",
|
||||
prompt_path="",
|
||||
prompt_template="",
|
||||
max_txt_len=32,
|
||||
end_sym='\n',
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.tokenizer = self.init_tokenizer()
|
||||
|
||||
print('Loading VIT')
|
||||
self.visual_encoder, self.ln_vision = self.init_vision_encoder(
|
||||
vit_model, img_size, drop_path_rate, use_grad_checkpoint, vit_precision
|
||||
)
|
||||
if freeze_vit:
|
||||
for name, param in self.visual_encoder.named_parameters():
|
||||
param.requires_grad = False
|
||||
self.visual_encoder = self.visual_encoder.eval()
|
||||
self.visual_encoder.train = disabled_train
|
||||
for name, param in self.ln_vision.named_parameters():
|
||||
param.requires_grad = False
|
||||
self.ln_vision = self.ln_vision.eval()
|
||||
self.ln_vision.train = disabled_train
|
||||
logging.info("freeze vision encoder")
|
||||
print('Loading VIT Done')
|
||||
|
||||
print('Loading Q-Former')
|
||||
self.Qformer, self.query_tokens = self.init_Qformer(
|
||||
num_query_token, self.visual_encoder.num_features
|
||||
)
|
||||
self.Qformer.cls = None
|
||||
self.Qformer.bert.embeddings.word_embeddings = None
|
||||
self.Qformer.bert.embeddings.position_embeddings = None
|
||||
for layer in self.Qformer.bert.encoder.layer:
|
||||
layer.output = None
|
||||
layer.intermediate = None
|
||||
self.load_from_pretrained(url_or_filename=q_former_model)
|
||||
|
||||
if freeze_qformer:
|
||||
for name, param in self.Qformer.named_parameters():
|
||||
param.requires_grad = False
|
||||
self.Qformer = self.Qformer.eval()
|
||||
self.Qformer.train = disabled_train
|
||||
self.query_tokens.requires_grad = False
|
||||
logging.info("freeze Qformer")
|
||||
print('Loading Q-Former Done')
|
||||
|
||||
print('Loading LLAMA')
|
||||
self.llama_tokenizer = LlamaTokenizer.from_pretrained(llama_model, use_fast=False)
|
||||
self.llama_tokenizer.pad_token = self.llama_tokenizer.eos_token
|
||||
|
||||
self.llama_model = LlamaForCausalLM.from_pretrained(
|
||||
llama_model, torch_dtype=torch.float16
|
||||
)
|
||||
for name, param in self.llama_model.named_parameters():
|
||||
param.requires_grad = False
|
||||
print('Loading LLAMA Done')
|
||||
|
||||
self.llama_proj = nn.Linear(
|
||||
self.Qformer.config.hidden_size, self.llama_model.config.hidden_size
|
||||
)
|
||||
self.max_txt_len = max_txt_len
|
||||
self.end_sym = end_sym
|
||||
|
||||
if prompt_path:
|
||||
with open(prompt_path, 'r') as f:
|
||||
raw_prompts = f.read().splitlines()
|
||||
filted_prompts = [raw_prompt for raw_prompt in raw_prompts if "<ImageHere>" in raw_prompt]
|
||||
self.prompt_list = [prompt_template.format(p) for p in filted_prompts]
|
||||
print('Load {} training prompts'.format(len(self.prompt_list)))
|
||||
print('Prompt Example \n{}'.format(random.choice(self.prompt_list)))
|
||||
else:
|
||||
self.prompt_list = []
|
||||
|
||||
def encode_img(self, image):
|
||||
with self.maybe_autocast():
|
||||
image_embeds = self.ln_vision(self.visual_encoder(image))
|
||||
image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(
|
||||
image.device
|
||||
)
|
||||
|
||||
query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
|
||||
query_output = self.Qformer.bert(
|
||||
query_embeds=query_tokens,
|
||||
encoder_hidden_states=image_embeds,
|
||||
encoder_attention_mask=image_atts,
|
||||
return_dict=True,
|
||||
)
|
||||
|
||||
inputs_llama = self.llama_proj(query_output.last_hidden_state)
|
||||
atts_llama = torch.ones(inputs_llama.size()[:-1], dtype=torch.long).to(image.device)
|
||||
return inputs_llama, atts_llama
|
||||
|
||||
def prompt_wrap(self, img_embeds, atts_img, prompt):
|
||||
if prompt:
|
||||
batch_size = img_embeds.shape[0]
|
||||
p_before, p_after = prompt.split('<ImageHere>')
|
||||
p_before_tokens = self.llama_tokenizer(
|
||||
p_before, return_tensors="pt", add_special_tokens=False).to(img_embeds.device)
|
||||
p_after_tokens = self.llama_tokenizer(
|
||||
p_after, return_tensors="pt", add_special_tokens=False).to(img_embeds.device)
|
||||
p_before_embeds = self.llama_model.model.embed_tokens(p_before_tokens.input_ids).expand(batch_size, -1, -1)
|
||||
p_after_embeds = self.llama_model.model.embed_tokens(p_after_tokens.input_ids).expand(batch_size, -1, -1)
|
||||
wrapped_img_embeds = torch.cat([p_before_embeds, img_embeds, p_after_embeds], dim=1)
|
||||
wrapped_atts_img = atts_img[:, :1].expand(-1, wrapped_img_embeds.shape[1])
|
||||
return wrapped_img_embeds, wrapped_atts_img
|
||||
else:
|
||||
return img_embeds, atts_img
|
||||
|
||||
def forward(self, samples):
|
||||
image = samples["image"]
|
||||
img_embeds, atts_img = self.encode_img(image)
|
||||
if hasattr(samples, 'question_split'): # VQA dataset
|
||||
print('VQA Batch')
|
||||
vqa_prompt = '###Human: <Img><ImageHere></Img> '
|
||||
img_embeds, atts_img = self.prompt_wrap(img_embeds, atts_img, vqa_prompt)
|
||||
elif self.prompt_list:
|
||||
prompt = random.choice(self.prompt_list)
|
||||
img_embeds, atts_img = self.prompt_wrap(img_embeds, atts_img, prompt)
|
||||
|
||||
self.llama_tokenizer.padding_side = "right"
|
||||
|
||||
text = [t + self.end_sym for t in samples["text_input"]]
|
||||
|
||||
to_regress_tokens = self.llama_tokenizer(
|
||||
text,
|
||||
return_tensors="pt",
|
||||
padding="longest",
|
||||
truncation=True,
|
||||
max_length=self.max_txt_len,
|
||||
add_special_tokens=False
|
||||
).to(image.device)
|
||||
|
||||
targets = to_regress_tokens.input_ids.masked_fill(
|
||||
to_regress_tokens.input_ids == self.llama_tokenizer.pad_token_id, -100
|
||||
)
|
||||
|
||||
empty_targets = (
|
||||
torch.ones([atts_img.shape[0], atts_img.shape[1]+1],
|
||||
dtype=torch.long).to(image.device).fill_(-100) # plus one for bos
|
||||
)
|
||||
targets = torch.cat([empty_targets, targets], dim=1)
|
||||
|
||||
batch_size = img_embeds.shape[0]
|
||||
bos = torch.ones([batch_size, 1],
|
||||
dtype=to_regress_tokens.input_ids.dtype,
|
||||
device=to_regress_tokens.input_ids.device) * self.llama_tokenizer.bos_token_id
|
||||
bos_embeds = self.llama_model.model.embed_tokens(bos)
|
||||
atts_bos = atts_img[:, :1]
|
||||
|
||||
to_regress_embeds = self.llama_model.model.embed_tokens(to_regress_tokens.input_ids)
|
||||
inputs_embeds = torch.cat([bos_embeds, img_embeds, to_regress_embeds], dim=1)
|
||||
attention_mask = torch.cat([atts_bos, atts_img, to_regress_tokens.attention_mask], dim=1)
|
||||
|
||||
with self.maybe_autocast():
|
||||
outputs = self.llama_model(
|
||||
inputs_embeds=inputs_embeds,
|
||||
attention_mask=attention_mask,
|
||||
return_dict=True,
|
||||
labels=targets,
|
||||
)
|
||||
loss = outputs.loss
|
||||
|
||||
return {"loss": loss}
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, cfg):
|
||||
vit_model = cfg.get("vit_model", "eva_clip_g")
|
||||
q_former_model = cfg.get("q_former_model", "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/blip2_pretrained_flant5xxl.pth")
|
||||
img_size = cfg.get("image_size")
|
||||
num_query_token = cfg.get("num_query_token")
|
||||
llama_model = cfg.get("llama_model")
|
||||
|
||||
drop_path_rate = cfg.get("drop_path_rate", 0)
|
||||
use_grad_checkpoint = cfg.get("use_grad_checkpoint", False)
|
||||
vit_precision = cfg.get("vit_precision", "fp16")
|
||||
freeze_vit = cfg.get("freeze_vit", True)
|
||||
freeze_qformer = cfg.get("freeze_qformer", True)
|
||||
|
||||
prompt_path = cfg.get("prompt_path", "")
|
||||
prompt_template = cfg.get("prompt_template", "")
|
||||
max_txt_len = cfg.get("max_txt_len", 32)
|
||||
end_sym = cfg.get("end_sym", '\n')
|
||||
|
||||
model = cls(
|
||||
vit_model=vit_model,
|
||||
q_former_model=q_former_model,
|
||||
img_size=img_size,
|
||||
drop_path_rate=drop_path_rate,
|
||||
use_grad_checkpoint=use_grad_checkpoint,
|
||||
vit_precision=vit_precision,
|
||||
freeze_vit=freeze_vit,
|
||||
freeze_qformer=freeze_qformer,
|
||||
num_query_token=num_query_token,
|
||||
llama_model=llama_model,
|
||||
prompt_path=prompt_path,
|
||||
prompt_template=prompt_template,
|
||||
max_txt_len=max_txt_len,
|
||||
end_sym=end_sym
|
||||
)
|
||||
|
||||
ckpt_path = cfg.get("ckpt", "") # load weights of MiniGPT-4
|
||||
if ckpt_path:
|
||||
print("Load BLIP2-LLM Checkpoint: {}".format(ckpt_path))
|
||||
ckpt = torch.load(ckpt_path, map_location="cpu")
|
||||
msg = model.load_state_dict(ckpt['model'], strict=False)
|
||||
|
||||
return model
|
755
minigpt4/models/modeling_llama.py
Normal file
@ -0,0 +1,755 @@
|
||||
# This script is based on https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py
|
||||
|
||||
""" PyTorch LLaMA model."""
|
||||
import math
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||
|
||||
from transformers.activations import ACT2FN
|
||||
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
|
||||
from transformers.models.llama.configuration_llama import LlamaConfig
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
_CONFIG_FOR_DOC = "LlamaConfig"
|
||||
|
||||
|
||||
# Copied from transformers.models.bart.modeling_bart._make_causal_mask
|
||||
def _make_causal_mask(
|
||||
input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
|
||||
):
|
||||
"""
|
||||
Make causal mask used for bi-directional self-attention.
|
||||
"""
|
||||
bsz, tgt_len = input_ids_shape
|
||||
mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min, device=device), device=device)
|
||||
mask_cond = torch.arange(mask.size(-1), device=device)
|
||||
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
|
||||
mask = mask.to(dtype)
|
||||
|
||||
if past_key_values_length > 0:
|
||||
mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
|
||||
return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
|
||||
|
||||
|
||||
# Copied from transformers.models.bart.modeling_bart._expand_mask
|
||||
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
|
||||
"""
|
||||
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
|
||||
"""
|
||||
bsz, src_len = mask.size()
|
||||
tgt_len = tgt_len if tgt_len is not None else src_len
|
||||
|
||||
expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
|
||||
|
||||
inverted_mask = 1.0 - expanded_mask
|
||||
|
||||
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
|
||||
|
||||
|
||||
class LlamaRMSNorm(nn.Module):
|
||||
def __init__(self, hidden_size, eps=1e-6):
|
||||
"""
|
||||
LlamaRMSNorm is equivalent to T5LayerNorm
|
||||
"""
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.ones(hidden_size))
|
||||
self.variance_epsilon = eps
|
||||
|
||||
def forward(self, hidden_states):
|
||||
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
|
||||
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
||||
|
||||
# convert into half-precision if necessary
|
||||
if self.weight.dtype in [torch.float16, torch.bfloat16]:
|
||||
hidden_states = hidden_states.to(self.weight.dtype)
|
||||
|
||||
return self.weight * hidden_states
|
||||
|
||||
|
||||
class LlamaRotaryEmbedding(torch.nn.Module):
|
||||
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
|
||||
super().__init__()
|
||||
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
|
||||
self.register_buffer("inv_freq", inv_freq)
|
||||
|
||||
# Build here to make `torch.jit.trace` work.
|
||||
self.max_seq_len_cached = max_position_embeddings
|
||||
t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
|
||||
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
||||
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False)
|
||||
self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False)
|
||||
|
||||
def forward(self, x, seq_len=None):
|
||||
# x: [bs, num_attention_heads, seq_len, head_size]
|
||||
# This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.
|
||||
if seq_len > self.max_seq_len_cached:
|
||||
self.max_seq_len_cached = seq_len
|
||||
t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype)
|
||||
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
||||
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
||||
emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
|
||||
self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False)
|
||||
self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False)
|
||||
return (
|
||||
self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
|
||||
self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
|
||||
)
|
||||
|
||||
|
||||
def rotate_half(x):
|
||||
"""Rotates half the hidden dims of the input."""
|
||||
x1 = x[..., : x.shape[-1] // 2]
|
||||
x2 = x[..., x.shape[-1] // 2 :]
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
|
||||
|
||||
def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
|
||||
gather_indices = position_ids[:, None, :, None] # [bs, 1, seq_len, 1]
|
||||
gather_indices = gather_indices.repeat(1, cos.shape[1], 1, cos.shape[3])
|
||||
cos = torch.gather(cos.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices)
|
||||
sin = torch.gather(sin.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices)
|
||||
q_embed = (q * cos) + (rotate_half(q) * sin)
|
||||
k_embed = (k * cos) + (rotate_half(k) * sin)
|
||||
return q_embed, k_embed
|
||||
|
||||
|
||||
class LlamaMLP(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
hidden_act: str,
|
||||
):
|
||||
super().__init__()
|
||||
self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
|
||||
self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False)
|
||||
self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
|
||||
self.act_fn = ACT2FN[hidden_act]
|
||||
|
||||
def forward(self, x):
|
||||
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
||||
|
||||
|
||||
class LlamaAttention(nn.Module):
|
||||
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||
|
||||
def __init__(self, config: LlamaConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.hidden_size = config.hidden_size
|
||||
self.num_heads = config.num_attention_heads
|
||||
self.head_dim = self.hidden_size // self.num_heads
|
||||
self.max_position_embeddings = config.max_position_embeddings
|
||||
|
||||
if (self.head_dim * self.num_heads) != self.hidden_size:
|
||||
raise ValueError(
|
||||
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
|
||||
f" and `num_heads`: {self.num_heads})."
|
||||
)
|
||||
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
|
||||
self.k_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
|
||||
self.v_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
|
||||
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
|
||||
self.rotary_emb = LlamaRotaryEmbedding(self.head_dim, max_position_embeddings=self.max_position_embeddings)
|
||||
|
||||
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
||||
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
kv_seq_len = key_states.shape[-2]
|
||||
if past_key_value is not None:
|
||||
kv_seq_len += past_key_value[0].shape[-2]
|
||||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
||||
# [bsz, nh, t, hd]
|
||||
|
||||
if past_key_value is not None:
|
||||
# reuse k, v, self_attention
|
||||
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
||||
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
||||
|
||||
past_key_value = (key_states, value_states) if use_cache else None
|
||||
|
||||
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
||||
|
||||
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
|
||||
raise ValueError(
|
||||
f"Attention weights should be of size {(bsz * self.num_heads, q_len, kv_seq_len)}, but is"
|
||||
f" {attn_weights.size()}"
|
||||
)
|
||||
|
||||
if attention_mask is not None:
|
||||
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
|
||||
raise ValueError(
|
||||
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
|
||||
)
|
||||
attn_weights = attn_weights + attention_mask
|
||||
attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min))
|
||||
|
||||
# upcast attention to fp32
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
||||
attn_output = torch.matmul(attn_weights, value_states)
|
||||
|
||||
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
||||
raise ValueError(
|
||||
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
|
||||
f" {attn_output.size()}"
|
||||
)
|
||||
|
||||
attn_output = attn_output.transpose(1, 2)
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
||||
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
if not output_attentions:
|
||||
attn_weights = None
|
||||
|
||||
return attn_output, attn_weights, past_key_value
|
||||
|
||||
|
||||
class LlamaDecoderLayer(nn.Module):
|
||||
def __init__(self, config: LlamaConfig):
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
self.self_attn = LlamaAttention(config=config)
|
||||
self.mlp = LlamaMLP(
|
||||
hidden_size=self.hidden_size,
|
||||
intermediate_size=config.intermediate_size,
|
||||
hidden_act=config.hidden_act,
|
||||
)
|
||||
self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
use_cache: Optional[bool] = False,
|
||||
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||
"""
|
||||
Args:
|
||||
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
||||
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
|
||||
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
||||
output_attentions (`bool`, *optional*):
|
||||
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
||||
returned tensors for more detail.
|
||||
use_cache (`bool`, *optional*):
|
||||
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
|
||||
(see `past_key_values`).
|
||||
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
|
||||
"""
|
||||
|
||||
residual = hidden_states
|
||||
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
|
||||
# Self Attention
|
||||
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
# Fully Connected
|
||||
residual = hidden_states
|
||||
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
outputs = (hidden_states,)
|
||||
|
||||
if output_attentions:
|
||||
outputs += (self_attn_weights,)
|
||||
|
||||
if use_cache:
|
||||
outputs += (present_key_value,)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
LLAMA_START_DOCSTRING = r"""
|
||||
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
||||
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
||||
etc.)
|
||||
|
||||
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
|
||||
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
|
||||
and behavior.
|
||||
|
||||
Parameters:
|
||||
config ([`LlamaConfig`]):
|
||||
Model configuration class with all the parameters of the model. Initializing with a config file does not
|
||||
load the weights associated with the model, only the configuration. Check out the
|
||||
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
||||
"""
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
|
||||
LLAMA_START_DOCSTRING,
|
||||
)
|
||||
class LlamaPreTrainedModel(PreTrainedModel):
|
||||
config_class = LlamaConfig
|
||||
base_model_prefix = "model"
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["LlamaDecoderLayer"]
|
||||
_keys_to_ignore_on_load_unexpected = [r"decoder\.version"]
|
||||
|
||||
def _init_weights(self, module):
|
||||
std = self.config.initializer_range
|
||||
if isinstance(module, nn.Linear):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
|
||||
def _set_gradient_checkpointing(self, module, value=False):
|
||||
if isinstance(module, LlamaModel):
|
||||
module.gradient_checkpointing = value
|
||||
|
||||
|
||||
LLAMA_INPUTS_DOCSTRING = r"""
|
||||
Args:
|
||||
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
||||
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
|
||||
it.
|
||||
|
||||
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
||||
[`PreTrainedTokenizer.__call__`] for details.
|
||||
|
||||
[What are input IDs?](../glossary#input-ids)
|
||||
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
||||
|
||||
- 1 for tokens that are **not masked**,
|
||||
- 0 for tokens that are **masked**.
|
||||
|
||||
[What are attention masks?](../glossary#attention-mask)
|
||||
|
||||
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
||||
[`PreTrainedTokenizer.__call__`] for details.
|
||||
|
||||
If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
|
||||
`past_key_values`).
|
||||
|
||||
If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
|
||||
and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
|
||||
information on the default strategy.
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the head is **masked**.
|
||||
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
|
||||
config.n_positions - 1]`.
|
||||
|
||||
[What are position IDs?](../glossary#position-ids)
|
||||
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
||||
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
|
||||
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
|
||||
`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
|
||||
|
||||
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
|
||||
blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
|
||||
|
||||
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
|
||||
don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
|
||||
`decoder_input_ids` of shape `(batch_size, sequence_length)`.
|
||||
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
||||
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
|
||||
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
|
||||
model's internal embedding lookup matrix.
|
||||
use_cache (`bool`, *optional*):
|
||||
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
|
||||
`past_key_values`).
|
||||
output_attentions (`bool`, *optional*):
|
||||
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
||||
tensors for more detail.
|
||||
output_hidden_states (`bool`, *optional*):
|
||||
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
||||
more detail.
|
||||
return_dict (`bool`, *optional*):
|
||||
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
||||
"""
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
|
||||
LLAMA_START_DOCSTRING,
|
||||
)
|
||||
class LlamaModel(LlamaPreTrainedModel):
|
||||
"""
|
||||
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`]
|
||||
|
||||
Args:
|
||||
config: LlamaConfig
|
||||
"""
|
||||
|
||||
def __init__(self, config: LlamaConfig):
|
||||
super().__init__(config)
|
||||
self.padding_idx = config.pad_token_id
|
||||
self.vocab_size = config.vocab_size
|
||||
|
||||
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
|
||||
self.layers = nn.ModuleList([LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)])
|
||||
self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.embed_tokens
|
||||
|
||||
def set_input_embeddings(self, value):
|
||||
self.embed_tokens = value
|
||||
|
||||
# Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
|
||||
def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
|
||||
# create causal mask
|
||||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||||
combined_attention_mask = None
|
||||
if input_shape[-1] > 1:
|
||||
combined_attention_mask = _make_causal_mask(
|
||||
input_shape,
|
||||
inputs_embeds.dtype,
|
||||
device=inputs_embeds.device,
|
||||
past_key_values_length=past_key_values_length,
|
||||
)
|
||||
|
||||
if attention_mask is not None:
|
||||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||||
expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
|
||||
inputs_embeds.device
|
||||
)
|
||||
combined_attention_mask = (
|
||||
expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
|
||||
)
|
||||
|
||||
return combined_attention_mask
|
||||
|
||||
@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
query_embeds: Optional[torch.FloatTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
# retrieve input_ids and inputs_embeds
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
|
||||
elif input_ids is not None:
|
||||
batch_size, seq_length = input_ids.shape
|
||||
elif inputs_embeds is not None:
|
||||
batch_size, seq_length, _ = inputs_embeds.shape
|
||||
else:
|
||||
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
if query_embeds is not None:
|
||||
inputs_embeds = torch.cat([query_embeds, inputs_embeds], dim=1)
|
||||
batch_size, seq_length, _ = inputs_embeds.shape
|
||||
|
||||
seq_length_with_past = seq_length
|
||||
past_key_values_length = 0
|
||||
|
||||
if past_key_values is not None:
|
||||
past_key_values_length = past_key_values[0][0].shape[2]
|
||||
seq_length_with_past = seq_length_with_past + past_key_values_length
|
||||
|
||||
if position_ids is None:
|
||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||
position_ids = torch.arange(
|
||||
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
|
||||
)
|
||||
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
|
||||
else:
|
||||
position_ids = position_ids.view(-1, seq_length).long()
|
||||
|
||||
# embed positions
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones(
|
||||
(batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
|
||||
)
|
||||
attention_mask = self._prepare_decoder_attention_mask(
|
||||
attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
|
||||
)
|
||||
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
if use_cache:
|
||||
logger.warning_once(
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
# decoder layers
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attns = () if output_attentions else None
|
||||
next_decoder_cache = () if use_cache else None
|
||||
|
||||
for idx, decoder_layer in enumerate(self.layers):
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
past_key_value = past_key_values[idx] if past_key_values is not None else None
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
# None for past_key_value
|
||||
return module(*inputs, output_attentions, None)
|
||||
|
||||
return custom_forward
|
||||
|
||||
layer_outputs = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(decoder_layer),
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
None,
|
||||
)
|
||||
else:
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if use_cache:
|
||||
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
|
||||
|
||||
if output_attentions:
|
||||
all_self_attns += (layer_outputs[1],)
|
||||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
|
||||
# add hidden states from the last decoder layer
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
next_cache = next_decoder_cache if use_cache else None
|
||||
if not return_dict:
|
||||
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
||||
return BaseModelOutputWithPast(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=next_cache,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attns,
|
||||
)
|
||||
|
||||
|
||||
class LlamaForCausalLM(LlamaPreTrainedModel):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.model = LlamaModel(config)
|
||||
|
||||
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.model.embed_tokens
|
||||
|
||||
def set_input_embeddings(self, value):
|
||||
self.model.embed_tokens = value
|
||||
|
||||
def get_output_embeddings(self):
|
||||
return self.lm_head
|
||||
|
||||
def set_output_embeddings(self, new_embeddings):
|
||||
self.lm_head = new_embeddings
|
||||
|
||||
def set_decoder(self, decoder):
|
||||
self.model = decoder
|
||||
|
||||
def get_decoder(self):
|
||||
return self.model
|
||||
|
||||
@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
|
||||
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
query_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||
r"""
|
||||
Args:
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||
|
||||
Returns:
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import AutoTokenizer, LlamaForCausalLM
|
||||
|
||||
>>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
|
||||
|
||||
>>> prompt = "Hey, are you consciours? Can you talk to me?"
|
||||
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
||||
|
||||
>>> # Generate
|
||||
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
||||
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||
"Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you."
|
||||
```"""
|
||||
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||
outputs = self.model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
query_embeds=query_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
logits = self.lm_head(hidden_states)
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
# Shift so that tokens < n predict n
|
||||
shift_logits = logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
# Flatten the tokens
|
||||
loss_fct = CrossEntropyLoss()
|
||||
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
||||
shift_labels = shift_labels.view(-1)
|
||||
# Enable model parallelism
|
||||
shift_labels = shift_labels.to(shift_logits.device)
|
||||
loss = loss_fct(shift_logits, shift_labels)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
return (loss,) + output if loss is not None else output
|
||||
|
||||
return CausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
def prepare_inputs_for_generation(
|
||||
self, input_ids, query_embeds=None, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
|
||||
):
|
||||
if past_key_values:
|
||||
input_ids = input_ids[:, -1:]
|
||||
|
||||
position_ids = kwargs.get("position_ids", None)
|
||||
if attention_mask is not None and position_ids is None:
|
||||
# create position_ids on the fly for batch generation
|
||||
position_ids = attention_mask.long().cumsum(-1) - 1
|
||||
position_ids.masked_fill_(attention_mask == 0, 1)
|
||||
if past_key_values:
|
||||
position_ids = position_ids[:, -1].unsqueeze(-1)
|
||||
query_embeds = None
|
||||
|
||||
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
||||
if inputs_embeds is not None and past_key_values is None:
|
||||
model_inputs = {"inputs_embeds": inputs_embeds}
|
||||
else:
|
||||
model_inputs = {"input_ids": input_ids}
|
||||
|
||||
model_inputs.update(
|
||||
{
|
||||
"position_ids": position_ids,
|
||||
"query_embeds": query_embeds,
|
||||
"past_key_values": past_key_values,
|
||||
"use_cache": kwargs.get("use_cache"),
|
||||
"attention_mask": attention_mask,
|
||||
}
|
||||
)
|
||||
return model_inputs
|
||||
|
||||
@staticmethod
|
||||
def _reorder_cache(past_key_values, beam_idx):
|
||||
reordered_past = ()
|
||||
for layer_past in past_key_values:
|
||||
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
|
||||
return reordered_past
|
||||
|
33
minigpt4/processors/__init__.py
Normal file
@ -0,0 +1,33 @@
|
||||
"""
|
||||
Copyright (c) 2022, salesforce.com, inc.
|
||||
All rights reserved.
|
||||
SPDX-License-Identifier: BSD-3-Clause
|
||||
For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
||||
"""
|
||||
|
||||
from minigpt4.processors.base_processor import BaseProcessor
|
||||
from minigpt4.processors.blip_processors import (
|
||||
Blip2ImageTrainProcessor,
|
||||
Blip2ImageEvalProcessor,
|
||||
BlipCaptionProcessor,
|
||||
)
|
||||
|
||||
from minigpt4.common.registry import registry
|
||||
|
||||
__all__ = [
|
||||
"BaseProcessor",
|
||||
"Blip2ImageTrainProcessor",
|
||||
"Blip2ImageEvalProcessor",
|
||||
"BlipCaptionProcessor",
|
||||
]
|
||||
|
||||
|
||||
def load_processor(name, cfg=None):
|
||||
"""
|
||||
Example
|
||||
|
||||
>>> processor = load_processor("alpro_video_train", cfg=None)
|
||||
"""
|
||||
processor = registry.get_processor_class(name).from_config(cfg)
|
||||
|
||||
return processor
|
26
minigpt4/processors/base_processor.py
Normal file
@ -0,0 +1,26 @@
|
||||
"""
|
||||
Copyright (c) 2022, salesforce.com, inc.
|
||||
All rights reserved.
|
||||
SPDX-License-Identifier: BSD-3-Clause
|
||||
For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
||||
"""
|
||||
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
|
||||
class BaseProcessor:
|
||||
def __init__(self):
|
||||
self.transform = lambda x: x
|
||||
return
|
||||
|
||||
def __call__(self, item):
|
||||
return self.transform(item)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, cfg=None):
|
||||
return cls()
|
||||
|
||||
def build(self, **kwargs):
|
||||
cfg = OmegaConf.create(kwargs)
|
||||
|
||||
return self.from_config(cfg)
|