41
MiniGPT4_Train.md
Normal file
@ -0,0 +1,41 @@
|
|||||||
|
## Training of MiniGPT-4
|
||||||
|
|
||||||
|
The training of MiniGPT-4 contains two alignment stages.
|
||||||
|
|
||||||
|
**1. First pretraining stage**
|
||||||
|
|
||||||
|
In the first pretrained stage, the model is trained using image-text pairs from Laion and CC datasets
|
||||||
|
to align the vision and language model. To download and prepare the datasets, please check
|
||||||
|
our [first stage dataset preparation instruction](dataset/README_1_STAGE.md).
|
||||||
|
After the first stage, the visual features are mapped and can be understood by the language
|
||||||
|
model.
|
||||||
|
To launch the first stage training, run the following command. In our experiments, we use 4 A100.
|
||||||
|
You can change the save path in the config file
|
||||||
|
[train_configs/minigpt4_stage1_pretrain.yaml](train_configs/minigpt4_stage1_pretrain.yaml)
|
||||||
|
|
||||||
|
```bash
|
||||||
|
torchrun --nproc-per-node NUM_GPU train.py --cfg-path train_configs/minigpt4_stage1_pretrain.yaml
|
||||||
|
```
|
||||||
|
|
||||||
|
A MiniGPT-4 checkpoint with only stage one training can be downloaded
|
||||||
|
[here (13B)](https://drive.google.com/file/d/1u9FRRBB3VovP1HxCAlpD9Lw4t4P6-Yq8/view?usp=share_link) or [here (7B)](https://drive.google.com/file/d/1HihQtCEXUyBM1i9DQbaK934wW3TZi-h5/view?usp=share_link).
|
||||||
|
Compared to the model after stage two, this checkpoint generate incomplete and repeated sentences frequently.
|
||||||
|
|
||||||
|
|
||||||
|
**2. Second finetuning stage**
|
||||||
|
|
||||||
|
In the second stage, we use a small high quality image-text pair dataset created by ourselves
|
||||||
|
and convert it to a conversation format to further align MiniGPT-4.
|
||||||
|
To download and prepare our second stage dataset, please check our
|
||||||
|
[second stage dataset preparation instruction](dataset/README_2_STAGE.md).
|
||||||
|
To launch the second stage alignment,
|
||||||
|
first specify the path to the checkpoint file trained in stage 1 in
|
||||||
|
[train_configs/minigpt4_stage1_pretrain.yaml](train_configs/minigpt4_stage2_finetune.yaml).
|
||||||
|
You can also specify the output path there.
|
||||||
|
Then, run the following command. In our experiments, we use 1 A100.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
torchrun --nproc-per-node NUM_GPU train.py --cfg-path train_configs/minigpt4_stage2_finetune.yaml
|
||||||
|
```
|
||||||
|
|
||||||
|
After the second stage alignment, MiniGPT-4 is able to talk about the image coherently and user-friendly.
|
BIN
MiniGPT_4.pdf
BIN
MiniGPTv2.pdf
Normal file
@ -1,35 +0,0 @@
|
|||||||
## How to Prepare Vicuna Weight
|
|
||||||
Vicuna is an open-source LLAMA-based LLM that has a performance close to ChatGPT.
|
|
||||||
We currently use the v0 version of Vicuna-13B.
|
|
||||||
|
|
||||||
To prepare Vicuna’s weight, first download Vicuna’s **delta** weight from [https://huggingface.co/lmsys/vicuna-13b-delta-v0](https://huggingface.co/lmsys/vicuna-13b-delta-v0).
|
|
||||||
In case you have git-lfs installed (https://git-lfs.com), this can be done by
|
|
||||||
|
|
||||||
```
|
|
||||||
git lfs install
|
|
||||||
git clone https://huggingface.co/lmsys/vicuna-13b-delta-v0 # more powerful, need at least 24G gpu memory
|
|
||||||
# or
|
|
||||||
git clone https://huggingface.co/lmsys/vicuna-7b-delta-v0 # smaller, need 12G gpu memory
|
|
||||||
```
|
|
||||||
|
|
||||||
Note that this is not directly the working weight, but the difference between the working weight and the original weight of LLAMA-13B. (Due to LLAMA’s rules, we cannot distribute the weight of LLAMA.)
|
|
||||||
|
|
||||||
Then, you need to obtain the original LLAMA-7B or LLAMA-13B weights in the HuggingFace format
|
|
||||||
either following the instruction provided by HuggingFace
|
|
||||||
[here](https://huggingface.co/docs/transformers/main/model_doc/llama) or from the Internet.
|
|
||||||
|
|
||||||
When these two weights are ready, we can use tools from Vicuna’s team to create the real working weight.
|
|
||||||
First, Install their library that is compatible with v0 Vicuna by
|
|
||||||
|
|
||||||
```
|
|
||||||
pip install git+https://github.com/lm-sys/FastChat.git@v0.1.10
|
|
||||||
```
|
|
||||||
|
|
||||||
Then, run the following command to create the final working weight
|
|
||||||
|
|
||||||
```
|
|
||||||
python -m fastchat.model.apply_delta --base /path/to/llama-13bOR7b-hf/ --target /path/to/save/working/vicuna/weight/ --delta /path/to/vicuna-13bOR7b-delta-v0/
|
|
||||||
```
|
|
||||||
|
|
||||||
Now you are good to go!
|
|
||||||
|
|
153
README.md
@ -1,24 +1,48 @@
|
|||||||
# MiniGPT-4: Enhancing Vision-language Understanding with Advanced Large Language Models
|
# MiniGPT-V
|
||||||
[Deyao Zhu](https://tsutikgiau.github.io/)* , [Jun Chen](https://junchen14.github.io/)* (On Job Market!), [Xiaoqian Shen](https://xiaoqian-shen.github.io), [Xiang Li](https://xiangli.ac.cn), and [Mohamed Elhoseiny](https://www.mohamed-elhoseiny.com/). *Equal Contribution
|
|
||||||
|
|
||||||
**King Abdullah University of Science and Technology**
|
<font size='5'>**MiniGPT-v2: Large Language Model as a Unified Interface for Vision-Language Multi-task Learning**</font>
|
||||||
|
|
||||||
|
Jun Chen, Deyao Zhu, Xiaoqian Shen, Xiang Li, Zechun Liu, Pengchuan Zhang, Raghuraman Krishnamoorthi, Vikas Chandra, Yunyang Xiong☨, Mohamed Elhoseiny☨
|
||||||
|
|
||||||
|
☨equal last author
|
||||||
|
|
||||||
|
<a href='https://minigpt-v2.github.io'><img src='https://img.shields.io/badge/Project-Page-Green'></a> <a href='https://github.com/Vision-CAIR/MiniGPT-4/blob/main/MiniGPTv2.pdf'><img src='https://img.shields.io/badge/Paper-PDF-red'></a> <a href='https://minigpt-v2.github.io'><img src='https://img.shields.io/badge/Gradio-Demo-blue'></a> [](https://www.youtube.com/watch?v=atFCwV2hSY4)
|
||||||
|
|
||||||
|
|
||||||
|
<font size='5'>**MiniGPT-4: Enhancing Vision-language Understanding with Advanced Large Language Models**</font>
|
||||||
|
|
||||||
|
Deyao Zhu*, Jun Chen*, Xiaoqian Shen, Xiang Li, Mohamed Elhoseiny
|
||||||
|
|
||||||
|
*equal contribution
|
||||||
|
|
||||||
<a href='https://minigpt-4.github.io'><img src='https://img.shields.io/badge/Project-Page-Green'></a> <a href='https://arxiv.org/abs/2304.10592'><img src='https://img.shields.io/badge/Paper-Arxiv-red'></a> <a href='https://huggingface.co/spaces/Vision-CAIR/minigpt4'><img src='https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue'></a> <a href='https://huggingface.co/Vision-CAIR/MiniGPT-4'><img src='https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Model-blue'></a> [](https://colab.research.google.com/drive/1OK4kYsZphwt5DXchKkzMBjYF6jnkqh4R?usp=sharing) [](https://www.youtube.com/watch?v=__tftoxpBAw&feature=youtu.be)
|
<a href='https://minigpt-4.github.io'><img src='https://img.shields.io/badge/Project-Page-Green'></a> <a href='https://arxiv.org/abs/2304.10592'><img src='https://img.shields.io/badge/Paper-Arxiv-red'></a> <a href='https://huggingface.co/spaces/Vision-CAIR/minigpt4'><img src='https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue'></a> <a href='https://huggingface.co/Vision-CAIR/MiniGPT-4'><img src='https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Model-blue'></a> [](https://colab.research.google.com/drive/1OK4kYsZphwt5DXchKkzMBjYF6jnkqh4R?usp=sharing) [](https://www.youtube.com/watch?v=__tftoxpBAw&feature=youtu.be)
|
||||||
|
|
||||||
|
*King Abdullah University of Science and Technology*
|
||||||
|
|
||||||
## 💡 Get help - [Q&A](https://github.com/Vision-CAIR/MiniGPT-4/discussions/categories/q-a) or [Discord 💬](https://discord.gg/5WdJkjbAeE)
|
## 💡 Get help - [Q&A](https://github.com/Vision-CAIR/MiniGPT-4/discussions/categories/q-a) or [Discord 💬](https://discord.gg/5WdJkjbAeE)
|
||||||
|
|
||||||
|
|
||||||
## News
|
## News
|
||||||
We now provide a llama 2 version of MiniGPT-4
|
[Oct.13 2023] Breaking! We release the first major update with our MiniGPT-v2
|
||||||
|
|
||||||
|
[Aug.28 2023] We now provide a llama 2 version of MiniGPT-4
|
||||||
|
|
||||||
## Online Demo
|
## Online Demo
|
||||||
|
|
||||||
|
Click the image to chat with MiniGPT-v2 around your images
|
||||||
|
[](https://minigpt-v2.github.io/)
|
||||||
|
|
||||||
Click the image to chat with MiniGPT-4 around your images
|
Click the image to chat with MiniGPT-4 around your images
|
||||||
[](https://minigpt-4.github.io)
|
[](https://minigpt-4.github.io)
|
||||||
|
|
||||||
|
|
||||||
## Examples
|
## MiniGPT-v2 Examples
|
||||||
|
|
||||||
|

|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
## MiniGPT-4 Examples
|
||||||
| | |
|
| | |
|
||||||
:-------------------------:|:-------------------------:
|
:-------------------------:|:-------------------------:
|
||||||
 | 
|
 | 
|
||||||
@ -28,17 +52,6 @@ More examples can be found in the [project page](https://minigpt-4.github.io).
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
## Introduction
|
|
||||||
- MiniGPT-4 aligns a frozen visual encoder from BLIP-2 with a frozen LLM, Vicuna, using just one projection layer.
|
|
||||||
- We train MiniGPT-4 with two stages. The first traditional pretraining stage is trained using roughly 5 million aligned image-text pairs in 10 hours using 4 A100s. After the first stage, Vicuna is able to understand the image. But the generation ability of Vicuna is heavily impacted.
|
|
||||||
- To address this issue and improve usability, we propose a novel way to create high-quality image-text pairs by the model itself and ChatGPT together. Based on this, we then create a small (3500 pairs in total) yet high-quality dataset.
|
|
||||||
- The second finetuning stage is trained on this dataset in a conversation template to significantly improve its generation reliability and overall usability. To our surprise, this stage is computationally efficient and takes only around 7 minutes with a single A100.
|
|
||||||
- MiniGPT-4 yields many emerging vision-language capabilities similar to those demonstrated in GPT-4.
|
|
||||||
|
|
||||||
|
|
||||||

|
|
||||||
|
|
||||||
|
|
||||||
## Getting Started
|
## Getting Started
|
||||||
### Installation
|
### Installation
|
||||||
|
|
||||||
@ -56,42 +69,62 @@ conda activate minigpt4
|
|||||||
|
|
||||||
**2. Prepare the pretrained LLM weights**
|
**2. Prepare the pretrained LLM weights**
|
||||||
|
|
||||||
Currently, we provide both Vicuna V0 and Llama 2 version of MiniGPT-4.
|
**MiniGPT-v2** is based on Llama2 Chat 7B. For **MiniGPT-4**, we have both Vicuna V0 and Llama 2 version.
|
||||||
Download the corresponding LLM weights from the following huggingface space via clone the repository using git-lfs.
|
Download the corresponding LLM weights from the following huggingface space via clone the repository using git-lfs.
|
||||||
|
|
||||||
| Vicuna V0 13B | Vicuna V0 7B | Llama 2 Chat 7B |
|
| Llama 2 Chat 7B | Vicuna V0 13B | Vicuna V0 7B |
|
||||||
:------------------------------------------------------------------------------------------------:|:----------------------------------------------------------------------------------------------:|:----------------------------------------------------------------------------------------------:
|
:------------------------------------------------------------------------------------------------:|:----------------------------------------------------------------------------------------------:|:----------------------------------------------------------------------------------------------:
|
||||||
[Downlad](https://huggingface.co/Vision-CAIR/vicuna/tree/main) | [Download](https://huggingface.co/Vision-CAIR/vicuna-7b/tree/main) | [Download](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf/tree/main)
|
[Download](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf/tree/main) | [Downlad](https://huggingface.co/Vision-CAIR/vicuna/tree/main) | [Download](https://huggingface.co/Vision-CAIR/vicuna-7b/tree/main)
|
||||||
|
|
||||||
|
|
||||||
Then, set the path to the vicuna weight in the model config file
|
Then, set the variable *llama_model* in the model config file to the LLM weight path.
|
||||||
[here](minigpt4/configs/models/minigpt4_vicuna0.yaml#L18) at Line 18
|
|
||||||
and/or the path to the llama2 weight in the model config file
|
* For MiniGPT-v2, set the LLM path
|
||||||
|
[here](minigpt4/configs/models/minigpt_v2.yaml#L15) at Line 14.
|
||||||
|
|
||||||
|
* For MiniGPT-4 (Llama2), set the LLM path
|
||||||
[here](minigpt4/configs/models/minigpt4_llama2.yaml#L15) at Line 15.
|
[here](minigpt4/configs/models/minigpt4_llama2.yaml#L15) at Line 15.
|
||||||
|
|
||||||
**3. Prepare the pretrained MiniGPT-4 checkpoint**
|
* For MiniGPT-4 (Vicuna), set the LLM path
|
||||||
|
[here](minigpt4/configs/models/minigpt4_vicuna0.yaml#L18) at Line 18
|
||||||
|
|
||||||
Download the pretrained checkpoints according to the Vicuna model you prepare.
|
**3. Prepare the pretrained model checkpoints**
|
||||||
|
|
||||||
| Checkpoint Aligned with Vicuna 13B | Checkpoint Aligned with Vicuna 7B | Checkpoint Aligned with Llama 2 Chat 7B |
|
Download the pretrained model checkpoints
|
||||||
:------------------------------------------------------------------------------------------------:|:-----------------------------------------------------------------------------------------------:|:----------------------------------------------------------------------------------------------:
|
|
||||||
[Downlad](https://drive.google.com/file/d/1a4zLvaiDBr-36pasffmgpvH5P7CKmpze/view?usp=share_link) | [Download](https://drive.google.com/file/d/1RY9jV0dyqLX-o38LrumkKRh6Jtaop58R/view?usp=sharing) | [Download](https://drive.google.com/file/d/11nAPjEok8eAGGEG1N2vXo3kBLCg0WgUk/view?usp=sharing)
|
|
||||||
|
|
||||||
|
|
||||||
Then, set the path to the pretrained checkpoint in the evaluation config file
|
| MiniGPT-v2 (LLaMA-2 Chat 7B) |
|
||||||
|
|------------------------------|
|
||||||
|
| [Download](https://drive.google.com/file/d/1aVbfW7nkCSYx99_vCRyP1sOlQiWVSnAl/view?usp=sharing) |
|
||||||
|
|
||||||
|
For **MiniGPT-v2**, set the path to the pretrained checkpoint in the evaluation config file
|
||||||
|
in [eval_configs/minigptv2_eval.yaml](eval_configs/minigptv2_eval.yaml#L10) at Line 8.
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
| MiniGPT-4 (Vicuna 13B) | MiniGPT-4 (Vicuna 7B) | MiniGPT-4 (LLaMA-2 Chat 7B) |
|
||||||
|
|----------------------------|---------------------------|---------------------------------|
|
||||||
|
| [Download](https://drive.google.com/file/d/1a4zLvaiDBr-36pasffmgpvH5P7CKmpze/view?usp=share_link) | [Download](https://drive.google.com/file/d/1RY9jV0dyqLX-o38LrumkKRh6Jtaop58R/view?usp=sharing) | [Download](https://drive.google.com/file/d/11nAPjEok8eAGGEG1N2vXo3kBLCg0WgUk/view?usp=sharing) |
|
||||||
|
|
||||||
|
For **MiniGPT-4**, set the path to the pretrained checkpoint in the evaluation config file
|
||||||
in [eval_configs/minigpt4_eval.yaml](eval_configs/minigpt4_eval.yaml#L10) at Line 8 for Vicuna version or [eval_configs/minigpt4_llama2_eval.yaml](eval_configs/minigpt4_llama2_eval.yaml#L10) for LLama2 version.
|
in [eval_configs/minigpt4_eval.yaml](eval_configs/minigpt4_eval.yaml#L10) at Line 8 for Vicuna version or [eval_configs/minigpt4_llama2_eval.yaml](eval_configs/minigpt4_llama2_eval.yaml#L10) for LLama2 version.
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
### Launching Demo Locally
|
### Launching Demo Locally
|
||||||
|
|
||||||
Try out our demo [demo.py](demo.py) for the vicuna version on your local machine by running
|
For MiniGPT-v2, run
|
||||||
|
```
|
||||||
|
python demo_v2.py --cfg-path eval_configs/minigpt4v2_eval.yaml --gpu-id 0
|
||||||
|
```
|
||||||
|
|
||||||
|
For MiniGPT-4 (Vicuna version), run
|
||||||
|
|
||||||
```
|
```
|
||||||
python demo.py --cfg-path eval_configs/minigpt4_eval.yaml --gpu-id 0
|
python demo.py --cfg-path eval_configs/minigpt4_eval.yaml --gpu-id 0
|
||||||
```
|
```
|
||||||
|
|
||||||
or for Llama 2 version by
|
For MiniGPT-4 (Llama2 version), run
|
||||||
|
|
||||||
```
|
```
|
||||||
python demo.py --cfg-path eval_configs/minigpt4_llama2_eval.yaml --gpu-id 0
|
python demo.py --cfg-path eval_configs/minigpt4_llama2_eval.yaml --gpu-id 0
|
||||||
@ -101,52 +134,17 @@ python demo.py --cfg-path eval_configs/minigpt4_llama2_eval.yaml --gpu-id 0
|
|||||||
To save GPU memory, LLMs loads as 8 bit by default, with a beam search width of 1.
|
To save GPU memory, LLMs loads as 8 bit by default, with a beam search width of 1.
|
||||||
This configuration requires about 23G GPU memory for 13B LLM and 11.5G GPU memory for 7B LLM.
|
This configuration requires about 23G GPU memory for 13B LLM and 11.5G GPU memory for 7B LLM.
|
||||||
For more powerful GPUs, you can run the model
|
For more powerful GPUs, you can run the model
|
||||||
in 16 bit by setting `low_resource` to `False` in the relevant config file
|
in 16 bit by setting `low_resource` to `False` in the relevant config file:
|
||||||
(line 6 of either [minigpt4_eval.yaml](eval_configs/minigpt4_eval.yaml#6) if using Vicuna or [minigpt4_llama2_eval.yaml](eval_configs/minigpt4_llama2_eval.yaml#6) if using Llama 2) and use a larger beam search width.
|
|
||||||
|
|
||||||
Thanks [@WangRongsheng](https://github.com/WangRongsheng), you can also run our code on [Colab](https://colab.research.google.com/drive/1OK4kYsZphwt5DXchKkzMBjYF6jnkqh4R?usp=sharing)
|
* MiniGPT-v2: [minigptv2_eval.yaml](eval_configs/minigptv2_eval.yaml#6)
|
||||||
|
* MiniGPT-4 (Llama2): [minigpt4_llama2_eval.yaml](eval_configs/minigpt4_llama2_eval.yaml#6)
|
||||||
|
* MiniGPT-4 (Vicuna): [minigpt4_eval.yaml](eval_configs/minigpt4_eval.yaml#6)
|
||||||
|
|
||||||
|
Thanks [@WangRongsheng](https://github.com/WangRongsheng), you can also run MiniGPT-4 on [Colab](https://colab.research.google.com/drive/1OK4kYsZphwt5DXchKkzMBjYF6jnkqh4R?usp=sharing)
|
||||||
|
|
||||||
|
|
||||||
### Training
|
### Training
|
||||||
The training of MiniGPT-4 contains two alignment stages.
|
For training details of MiniGPT-4, check [here](MiniGPT4_Train.md).
|
||||||
|
|
||||||
**1. First pretraining stage**
|
|
||||||
|
|
||||||
In the first pretrained stage, the model is trained using image-text pairs from Laion and CC datasets
|
|
||||||
to align the vision and language model. To download and prepare the datasets, please check
|
|
||||||
our [first stage dataset preparation instruction](dataset/README_1_STAGE.md).
|
|
||||||
After the first stage, the visual features are mapped and can be understood by the language
|
|
||||||
model.
|
|
||||||
To launch the first stage training, run the following command. In our experiments, we use 4 A100.
|
|
||||||
You can change the save path in the config file
|
|
||||||
[train_configs/minigpt4_stage1_pretrain.yaml](train_configs/minigpt4_stage1_pretrain.yaml)
|
|
||||||
|
|
||||||
```bash
|
|
||||||
torchrun --nproc-per-node NUM_GPU train.py --cfg-path train_configs/minigpt4_stage1_pretrain.yaml
|
|
||||||
```
|
|
||||||
|
|
||||||
A MiniGPT-4 checkpoint with only stage one training can be downloaded
|
|
||||||
[here (13B)](https://drive.google.com/file/d/1u9FRRBB3VovP1HxCAlpD9Lw4t4P6-Yq8/view?usp=share_link) or [here (7B)](https://drive.google.com/file/d/1HihQtCEXUyBM1i9DQbaK934wW3TZi-h5/view?usp=share_link).
|
|
||||||
Compared to the model after stage two, this checkpoint generate incomplete and repeated sentences frequently.
|
|
||||||
|
|
||||||
|
|
||||||
**2. Second finetuning stage**
|
|
||||||
|
|
||||||
In the second stage, we use a small high quality image-text pair dataset created by ourselves
|
|
||||||
and convert it to a conversation format to further align MiniGPT-4.
|
|
||||||
To download and prepare our second stage dataset, please check our
|
|
||||||
[second stage dataset preparation instruction](dataset/README_2_STAGE.md).
|
|
||||||
To launch the second stage alignment,
|
|
||||||
first specify the path to the checkpoint file trained in stage 1 in
|
|
||||||
[train_configs/minigpt4_stage1_pretrain.yaml](train_configs/minigpt4_stage2_finetune.yaml).
|
|
||||||
You can also specify the output path there.
|
|
||||||
Then, run the following command. In our experiments, we use 1 A100.
|
|
||||||
|
|
||||||
```bash
|
|
||||||
torchrun --nproc-per-node NUM_GPU train.py --cfg-path train_configs/minigpt4_stage2_finetune.yaml
|
|
||||||
```
|
|
||||||
|
|
||||||
After the second stage alignment, MiniGPT-4 is able to talk about the image coherently and user-friendly.
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@ -156,10 +154,19 @@ After the second stage alignment, MiniGPT-4 is able to talk about the image cohe
|
|||||||
+ [BLIP2](https://huggingface.co/docs/transformers/main/model_doc/blip-2) The model architecture of MiniGPT-4 follows BLIP-2. Don't forget to check this great open-source work if you don't know it before!
|
+ [BLIP2](https://huggingface.co/docs/transformers/main/model_doc/blip-2) The model architecture of MiniGPT-4 follows BLIP-2. Don't forget to check this great open-source work if you don't know it before!
|
||||||
+ [Lavis](https://github.com/salesforce/LAVIS) This repository is built upon Lavis!
|
+ [Lavis](https://github.com/salesforce/LAVIS) This repository is built upon Lavis!
|
||||||
+ [Vicuna](https://github.com/lm-sys/FastChat) The fantastic language ability of Vicuna with only 13B parameters is just amazing. And it is open-source!
|
+ [Vicuna](https://github.com/lm-sys/FastChat) The fantastic language ability of Vicuna with only 13B parameters is just amazing. And it is open-source!
|
||||||
|
+ [LLaMA](https://github.com/facebookresearch/llama) The strong open-sourced LLaMA 2 language model.
|
||||||
|
|
||||||
|
|
||||||
If you're using MiniGPT-4 in your research or applications, please cite using this BibTeX:
|
If you're using MiniGPT-4/MiniGPT-v2 in your research or applications, please cite using this BibTeX:
|
||||||
```bibtex
|
```bibtex
|
||||||
|
|
||||||
|
@article{Chen2023minigpt,
|
||||||
|
title={MiniGPT-v2: Large Language Model as a Unified Interface for Vision-Language Multi-task Learning},
|
||||||
|
author={Chen, Jun and Zhu, Deyao and Shen, Xiaoqian and Li, Xiang and Liu, Zechu and Zhang, Pengchuan and Krishnamoorthi, Raghuraman and Chandra, Vikas and Xiong, Yunyang and Elhoseiny, Mohamed},
|
||||||
|
journal={github},
|
||||||
|
year={2023}
|
||||||
|
}
|
||||||
|
|
||||||
@article{zhu2023minigpt,
|
@article{zhu2023minigpt,
|
||||||
title={MiniGPT-4: Enhancing Vision-Language Understanding with Advanced Large Language Models},
|
title={MiniGPT-4: Enhancing Vision-Language Understanding with Advanced Large Language Models},
|
||||||
author={Zhu, Deyao and Chen, Jun and Shen, Xiaoqian and Li, Xiang and Elhoseiny, Mohamed},
|
author={Zhu, Deyao and Chen, Jun and Shen, Xiaoqian and Li, Xiang and Elhoseiny, Mohamed},
|
||||||
|
16
demo.py
@ -7,10 +7,12 @@ import torch
|
|||||||
import torch.backends.cudnn as cudnn
|
import torch.backends.cudnn as cudnn
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
|
|
||||||
|
from transformers import StoppingCriteriaList
|
||||||
|
|
||||||
from minigpt4.common.config import Config
|
from minigpt4.common.config import Config
|
||||||
from minigpt4.common.dist_utils import get_rank
|
from minigpt4.common.dist_utils import get_rank
|
||||||
from minigpt4.common.registry import registry
|
from minigpt4.common.registry import registry
|
||||||
from minigpt4.conversation.conversation import Chat, CONV_VISION_Vicuna0, CONV_VISION_LLama2
|
from minigpt4.conversation.conversation import Chat, CONV_VISION_Vicuna0, CONV_VISION_LLama2, StoppingCriteriaSub
|
||||||
|
|
||||||
# imports modules for registration
|
# imports modules for registration
|
||||||
from minigpt4.datasets.builders import *
|
from minigpt4.datasets.builders import *
|
||||||
@ -66,7 +68,12 @@ CONV_VISION = conv_dict[model_config.model_type]
|
|||||||
|
|
||||||
vis_processor_cfg = cfg.datasets_cfg.cc_sbu_align.vis_processor.train
|
vis_processor_cfg = cfg.datasets_cfg.cc_sbu_align.vis_processor.train
|
||||||
vis_processor = registry.get_processor_class(vis_processor_cfg.name).from_config(vis_processor_cfg)
|
vis_processor = registry.get_processor_class(vis_processor_cfg.name).from_config(vis_processor_cfg)
|
||||||
chat = Chat(model, vis_processor, device='cuda:{}'.format(args.gpu_id))
|
|
||||||
|
stop_words_ids = [[835], [2277, 29937]]
|
||||||
|
stop_words_ids = [torch.tensor(ids).to(device='cuda:{}'.format(args.gpu_id)) for ids in stop_words_ids]
|
||||||
|
stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)])
|
||||||
|
|
||||||
|
chat = Chat(model, vis_processor, device='cuda:{}'.format(args.gpu_id), stopping_criteria=stopping_criteria)
|
||||||
print('Initialization Finished')
|
print('Initialization Finished')
|
||||||
|
|
||||||
|
|
||||||
@ -89,6 +96,7 @@ def upload_img(gr_img, text_input, chat_state):
|
|||||||
chat_state = CONV_VISION.copy()
|
chat_state = CONV_VISION.copy()
|
||||||
img_list = []
|
img_list = []
|
||||||
llm_message = chat.upload_img(gr_img, chat_state, img_list)
|
llm_message = chat.upload_img(gr_img, chat_state, img_list)
|
||||||
|
chat.encode_img(img_list)
|
||||||
return gr.update(interactive=False), gr.update(interactive=True, placeholder='Type and press Enter'), gr.update(value="Start Chatting", interactive=False), chat_state, img_list
|
return gr.update(interactive=False), gr.update(interactive=True, placeholder='Type and press Enter'), gr.update(value="Start Chatting", interactive=False), chat_state, img_list
|
||||||
|
|
||||||
|
|
||||||
@ -124,7 +132,7 @@ with gr.Blocks() as demo:
|
|||||||
gr.Markdown(article)
|
gr.Markdown(article)
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
with gr.Column(scale=0.5):
|
with gr.Column(scale=1):
|
||||||
image = gr.Image(type="pil")
|
image = gr.Image(type="pil")
|
||||||
upload_button = gr.Button(value="Upload & Start Chat", interactive=True, variant="primary")
|
upload_button = gr.Button(value="Upload & Start Chat", interactive=True, variant="primary")
|
||||||
clear = gr.Button("Restart")
|
clear = gr.Button("Restart")
|
||||||
@ -147,7 +155,7 @@ with gr.Blocks() as demo:
|
|||||||
label="Temperature",
|
label="Temperature",
|
||||||
)
|
)
|
||||||
|
|
||||||
with gr.Column():
|
with gr.Column(scale=2):
|
||||||
chat_state = gr.State()
|
chat_state = gr.State()
|
||||||
img_list = gr.State()
|
img_list = gr.State()
|
||||||
chatbot = gr.Chatbot(label='MiniGPT-4')
|
chatbot = gr.Chatbot(label='MiniGPT-4')
|
||||||
|
662
demo_v2.py
Normal file
@ -0,0 +1,662 @@
|
|||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
from collections import defaultdict
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import re
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from PIL import Image
|
||||||
|
import torch
|
||||||
|
import html
|
||||||
|
import gradio as gr
|
||||||
|
|
||||||
|
import torchvision.transforms as T
|
||||||
|
import torch.backends.cudnn as cudnn
|
||||||
|
|
||||||
|
from minigpt4.common.config import Config
|
||||||
|
|
||||||
|
from minigpt4.common.registry import registry
|
||||||
|
from minigpt4.conversation.conversation import Conversation, SeparatorStyle, Chat
|
||||||
|
|
||||||
|
# imports modules for registration
|
||||||
|
from minigpt4.datasets.builders import *
|
||||||
|
from minigpt4.models import *
|
||||||
|
from minigpt4.processors import *
|
||||||
|
from minigpt4.runners import *
|
||||||
|
from minigpt4.tasks import *
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args():
|
||||||
|
parser = argparse.ArgumentParser(description="Demo")
|
||||||
|
parser.add_argument("--cfg-path", default='eval_configs/minigptv2_eval.yaml',
|
||||||
|
help="path to configuration file.")
|
||||||
|
parser.add_argument("--gpu-id", type=int, default=0, help="specify the gpu to load the model.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--options",
|
||||||
|
nargs="+",
|
||||||
|
help="override some settings in the used config, the key-value pair "
|
||||||
|
"in xxx=yyy format will be merged into config file (deprecate), "
|
||||||
|
"change to --cfg-options instead.",
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
return args
|
||||||
|
|
||||||
|
|
||||||
|
random.seed(42)
|
||||||
|
np.random.seed(42)
|
||||||
|
torch.manual_seed(42)
|
||||||
|
|
||||||
|
cudnn.benchmark = False
|
||||||
|
cudnn.deterministic = True
|
||||||
|
|
||||||
|
print('Initializing Chat')
|
||||||
|
args = parse_args()
|
||||||
|
cfg = Config(args)
|
||||||
|
|
||||||
|
device = 'cuda:{}'.format(args.gpu_id)
|
||||||
|
|
||||||
|
model_config = cfg.model_cfg
|
||||||
|
model_config.device_8bit = args.gpu_id
|
||||||
|
model_cls = registry.get_model_class(model_config.arch)
|
||||||
|
model = model_cls.from_config(model_config).to(device)
|
||||||
|
bounding_box_size = 100
|
||||||
|
|
||||||
|
vis_processor_cfg = cfg.datasets_cfg.cc_sbu_align.vis_processor.train
|
||||||
|
vis_processor = registry.get_processor_class(vis_processor_cfg.name).from_config(vis_processor_cfg)
|
||||||
|
|
||||||
|
model = model.eval()
|
||||||
|
|
||||||
|
CONV_VISION = Conversation(
|
||||||
|
system="",
|
||||||
|
roles=(r"<s>[INST] ", r" [/INST]"),
|
||||||
|
messages=[],
|
||||||
|
offset=2,
|
||||||
|
sep_style=SeparatorStyle.SINGLE,
|
||||||
|
sep="",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def extract_substrings(string):
|
||||||
|
# first check if there is no-finished bracket
|
||||||
|
index = string.rfind('}')
|
||||||
|
if index != -1:
|
||||||
|
string = string[:index + 1]
|
||||||
|
|
||||||
|
pattern = r'<p>(.*?)\}(?!<)'
|
||||||
|
matches = re.findall(pattern, string)
|
||||||
|
substrings = [match for match in matches]
|
||||||
|
|
||||||
|
return substrings
|
||||||
|
|
||||||
|
|
||||||
|
def is_overlapping(rect1, rect2):
|
||||||
|
x1, y1, x2, y2 = rect1
|
||||||
|
x3, y3, x4, y4 = rect2
|
||||||
|
return not (x2 < x3 or x1 > x4 or y2 < y3 or y1 > y4)
|
||||||
|
|
||||||
|
|
||||||
|
def computeIoU(bbox1, bbox2):
|
||||||
|
x1, y1, x2, y2 = bbox1
|
||||||
|
x3, y3, x4, y4 = bbox2
|
||||||
|
intersection_x1 = max(x1, x3)
|
||||||
|
intersection_y1 = max(y1, y3)
|
||||||
|
intersection_x2 = min(x2, x4)
|
||||||
|
intersection_y2 = min(y2, y4)
|
||||||
|
intersection_area = max(0, intersection_x2 - intersection_x1 + 1) * max(0, intersection_y2 - intersection_y1 + 1)
|
||||||
|
bbox1_area = (x2 - x1 + 1) * (y2 - y1 + 1)
|
||||||
|
bbox2_area = (x4 - x3 + 1) * (y4 - y3 + 1)
|
||||||
|
union_area = bbox1_area + bbox2_area - intersection_area
|
||||||
|
iou = intersection_area / union_area
|
||||||
|
return iou
|
||||||
|
|
||||||
|
|
||||||
|
def save_tmp_img(visual_img):
|
||||||
|
file_name = "".join([str(random.randint(0, 9)) for _ in range(5)]) + ".jpg"
|
||||||
|
file_path = "/tmp/" + file_name
|
||||||
|
visual_img.save(file_path)
|
||||||
|
return file_path
|
||||||
|
|
||||||
|
|
||||||
|
def mask2bbox(mask):
|
||||||
|
if mask is None:
|
||||||
|
return ''
|
||||||
|
mask = mask.resize([100, 100], resample=Image.NEAREST)
|
||||||
|
mask = np.array(mask)[:, :, 0]
|
||||||
|
|
||||||
|
rows = np.any(mask, axis=1)
|
||||||
|
cols = np.any(mask, axis=0)
|
||||||
|
|
||||||
|
if rows.sum():
|
||||||
|
# Get the top, bottom, left, and right boundaries
|
||||||
|
rmin, rmax = np.where(rows)[0][[0, -1]]
|
||||||
|
cmin, cmax = np.where(cols)[0][[0, -1]]
|
||||||
|
bbox = '{{<{}><{}><{}><{}>}}'.format(cmin, rmin, cmax, rmax)
|
||||||
|
else:
|
||||||
|
bbox = ''
|
||||||
|
|
||||||
|
return bbox
|
||||||
|
|
||||||
|
|
||||||
|
def escape_markdown(text):
|
||||||
|
# List of Markdown special characters that need to be escaped
|
||||||
|
md_chars = ['<', '>']
|
||||||
|
|
||||||
|
# Escape each special character
|
||||||
|
for char in md_chars:
|
||||||
|
text = text.replace(char, '\\' + char)
|
||||||
|
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
|
def reverse_escape(text):
|
||||||
|
md_chars = ['\\<', '\\>']
|
||||||
|
|
||||||
|
for char in md_chars:
|
||||||
|
text = text.replace(char, char[1:])
|
||||||
|
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
|
colors = [
|
||||||
|
(255, 0, 0),
|
||||||
|
(0, 255, 0),
|
||||||
|
(0, 0, 255),
|
||||||
|
(210, 210, 0),
|
||||||
|
(255, 0, 255),
|
||||||
|
(0, 255, 255),
|
||||||
|
(114, 128, 250),
|
||||||
|
(0, 165, 255),
|
||||||
|
(0, 128, 0),
|
||||||
|
(144, 238, 144),
|
||||||
|
(238, 238, 175),
|
||||||
|
(255, 191, 0),
|
||||||
|
(0, 128, 0),
|
||||||
|
(226, 43, 138),
|
||||||
|
(255, 0, 255),
|
||||||
|
(0, 215, 255),
|
||||||
|
]
|
||||||
|
|
||||||
|
color_map = {
|
||||||
|
f"{color_id}": f"#{hex(color[2])[2:].zfill(2)}{hex(color[1])[2:].zfill(2)}{hex(color[0])[2:].zfill(2)}" for
|
||||||
|
color_id, color in enumerate(colors)
|
||||||
|
}
|
||||||
|
|
||||||
|
used_colors = colors
|
||||||
|
|
||||||
|
|
||||||
|
def visualize_all_bbox_together(image, generation):
|
||||||
|
if image is None:
|
||||||
|
return None, ''
|
||||||
|
|
||||||
|
generation = html.unescape(generation)
|
||||||
|
print('gen begin', generation)
|
||||||
|
|
||||||
|
image_width, image_height = image.size
|
||||||
|
image = image.resize([500, int(500 / image_width * image_height)])
|
||||||
|
image_width, image_height = image.size
|
||||||
|
|
||||||
|
string_list = extract_substrings(generation)
|
||||||
|
if string_list: # it is grounding or detection
|
||||||
|
mode = 'all'
|
||||||
|
entities = defaultdict(list)
|
||||||
|
i = 0
|
||||||
|
j = 0
|
||||||
|
for string in string_list:
|
||||||
|
try:
|
||||||
|
obj, string = string.split('</p>')
|
||||||
|
except ValueError:
|
||||||
|
print('wrong string: ', string)
|
||||||
|
continue
|
||||||
|
bbox_list = string.split('<delim>')
|
||||||
|
flag = False
|
||||||
|
for bbox_string in bbox_list:
|
||||||
|
integers = re.findall(r'-?\d+', bbox_string)
|
||||||
|
if len(integers) == 4:
|
||||||
|
x0, y0, x1, y1 = int(integers[0]), int(integers[1]), int(integers[2]), int(integers[3])
|
||||||
|
left = x0 / bounding_box_size * image_width
|
||||||
|
bottom = y0 / bounding_box_size * image_height
|
||||||
|
right = x1 / bounding_box_size * image_width
|
||||||
|
top = y1 / bounding_box_size * image_height
|
||||||
|
|
||||||
|
entities[obj].append([left, bottom, right, top])
|
||||||
|
|
||||||
|
j += 1
|
||||||
|
flag = True
|
||||||
|
if flag:
|
||||||
|
i += 1
|
||||||
|
else:
|
||||||
|
integers = re.findall(r'-?\d+', generation)
|
||||||
|
|
||||||
|
if len(integers) == 4: # it is refer
|
||||||
|
mode = 'single'
|
||||||
|
|
||||||
|
entities = list()
|
||||||
|
x0, y0, x1, y1 = int(integers[0]), int(integers[1]), int(integers[2]), int(integers[3])
|
||||||
|
left = x0 / bounding_box_size * image_width
|
||||||
|
bottom = y0 / bounding_box_size * image_height
|
||||||
|
right = x1 / bounding_box_size * image_width
|
||||||
|
top = y1 / bounding_box_size * image_height
|
||||||
|
entities.append([left, bottom, right, top])
|
||||||
|
else:
|
||||||
|
# don't detect any valid bbox to visualize
|
||||||
|
return None, ''
|
||||||
|
|
||||||
|
if len(entities) == 0:
|
||||||
|
return None, ''
|
||||||
|
|
||||||
|
if isinstance(image, Image.Image):
|
||||||
|
image_h = image.height
|
||||||
|
image_w = image.width
|
||||||
|
image = np.array(image)
|
||||||
|
|
||||||
|
elif isinstance(image, str):
|
||||||
|
if os.path.exists(image):
|
||||||
|
pil_img = Image.open(image).convert("RGB")
|
||||||
|
image = np.array(pil_img)[:, :, [2, 1, 0]]
|
||||||
|
image_h = pil_img.height
|
||||||
|
image_w = pil_img.width
|
||||||
|
else:
|
||||||
|
raise ValueError(f"invaild image path, {image}")
|
||||||
|
elif isinstance(image, torch.Tensor):
|
||||||
|
|
||||||
|
image_tensor = image.cpu()
|
||||||
|
reverse_norm_mean = torch.tensor([0.48145466, 0.4578275, 0.40821073])[:, None, None]
|
||||||
|
reverse_norm_std = torch.tensor([0.26862954, 0.26130258, 0.27577711])[:, None, None]
|
||||||
|
image_tensor = image_tensor * reverse_norm_std + reverse_norm_mean
|
||||||
|
pil_img = T.ToPILImage()(image_tensor)
|
||||||
|
image_h = pil_img.height
|
||||||
|
image_w = pil_img.width
|
||||||
|
image = np.array(pil_img)[:, :, [2, 1, 0]]
|
||||||
|
else:
|
||||||
|
raise ValueError(f"invaild image format, {type(image)} for {image}")
|
||||||
|
|
||||||
|
indices = list(range(len(entities)))
|
||||||
|
|
||||||
|
new_image = image.copy()
|
||||||
|
|
||||||
|
previous_bboxes = []
|
||||||
|
# size of text
|
||||||
|
text_size = 0.5
|
||||||
|
# thickness of text
|
||||||
|
text_line = 1 # int(max(1 * min(image_h, image_w) / 512, 1))
|
||||||
|
box_line = 2
|
||||||
|
(c_width, text_height), _ = cv2.getTextSize("F", cv2.FONT_HERSHEY_COMPLEX, text_size, text_line)
|
||||||
|
base_height = int(text_height * 0.675)
|
||||||
|
text_offset_original = text_height - base_height
|
||||||
|
text_spaces = 2
|
||||||
|
|
||||||
|
# num_bboxes = sum(len(x[-1]) for x in entities)
|
||||||
|
used_colors = colors # random.sample(colors, k=num_bboxes)
|
||||||
|
|
||||||
|
color_id = -1
|
||||||
|
for entity_idx, entity_name in enumerate(entities):
|
||||||
|
if mode == 'single' or mode == 'identify':
|
||||||
|
bboxes = entity_name
|
||||||
|
bboxes = [bboxes]
|
||||||
|
else:
|
||||||
|
bboxes = entities[entity_name]
|
||||||
|
color_id += 1
|
||||||
|
for bbox_id, (x1_norm, y1_norm, x2_norm, y2_norm) in enumerate(bboxes):
|
||||||
|
skip_flag = False
|
||||||
|
orig_x1, orig_y1, orig_x2, orig_y2 = int(x1_norm), int(y1_norm), int(x2_norm), int(y2_norm)
|
||||||
|
|
||||||
|
color = used_colors[entity_idx % len(used_colors)] # tuple(np.random.randint(0, 255, size=3).tolist())
|
||||||
|
new_image = cv2.rectangle(new_image, (orig_x1, orig_y1), (orig_x2, orig_y2), color, box_line)
|
||||||
|
|
||||||
|
if mode == 'all':
|
||||||
|
l_o, r_o = box_line // 2 + box_line % 2, box_line // 2 + box_line % 2 + 1
|
||||||
|
|
||||||
|
x1 = orig_x1 - l_o
|
||||||
|
y1 = orig_y1 - l_o
|
||||||
|
|
||||||
|
if y1 < text_height + text_offset_original + 2 * text_spaces:
|
||||||
|
y1 = orig_y1 + r_o + text_height + text_offset_original + 2 * text_spaces
|
||||||
|
x1 = orig_x1 + r_o
|
||||||
|
|
||||||
|
# add text background
|
||||||
|
(text_width, text_height), _ = cv2.getTextSize(f" {entity_name}", cv2.FONT_HERSHEY_COMPLEX, text_size,
|
||||||
|
text_line)
|
||||||
|
text_bg_x1, text_bg_y1, text_bg_x2, text_bg_y2 = x1, y1 - (
|
||||||
|
text_height + text_offset_original + 2 * text_spaces), x1 + text_width, y1
|
||||||
|
|
||||||
|
for prev_bbox in previous_bboxes:
|
||||||
|
if computeIoU((text_bg_x1, text_bg_y1, text_bg_x2, text_bg_y2), prev_bbox['bbox']) > 0.95 and \
|
||||||
|
prev_bbox['phrase'] == entity_name:
|
||||||
|
skip_flag = True
|
||||||
|
break
|
||||||
|
while is_overlapping((text_bg_x1, text_bg_y1, text_bg_x2, text_bg_y2), prev_bbox['bbox']):
|
||||||
|
text_bg_y1 += (text_height + text_offset_original + 2 * text_spaces)
|
||||||
|
text_bg_y2 += (text_height + text_offset_original + 2 * text_spaces)
|
||||||
|
y1 += (text_height + text_offset_original + 2 * text_spaces)
|
||||||
|
|
||||||
|
if text_bg_y2 >= image_h:
|
||||||
|
text_bg_y1 = max(0, image_h - (text_height + text_offset_original + 2 * text_spaces))
|
||||||
|
text_bg_y2 = image_h
|
||||||
|
y1 = image_h
|
||||||
|
break
|
||||||
|
if not skip_flag:
|
||||||
|
alpha = 0.5
|
||||||
|
for i in range(text_bg_y1, text_bg_y2):
|
||||||
|
for j in range(text_bg_x1, text_bg_x2):
|
||||||
|
if i < image_h and j < image_w:
|
||||||
|
if j < text_bg_x1 + 1.35 * c_width:
|
||||||
|
# original color
|
||||||
|
bg_color = color
|
||||||
|
else:
|
||||||
|
# white
|
||||||
|
bg_color = [255, 255, 255]
|
||||||
|
new_image[i, j] = (alpha * new_image[i, j] + (1 - alpha) * np.array(bg_color)).astype(
|
||||||
|
np.uint8)
|
||||||
|
|
||||||
|
cv2.putText(
|
||||||
|
new_image, f" {entity_name}", (x1, y1 - text_offset_original - 1 * text_spaces),
|
||||||
|
cv2.FONT_HERSHEY_COMPLEX, text_size, (0, 0, 0), text_line, cv2.LINE_AA
|
||||||
|
)
|
||||||
|
|
||||||
|
previous_bboxes.append(
|
||||||
|
{'bbox': (text_bg_x1, text_bg_y1, text_bg_x2, text_bg_y2), 'phrase': entity_name})
|
||||||
|
|
||||||
|
if mode == 'all':
|
||||||
|
def color_iterator(colors):
|
||||||
|
while True:
|
||||||
|
for color in colors:
|
||||||
|
yield color
|
||||||
|
|
||||||
|
color_gen = color_iterator(colors)
|
||||||
|
|
||||||
|
# Add colors to phrases and remove <p></p>
|
||||||
|
def colored_phrases(match):
|
||||||
|
phrase = match.group(1)
|
||||||
|
color = next(color_gen)
|
||||||
|
return f'<span style="color:rgb{color}">{phrase}</span>'
|
||||||
|
|
||||||
|
print('gen before', generation)
|
||||||
|
generation = re.sub(r'{<\d+><\d+><\d+><\d+>}|<delim>', '', generation)
|
||||||
|
print('gen after', generation)
|
||||||
|
generation_colored = re.sub(r'<p>(.*?)</p>', colored_phrases, generation)
|
||||||
|
else:
|
||||||
|
generation_colored = ''
|
||||||
|
|
||||||
|
pil_image = Image.fromarray(new_image)
|
||||||
|
return pil_image, generation_colored
|
||||||
|
|
||||||
|
|
||||||
|
def gradio_reset(chat_state, img_list):
|
||||||
|
if chat_state is not None:
|
||||||
|
chat_state.messages = []
|
||||||
|
if img_list is not None:
|
||||||
|
img_list = []
|
||||||
|
return None, gr.update(value=None, interactive=True), gr.update(placeholder='Upload your image and chat',
|
||||||
|
interactive=True), chat_state, img_list
|
||||||
|
|
||||||
|
|
||||||
|
def image_upload_trigger(upload_flag, replace_flag, img_list):
|
||||||
|
# set the upload flag to true when receive a new image.
|
||||||
|
# if there is an old image (and old conversation), set the replace flag to true to reset the conv later.
|
||||||
|
print('flag', upload_flag, replace_flag)
|
||||||
|
print("SET UPLOAD FLAG!")
|
||||||
|
upload_flag = 1
|
||||||
|
if img_list:
|
||||||
|
print("SET REPLACE FLAG!")
|
||||||
|
replace_flag = 1
|
||||||
|
print('flag', upload_flag, replace_flag)
|
||||||
|
return upload_flag, replace_flag
|
||||||
|
|
||||||
|
|
||||||
|
def example_trigger(text_input, image, upload_flag, replace_flag, img_list):
|
||||||
|
# set the upload flag to true when receive a new image.
|
||||||
|
# if there is an old image (and old conversation), set the replace flag to true to reset the conv later.
|
||||||
|
print('flag', upload_flag, replace_flag)
|
||||||
|
print("SET UPLOAD FLAG!")
|
||||||
|
upload_flag = 1
|
||||||
|
if img_list or replace_flag == 1:
|
||||||
|
print("SET REPLACE FLAG!")
|
||||||
|
replace_flag = 1
|
||||||
|
|
||||||
|
print('flag', upload_flag, replace_flag)
|
||||||
|
return upload_flag, replace_flag
|
||||||
|
|
||||||
|
|
||||||
|
def gradio_ask(user_message, chatbot, chat_state, gr_img, img_list, upload_flag, replace_flag):
|
||||||
|
if isinstance(gr_img, dict):
|
||||||
|
gr_img, mask = gr_img['image'], gr_img['mask']
|
||||||
|
else:
|
||||||
|
mask = None
|
||||||
|
|
||||||
|
if '[identify]' in user_message:
|
||||||
|
# check if user provide bbox in the text input
|
||||||
|
integers = re.findall(r'-?\d+', user_message)
|
||||||
|
if len(integers) != 4: # no bbox in text
|
||||||
|
bbox = mask2bbox(mask)
|
||||||
|
user_message = user_message + bbox
|
||||||
|
|
||||||
|
if len(user_message) == 0:
|
||||||
|
return gr.update(interactive=True, placeholder='Input should not be empty!'), chatbot, chat_state
|
||||||
|
|
||||||
|
if chat_state is None:
|
||||||
|
chat_state = CONV_VISION.copy()
|
||||||
|
|
||||||
|
print('upload flag: {}'.format(upload_flag))
|
||||||
|
if upload_flag:
|
||||||
|
if replace_flag:
|
||||||
|
print('RESET!!!!!!!')
|
||||||
|
chat_state = CONV_VISION.copy() # new image, reset everything
|
||||||
|
replace_flag = 0
|
||||||
|
chatbot = []
|
||||||
|
print('UPLOAD IMAGE!!')
|
||||||
|
img_list = []
|
||||||
|
llm_message = chat.upload_img(gr_img, chat_state, img_list)
|
||||||
|
upload_flag = 0
|
||||||
|
|
||||||
|
chat.ask(user_message, chat_state)
|
||||||
|
|
||||||
|
chatbot = chatbot + [[user_message, None]]
|
||||||
|
|
||||||
|
if '[identify]' in user_message:
|
||||||
|
visual_img, _ = visualize_all_bbox_together(gr_img, user_message)
|
||||||
|
if visual_img is not None:
|
||||||
|
print('Visualizing the input')
|
||||||
|
file_path = save_tmp_img(visual_img)
|
||||||
|
chatbot = chatbot + [[(file_path,), None]]
|
||||||
|
|
||||||
|
return '', chatbot, chat_state, img_list, upload_flag, replace_flag
|
||||||
|
|
||||||
|
|
||||||
|
def gradio_answer(chatbot, chat_state, img_list, temperature):
|
||||||
|
llm_message = chat.answer(conv=chat_state,
|
||||||
|
img_list=img_list,
|
||||||
|
temperature=temperature,
|
||||||
|
max_new_tokens=500,
|
||||||
|
max_length=2000)[0]
|
||||||
|
chatbot[-1][1] = llm_message
|
||||||
|
return chatbot, chat_state
|
||||||
|
|
||||||
|
|
||||||
|
def gradio_stream_answer(chatbot, chat_state, img_list, temperature):
|
||||||
|
print('chat state', chat_state.get_prompt())
|
||||||
|
if not isinstance(img_list[0], torch.Tensor):
|
||||||
|
chat.encode_img(img_list)
|
||||||
|
streamer = chat.stream_answer(conv=chat_state,
|
||||||
|
img_list=img_list,
|
||||||
|
temperature=temperature,
|
||||||
|
max_new_tokens=500,
|
||||||
|
max_length=2000)
|
||||||
|
output = ''
|
||||||
|
for new_output in streamer:
|
||||||
|
escapped = escape_markdown(new_output)
|
||||||
|
output += escapped
|
||||||
|
chatbot[-1][1] = output
|
||||||
|
yield chatbot, chat_state
|
||||||
|
# print('message: ', chat_state.messages)
|
||||||
|
chat_state.messages[-1][1] = '</s>'
|
||||||
|
return chatbot, chat_state
|
||||||
|
|
||||||
|
|
||||||
|
def gradio_visualize(chatbot, gr_img):
|
||||||
|
if isinstance(gr_img, dict):
|
||||||
|
gr_img, mask = gr_img['image'], gr_img['mask']
|
||||||
|
|
||||||
|
unescaped = reverse_escape(chatbot[-1][1])
|
||||||
|
visual_img, generation_color = visualize_all_bbox_together(gr_img, unescaped)
|
||||||
|
if visual_img is not None:
|
||||||
|
print('Visualizing the output')
|
||||||
|
if len(generation_color):
|
||||||
|
chatbot[-1][1] = generation_color
|
||||||
|
file_path = save_tmp_img(visual_img)
|
||||||
|
chatbot = chatbot + [[None, (file_path,)]]
|
||||||
|
|
||||||
|
return chatbot
|
||||||
|
|
||||||
|
|
||||||
|
def gradio_taskselect(idx):
|
||||||
|
prompt_list = [
|
||||||
|
'',
|
||||||
|
'[grounding] describe this image in detail',
|
||||||
|
'[refer] ',
|
||||||
|
'[detection] ',
|
||||||
|
'[identify] what is this ',
|
||||||
|
'[vqa] '
|
||||||
|
]
|
||||||
|
instruct_list = [
|
||||||
|
'**Hint:** Type in whatever you want',
|
||||||
|
'**Hint:** Send the command to generate a grounded image description',
|
||||||
|
'**Hint:** Type in a phrase about an object in the image and send the command',
|
||||||
|
'**Hint:** Type in a caption or phrase, and see object locations in the image',
|
||||||
|
'**Hint:** Draw a bounding box on the uploaded image then send the command. Click the "clear" botton on the top right of the image before redraw',
|
||||||
|
'**Hint:** Send a question to get a short answer',
|
||||||
|
]
|
||||||
|
return prompt_list[idx], instruct_list[idx]
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
chat = Chat(model, vis_processor, device=device)
|
||||||
|
|
||||||
|
title = """<h1 align="center">MiniGPT-v2 Demo</h1>"""
|
||||||
|
description = 'Welcome to Our MiniGPT-v2 Chatbot Demo!'
|
||||||
|
# article = """<p><a href='https://minigpt-v2.github.io'><img src='https://img.shields.io/badge/Project-Page-Green'></a></p><p><a href='https://github.com/Vision-CAIR/MiniGPT-4/blob/main/MiniGPTv2.pdf'><img src='https://img.shields.io/badge/Paper-PDF-red'></a></p><p><a href='https://github.com/Vision-CAIR/MiniGPT-4'><img src='https://img.shields.io/badge/GitHub-Repo-blue'></a></p><p><a href='https://www.youtube.com/watch?v=atFCwV2hSY4'><img src='https://img.shields.io/badge/YouTube-Video-red'></a></p>"""
|
||||||
|
article = """<p><a href='https://minigpt-v2.github.io'><img src='https://img.shields.io/badge/Project-Page-Green'></a></p>"""
|
||||||
|
|
||||||
|
introduction = '''
|
||||||
|
For Abilities Involving Visual Grounding:
|
||||||
|
1. Grounding: CLICK **Send** to generate a grounded image description.
|
||||||
|
2. Refer: Input a referring object and CLICK **Send**.
|
||||||
|
3. Detection: Write a caption or phrase, and CLICK **Send**.
|
||||||
|
4. Identify: Draw the bounding box on the uploaded image window and CLICK **Send** to generate the bounding box. (CLICK "clear" button before re-drawing next time).
|
||||||
|
5. VQA: Input a visual question and CLICK **Send**.
|
||||||
|
6. No Tag: Input whatever you want and CLICK **Send** without any tagging
|
||||||
|
|
||||||
|
You can also simply chat in free form!
|
||||||
|
'''
|
||||||
|
|
||||||
|
text_input = gr.Textbox(placeholder='Upload your image and chat', interactive=True, show_label=False, container=False,
|
||||||
|
scale=8)
|
||||||
|
with gr.Blocks() as demo:
|
||||||
|
gr.Markdown(title)
|
||||||
|
# gr.Markdown(description)
|
||||||
|
gr.Markdown(article)
|
||||||
|
|
||||||
|
with gr.Row():
|
||||||
|
with gr.Column(scale=0.5):
|
||||||
|
image = gr.Image(type="pil", tool='sketch', brush_radius=20)
|
||||||
|
|
||||||
|
temperature = gr.Slider(
|
||||||
|
minimum=0.1,
|
||||||
|
maximum=2.0,
|
||||||
|
value=1.0,
|
||||||
|
step=0.1,
|
||||||
|
interactive=True,
|
||||||
|
label="Temperature",
|
||||||
|
)
|
||||||
|
|
||||||
|
clear = gr.Button("Restart")
|
||||||
|
|
||||||
|
gr.Markdown(introduction)
|
||||||
|
|
||||||
|
with gr.Column():
|
||||||
|
chat_state = gr.State(value=None)
|
||||||
|
img_list = gr.State(value=[])
|
||||||
|
chatbot = gr.Chatbot(label='MiniGPT-v2')
|
||||||
|
|
||||||
|
dataset = gr.Dataset(
|
||||||
|
components=[gr.Textbox(visible=False)],
|
||||||
|
samples=[['No Tag'], ['Grounding'], ['Refer'], ['Detection'], ['Identify'], ['VQA']],
|
||||||
|
type="index",
|
||||||
|
label='Task Shortcuts',
|
||||||
|
)
|
||||||
|
task_inst = gr.Markdown('**Hint:** Upload your image and chat')
|
||||||
|
with gr.Row():
|
||||||
|
text_input.render()
|
||||||
|
send = gr.Button("Send", variant='primary', size='sm', scale=1)
|
||||||
|
|
||||||
|
upload_flag = gr.State(value=0)
|
||||||
|
replace_flag = gr.State(value=0)
|
||||||
|
image.upload(image_upload_trigger, [upload_flag, replace_flag, img_list], [upload_flag, replace_flag])
|
||||||
|
|
||||||
|
with gr.Row():
|
||||||
|
with gr.Column():
|
||||||
|
gr.Examples(examples=[
|
||||||
|
["examples_v2/office.jpg", "[grounding] describe this image in detail", upload_flag, replace_flag,
|
||||||
|
img_list],
|
||||||
|
["examples_v2/sofa.jpg", "[detection] sofas", upload_flag, replace_flag, img_list],
|
||||||
|
["examples_v2/2000x1372_wmkn_0012149409555.jpg", "[refer] the world cup", upload_flag, replace_flag,
|
||||||
|
img_list],
|
||||||
|
["examples_v2/KFC-20-for-20-Nuggets.jpg", "[identify] what is this {<4><50><30><65>}", upload_flag,
|
||||||
|
replace_flag, img_list],
|
||||||
|
], inputs=[image, text_input, upload_flag, replace_flag, img_list], fn=example_trigger,
|
||||||
|
outputs=[upload_flag, replace_flag])
|
||||||
|
with gr.Column():
|
||||||
|
gr.Examples(examples=[
|
||||||
|
["examples_v2/glip_test.jpg", "[vqa] where should I hide in this room when playing hide and seek",
|
||||||
|
upload_flag, replace_flag, img_list],
|
||||||
|
["examples_v2/float.png", "Please write a poem about the image", upload_flag, replace_flag, img_list],
|
||||||
|
["examples_v2/thief.png", "Is the weapon fateful", upload_flag, replace_flag, img_list],
|
||||||
|
["examples_v2/cockdial.png", "What might happen in this image in the next second", upload_flag,
|
||||||
|
replace_flag, img_list],
|
||||||
|
], inputs=[image, text_input, upload_flag, replace_flag, img_list], fn=example_trigger,
|
||||||
|
outputs=[upload_flag, replace_flag])
|
||||||
|
|
||||||
|
dataset.click(
|
||||||
|
gradio_taskselect,
|
||||||
|
inputs=[dataset],
|
||||||
|
outputs=[text_input, task_inst],
|
||||||
|
show_progress="hidden",
|
||||||
|
postprocess=False,
|
||||||
|
queue=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
text_input.submit(
|
||||||
|
gradio_ask,
|
||||||
|
[text_input, chatbot, chat_state, image, img_list, upload_flag, replace_flag],
|
||||||
|
[text_input, chatbot, chat_state, img_list, upload_flag, replace_flag], queue=False
|
||||||
|
).success(
|
||||||
|
gradio_stream_answer,
|
||||||
|
[chatbot, chat_state, img_list, temperature],
|
||||||
|
[chatbot, chat_state]
|
||||||
|
).success(
|
||||||
|
gradio_visualize,
|
||||||
|
[chatbot, image],
|
||||||
|
[chatbot],
|
||||||
|
queue=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
send.click(
|
||||||
|
gradio_ask,
|
||||||
|
[text_input, chatbot, chat_state, image, img_list, upload_flag, replace_flag],
|
||||||
|
[text_input, chatbot, chat_state, img_list, upload_flag, replace_flag], queue=False
|
||||||
|
).success(
|
||||||
|
gradio_stream_answer,
|
||||||
|
[chatbot, chat_state, img_list, temperature],
|
||||||
|
[chatbot, chat_state]
|
||||||
|
).success(
|
||||||
|
gradio_visualize,
|
||||||
|
[chatbot, image],
|
||||||
|
[chatbot],
|
||||||
|
queue=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
clear.click(gradio_reset, [chat_state, img_list], [chatbot, image, text_input, chat_state, img_list], queue=False)
|
||||||
|
|
||||||
|
demo.launch(share=True, enable_queue=True)
|
@ -7,57 +7,27 @@ dependencies:
|
|||||||
- python=3.9
|
- python=3.9
|
||||||
- cudatoolkit
|
- cudatoolkit
|
||||||
- pip
|
- pip
|
||||||
- pytorch=1.12.1
|
|
||||||
- pytorch-mutex=1.0=cuda
|
|
||||||
- torchaudio=0.12.1
|
|
||||||
- torchvision=0.13.1
|
|
||||||
- pip:
|
- pip:
|
||||||
- accelerate==0.16.0
|
- torch==2.0.0
|
||||||
- aiohttp==3.8.4
|
- torchaudio
|
||||||
- aiosignal==1.3.1
|
- torchvision
|
||||||
- async-timeout==4.0.2
|
- huggingface-hub==0.18.0
|
||||||
- attrs==22.2.0
|
|
||||||
- bitsandbytes==0.37.0
|
|
||||||
- cchardet==2.1.7
|
|
||||||
- chardet==5.1.0
|
|
||||||
- contourpy==1.0.7
|
|
||||||
- cycler==0.11.0
|
|
||||||
- filelock==3.9.0
|
|
||||||
- fonttools==4.38.0
|
|
||||||
- frozenlist==1.3.3
|
|
||||||
- huggingface-hub==0.13.4
|
|
||||||
- importlib-resources==5.12.0
|
|
||||||
- kiwisolver==1.4.4
|
|
||||||
- matplotlib==3.7.0
|
- matplotlib==3.7.0
|
||||||
- multidict==6.0.4
|
|
||||||
- openai==0.27.0
|
|
||||||
- packaging==23.0
|
|
||||||
- psutil==5.9.4
|
- psutil==5.9.4
|
||||||
- pycocotools==2.0.6
|
- iopath
|
||||||
- pyparsing==3.0.9
|
|
||||||
- python-dateutil==2.8.2
|
|
||||||
- pyyaml==6.0
|
- pyyaml==6.0
|
||||||
- regex==2022.10.31
|
- regex==2022.10.31
|
||||||
- tokenizers==0.13.2
|
- tokenizers==0.13.2
|
||||||
- tqdm==4.64.1
|
- tqdm==4.64.1
|
||||||
- transformers==4.28.0
|
- transformers==4.30.0
|
||||||
- timm==0.6.13
|
- timm==0.6.13
|
||||||
- spacy==3.5.1
|
|
||||||
- webdataset==0.2.48
|
- webdataset==0.2.48
|
||||||
- scikit-learn==1.2.2
|
|
||||||
- scipy==1.10.1
|
|
||||||
- yarl==1.8.2
|
|
||||||
- zipp==3.14.0
|
|
||||||
- omegaconf==2.3.0
|
- omegaconf==2.3.0
|
||||||
- opencv-python==4.7.0.72
|
- opencv-python==4.7.0.72
|
||||||
- iopath==0.1.10
|
|
||||||
- decord==0.6.0
|
- decord==0.6.0
|
||||||
- tenacity==8.2.2
|
- peft==0.2.0
|
||||||
- peft
|
|
||||||
- pycocoevalcap
|
|
||||||
- sentence-transformers
|
- sentence-transformers
|
||||||
- umap-learn
|
- gradio==3.47.1
|
||||||
- notebook
|
- accelerate==0.20.3
|
||||||
- gradio==3.24.1
|
- bitsandbytes==0.37.0
|
||||||
- gradio-client==0.0.8
|
|
||||||
- wandb
|
- wandb
|
||||||
|
@ -1,11 +1,11 @@
|
|||||||
model:
|
model:
|
||||||
arch: mini_gpt4
|
arch: minigpt4
|
||||||
model_type: pretrain_vicuna0
|
model_type: pretrain_vicuna0
|
||||||
max_txt_len: 160
|
max_txt_len: 160
|
||||||
end_sym: "###"
|
end_sym: "###"
|
||||||
low_resource: True
|
low_resource: True
|
||||||
prompt_template: '###Human: {} ###Assistant: '
|
prompt_template: '###Human: {} ###Assistant: '
|
||||||
ckpt: '/path/to/checkpoint/'
|
ckpt: 'please set this value to the path of pretrained checkpoint'
|
||||||
|
|
||||||
|
|
||||||
datasets:
|
datasets:
|
||||||
|
@ -1,11 +1,11 @@
|
|||||||
model:
|
model:
|
||||||
arch: mini_gpt4
|
arch: minigpt4
|
||||||
model_type: pretrain_llama2
|
model_type: pretrain_llama2
|
||||||
max_txt_len: 160
|
max_txt_len: 160
|
||||||
end_sym: "</s>"
|
end_sym: "</s>"
|
||||||
low_resource: True
|
low_resource: True
|
||||||
prompt_template: '[INST] {} [/INST] '
|
prompt_template: '[INST] {} [/INST] '
|
||||||
ckpt: '/path/to/checkpoint/'
|
ckpt: 'please set this value to the path of pretrained checkpoint'
|
||||||
|
|
||||||
|
|
||||||
datasets:
|
datasets:
|
||||||
|
24
eval_configs/minigptv2_eval.yaml
Normal file
@ -0,0 +1,24 @@
|
|||||||
|
model:
|
||||||
|
arch: minigpt_v2
|
||||||
|
model_type: pretrain
|
||||||
|
max_txt_len: 160
|
||||||
|
end_sym: "</s>"
|
||||||
|
low_resource: True
|
||||||
|
prompt_template: '[INST] {} [/INST]'
|
||||||
|
ckpt: 'please set this value to the path of pretrained checkpoint'
|
||||||
|
lora_r: 64
|
||||||
|
lora_alpha: 16
|
||||||
|
|
||||||
|
|
||||||
|
datasets:
|
||||||
|
cc_sbu_align:
|
||||||
|
vis_processor:
|
||||||
|
train:
|
||||||
|
name: "blip2_image_eval"
|
||||||
|
image_size: 448
|
||||||
|
text_processor:
|
||||||
|
train:
|
||||||
|
name: "blip_caption"
|
||||||
|
|
||||||
|
run:
|
||||||
|
task: image_text_pretrain
|
BIN
examples_v2/2000x1372_wmkn_0012149409555.jpg
Executable file
After Width: | Height: | Size: 91 KiB |
BIN
examples_v2/KFC-20-for-20-Nuggets.jpg
Executable file
After Width: | Height: | Size: 83 KiB |
BIN
examples_v2/cockdial.png
Executable file
After Width: | Height: | Size: 1.5 MiB |
BIN
examples_v2/float.png
Executable file
After Width: | Height: | Size: 1.2 MiB |
BIN
examples_v2/glip_test.jpg
Executable file
After Width: | Height: | Size: 92 KiB |
BIN
examples_v2/office.jpg
Executable file
After Width: | Height: | Size: 25 KiB |
BIN
examples_v2/sofa.jpg
Executable file
After Width: | Height: | Size: 116 KiB |
BIN
examples_v2/thief.png
Executable file
After Width: | Height: | Size: 865 KiB |
BIN
figs/demo.png
Normal file
After Width: | Height: | Size: 1.1 MiB |
@ -1,5 +1,5 @@
|
|||||||
model:
|
model:
|
||||||
arch: mini_gpt4
|
arch: minigpt4
|
||||||
|
|
||||||
# vit encoder
|
# vit encoder
|
||||||
image_size: 224
|
image_size: 224
|
||||||
@ -12,7 +12,7 @@ model:
|
|||||||
# generation configs
|
# generation configs
|
||||||
prompt: ""
|
prompt: ""
|
||||||
|
|
||||||
llama_model: "/path/to/llama2/weight"
|
llama_model: "please set this value to the path of llama2-chat-7b"
|
||||||
|
|
||||||
preprocess:
|
preprocess:
|
||||||
vis_processor:
|
vis_processor:
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
model:
|
model:
|
||||||
arch: mini_gpt4
|
arch: minigpt4
|
||||||
|
|
||||||
# vit encoder
|
# vit encoder
|
||||||
image_size: 224
|
image_size: 224
|
||||||
@ -15,7 +15,7 @@ model:
|
|||||||
# generation configs
|
# generation configs
|
||||||
prompt: ""
|
prompt: ""
|
||||||
|
|
||||||
llama_model: "/path/to/vicuna/weight"
|
llama_model: "please set this value to the path of vicuna model"
|
||||||
|
|
||||||
preprocess:
|
preprocess:
|
||||||
vis_processor:
|
vis_processor:
|
||||||
|
31
minigpt4/configs/models/minigpt_v2.yaml
Executable file
@ -0,0 +1,31 @@
|
|||||||
|
model:
|
||||||
|
arch: minigpt_v2
|
||||||
|
|
||||||
|
# vit encoder
|
||||||
|
image_size: 448
|
||||||
|
drop_path_rate: 0
|
||||||
|
use_grad_checkpoint: False
|
||||||
|
vit_precision: "fp16"
|
||||||
|
freeze_vit: True
|
||||||
|
|
||||||
|
# generation configs
|
||||||
|
prompt: ""
|
||||||
|
|
||||||
|
llama_model: "please set this value to the path of llama2-chat-7b"
|
||||||
|
lora_r: 64
|
||||||
|
lora_alpha: 16
|
||||||
|
|
||||||
|
|
||||||
|
preprocess:
|
||||||
|
vis_processor:
|
||||||
|
train:
|
||||||
|
name: "blip2_image_train"
|
||||||
|
image_size: 448
|
||||||
|
eval:
|
||||||
|
name: "blip2_image_eval"
|
||||||
|
image_size: 448
|
||||||
|
text_processor:
|
||||||
|
train:
|
||||||
|
name: "blip_caption"
|
||||||
|
eval:
|
||||||
|
name: "blip_caption"
|
@ -1,10 +1,11 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import time
|
import time
|
||||||
|
from threading import Thread
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from transformers import AutoTokenizer, AutoModelForCausalLM, LlamaTokenizer
|
from transformers import AutoTokenizer, AutoModelForCausalLM, LlamaTokenizer
|
||||||
from transformers import StoppingCriteria, StoppingCriteriaList
|
from transformers import StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer
|
||||||
|
|
||||||
import dataclasses
|
import dataclasses
|
||||||
from enum import auto, Enum
|
from enum import auto, Enum
|
||||||
@ -129,13 +130,16 @@ CONV_VISION_LLama2 = Conversation(
|
|||||||
|
|
||||||
|
|
||||||
class Chat:
|
class Chat:
|
||||||
def __init__(self, model, vis_processor, device='cuda:0'):
|
def __init__(self, model, vis_processor, device='cuda:0', stopping_criteria=None):
|
||||||
self.device = device
|
self.device = device
|
||||||
self.model = model
|
self.model = model
|
||||||
self.vis_processor = vis_processor
|
self.vis_processor = vis_processor
|
||||||
stop_words_ids = [torch.tensor([835]).to(self.device),
|
|
||||||
torch.tensor([2277, 29937]).to(self.device)] # '###' can be encoded in two different ways.
|
if stopping_criteria is not None:
|
||||||
self.stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)])
|
self.stopping_criteria = stopping_criteria
|
||||||
|
else:
|
||||||
|
stop_words_ids = [torch.tensor([2]).to(self.device)]
|
||||||
|
self.stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)])
|
||||||
|
|
||||||
def ask(self, text, conv):
|
def ask(self, text, conv):
|
||||||
if len(conv.messages) > 0 and conv.messages[-1][0] == conv.roles[0] \
|
if len(conv.messages) > 0 and conv.messages[-1][0] == conv.roles[0] \
|
||||||
@ -144,8 +148,8 @@ class Chat:
|
|||||||
else:
|
else:
|
||||||
conv.append_message(conv.roles[0], text)
|
conv.append_message(conv.roles[0], text)
|
||||||
|
|
||||||
def answer(self, conv, img_list, max_new_tokens=300, num_beams=1, min_length=1, top_p=0.9,
|
def answer_prepare(self, conv, img_list, max_new_tokens=300, num_beams=1, min_length=1, top_p=0.9,
|
||||||
repetition_penalty=1.0, length_penalty=1, temperature=1.0, max_length=2000):
|
repetition_penalty=1.05, length_penalty=1, temperature=1.0, max_length=2000):
|
||||||
conv.append_message(conv.roles[1], None)
|
conv.append_message(conv.roles[1], None)
|
||||||
embs = self.get_context_emb(conv, img_list)
|
embs = self.get_context_emb(conv, img_list)
|
||||||
|
|
||||||
@ -154,10 +158,9 @@ class Chat:
|
|||||||
print('Warning: The number of tokens in current conversation exceeds the max length. '
|
print('Warning: The number of tokens in current conversation exceeds the max length. '
|
||||||
'The model will not see the contexts outside the range.')
|
'The model will not see the contexts outside the range.')
|
||||||
begin_idx = max(0, current_max_len - max_length)
|
begin_idx = max(0, current_max_len - max_length)
|
||||||
|
|
||||||
embs = embs[:, begin_idx:]
|
embs = embs[:, begin_idx:]
|
||||||
|
|
||||||
outputs = self.model.llama_model.generate(
|
generation_kwargs = dict(
|
||||||
inputs_embeds=embs,
|
inputs_embeds=embs,
|
||||||
max_new_tokens=max_new_tokens,
|
max_new_tokens=max_new_tokens,
|
||||||
stopping_criteria=self.stopping_criteria,
|
stopping_criteria=self.stopping_criteria,
|
||||||
@ -169,18 +172,31 @@ class Chat:
|
|||||||
length_penalty=length_penalty,
|
length_penalty=length_penalty,
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
)
|
)
|
||||||
output_token = outputs[0]
|
return generation_kwargs
|
||||||
if output_token[0] == 0: # the model might output a unknow token <unk> at the beginning. remove it
|
|
||||||
output_token = output_token[1:]
|
def answer(self, conv, img_list, **kargs):
|
||||||
if output_token[0] == 1: # some users find that there is a start token <s> at the beginning. remove it
|
generation_dict = self.answer_prepare(conv, img_list, **kargs)
|
||||||
output_token = output_token[1:]
|
|
||||||
output_text = self.model.llama_tokenizer.decode(output_token, add_special_tokens=False)
|
output_token = self.model.llama_model.generate(**generation_dict)[0]
|
||||||
|
output_text = self.model.llama_tokenizer.decode(output_token, skip_special_tokens=True)
|
||||||
|
|
||||||
output_text = output_text.split('###')[0] # remove the stop sign '###'
|
output_text = output_text.split('###')[0] # remove the stop sign '###'
|
||||||
output_text = output_text.split('Assistant:')[-1].strip()
|
output_text = output_text.split('Assistant:')[-1].strip()
|
||||||
|
|
||||||
conv.messages[-1][1] = output_text
|
conv.messages[-1][1] = output_text
|
||||||
return output_text, output_token.cpu().numpy()
|
return output_text, output_token.cpu().numpy()
|
||||||
|
|
||||||
def upload_img(self, image, conv, img_list):
|
def stream_answer(self, conv, img_list, **kargs):
|
||||||
|
generation_kwargs = self.answer_prepare(conv, img_list, **kargs)
|
||||||
|
streamer = TextIteratorStreamer(self.model.llama_tokenizer, skip_special_tokens=True)
|
||||||
|
generation_kwargs['streamer'] = streamer
|
||||||
|
thread = Thread(target=self.model.llama_model.generate, kwargs=generation_kwargs)
|
||||||
|
thread.start()
|
||||||
|
return streamer
|
||||||
|
|
||||||
|
def encode_img(self, img_list):
|
||||||
|
image = img_list[0]
|
||||||
|
img_list.pop(0)
|
||||||
if isinstance(image, str): # is a image path
|
if isinstance(image, str): # is a image path
|
||||||
raw_image = Image.open(image).convert('RGB')
|
raw_image = Image.open(image).convert('RGB')
|
||||||
image = self.vis_processor(raw_image).unsqueeze(0).to(self.device)
|
image = self.vis_processor(raw_image).unsqueeze(0).to(self.device)
|
||||||
@ -194,9 +210,12 @@ class Chat:
|
|||||||
|
|
||||||
image_emb, _ = self.model.encode_img(image)
|
image_emb, _ = self.model.encode_img(image)
|
||||||
img_list.append(image_emb)
|
img_list.append(image_emb)
|
||||||
|
|
||||||
|
def upload_img(self, image, conv, img_list):
|
||||||
conv.append_message(conv.roles[0], "<Img><ImageHere></Img>")
|
conv.append_message(conv.roles[0], "<Img><ImageHere></Img>")
|
||||||
|
img_list.append(image)
|
||||||
msg = "Received."
|
msg = "Received."
|
||||||
# self.conv.append_message(self.conv.roles[1], msg)
|
|
||||||
return msg
|
return msg
|
||||||
|
|
||||||
def get_context_emb(self, conv, img_list):
|
def get_context_emb(self, conv, img_list):
|
||||||
@ -209,7 +228,9 @@ class Chat:
|
|||||||
# only add bos to the first seg
|
# only add bos to the first seg
|
||||||
for i, seg in enumerate(prompt_segs)
|
for i, seg in enumerate(prompt_segs)
|
||||||
]
|
]
|
||||||
seg_embs = [self.model.llama_model.model.embed_tokens(seg_t) for seg_t in seg_tokens]
|
print('debug device: ', self.device)
|
||||||
|
print('debug model device: ', self.model.device)
|
||||||
|
seg_embs = [self.model.embed_tokens(seg_t) for seg_t in seg_tokens]
|
||||||
mixed_embs = [emb for pair in zip(seg_embs[:-1], img_list) for emb in pair] + [seg_embs[-1]]
|
mixed_embs = [emb for pair in zip(seg_embs[:-1], img_list) for emb in pair] + [seg_embs[-1]]
|
||||||
mixed_embs = torch.cat(mixed_embs, dim=1)
|
mixed_embs = torch.cat(mixed_embs, dim=1)
|
||||||
return mixed_embs
|
return mixed_embs
|
||||||
|
@ -11,16 +11,18 @@ from omegaconf import OmegaConf
|
|||||||
|
|
||||||
from minigpt4.common.registry import registry
|
from minigpt4.common.registry import registry
|
||||||
from minigpt4.models.base_model import BaseModel
|
from minigpt4.models.base_model import BaseModel
|
||||||
from minigpt4.models.blip2 import Blip2Base
|
from minigpt4.models.minigpt_base import MiniGPTBase
|
||||||
from minigpt4.models.mini_gpt4 import MiniGPT4
|
from minigpt4.models.minigpt4 import MiniGPT4
|
||||||
|
from minigpt4.models.minigpt_v2 import MiniGPTv2
|
||||||
from minigpt4.processors.base_processor import BaseProcessor
|
from minigpt4.processors.base_processor import BaseProcessor
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"load_model",
|
"load_model",
|
||||||
"BaseModel",
|
"BaseModel",
|
||||||
"Blip2Base",
|
"MiniGPTBase",
|
||||||
"MiniGPT4",
|
"MiniGPT4",
|
||||||
|
"MiniGPTv2"
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@ -5,15 +5,26 @@
|
|||||||
For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
|
||||||
import os
|
import os
|
||||||
|
import logging
|
||||||
|
import contextlib
|
||||||
|
|
||||||
|
from omegaconf import OmegaConf
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
from transformers import BertTokenizer, LlamaTokenizer
|
||||||
|
from transformers.models.llama.modeling_llama import LlamaForCausalLM
|
||||||
|
from peft import (
|
||||||
|
LoraConfig,
|
||||||
|
get_peft_model,
|
||||||
|
prepare_model_for_int8_training,
|
||||||
|
)
|
||||||
|
|
||||||
from minigpt4.common.dist_utils import download_cached_file, is_dist_avail_and_initialized
|
from minigpt4.common.dist_utils import download_cached_file, is_dist_avail_and_initialized
|
||||||
from minigpt4.common.utils import get_abs_path, is_url
|
from minigpt4.common.utils import get_abs_path, is_url
|
||||||
from omegaconf import OmegaConf
|
from minigpt4.models.eva_vit import create_eva_vit_g
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class BaseModel(nn.Module):
|
class BaseModel(nn.Module):
|
||||||
@ -117,131 +128,121 @@ class BaseModel(nn.Module):
|
|||||||
else:
|
else:
|
||||||
return tot
|
return tot
|
||||||
|
|
||||||
|
def maybe_autocast(self, dtype=torch.float16):
|
||||||
|
# if on cpu, don't use autocast
|
||||||
|
# if on gpu, use autocast with dtype if provided, otherwise use torch.float16
|
||||||
|
enable_autocast = self.device != torch.device("cpu")
|
||||||
|
|
||||||
class BaseEncoder(nn.Module):
|
if enable_autocast:
|
||||||
"""
|
return torch.cuda.amp.autocast(dtype=dtype)
|
||||||
Base class for primitive encoders, such as ViT, TimeSformer, etc.
|
else:
|
||||||
"""
|
return contextlib.nullcontext()
|
||||||
|
|
||||||
def __init__(self):
|
@classmethod
|
||||||
super().__init__()
|
def init_vision_encoder(
|
||||||
|
cls, model_name, img_size, drop_path_rate, use_grad_checkpoint, precision, freeze
|
||||||
|
):
|
||||||
|
logging.info('Loading VIT')
|
||||||
|
|
||||||
def forward_features(self, samples, **kwargs):
|
assert model_name == "eva_clip_g", "vit model must be eva_clip_g for current version of MiniGPT-4"
|
||||||
raise NotImplementedError
|
if not freeze:
|
||||||
|
precision = "fp32" # fp16 is not for training
|
||||||
|
|
||||||
@property
|
visual_encoder = create_eva_vit_g(
|
||||||
def device(self):
|
img_size, drop_path_rate, use_grad_checkpoint, precision
|
||||||
return list(self.parameters())[0].device
|
)
|
||||||
|
|
||||||
|
ln_vision = LayerNorm(visual_encoder.num_features)
|
||||||
|
|
||||||
|
if freeze:
|
||||||
|
for name, param in visual_encoder.named_parameters():
|
||||||
|
param.requires_grad = False
|
||||||
|
visual_encoder = visual_encoder.eval()
|
||||||
|
visual_encoder.train = disabled_train
|
||||||
|
for name, param in ln_vision.named_parameters():
|
||||||
|
param.requires_grad = False
|
||||||
|
ln_vision = ln_vision.eval()
|
||||||
|
ln_vision.train = disabled_train
|
||||||
|
logging.info("freeze vision encoder")
|
||||||
|
|
||||||
|
logging.info('Loading VIT Done')
|
||||||
|
return visual_encoder, ln_vision
|
||||||
|
|
||||||
|
def init_llm(cls, llama_model_path, low_resource=False, low_res_device=0, lora_r=0,
|
||||||
|
lora_target_modules=["q_proj","v_proj"], **lora_kargs):
|
||||||
|
logging.info('Loading LLAMA')
|
||||||
|
llama_tokenizer = LlamaTokenizer.from_pretrained(llama_model_path, use_fast=False)
|
||||||
|
llama_tokenizer.pad_token = "$$"
|
||||||
|
|
||||||
|
if low_resource:
|
||||||
|
llama_model = LlamaForCausalLM.from_pretrained(
|
||||||
|
llama_model_path,
|
||||||
|
torch_dtype=torch.float16,
|
||||||
|
load_in_8bit=True,
|
||||||
|
device_map={'': low_res_device}
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
llama_model = LlamaForCausalLM.from_pretrained(
|
||||||
|
llama_model_path,
|
||||||
|
torch_dtype=torch.float16,
|
||||||
|
)
|
||||||
|
|
||||||
|
if lora_r > 0:
|
||||||
|
llama_model = prepare_model_for_int8_training(llama_model)
|
||||||
|
loraconfig = LoraConfig(
|
||||||
|
r=lora_r,
|
||||||
|
bias="none",
|
||||||
|
task_type="CAUSAL_LM",
|
||||||
|
target_modules=lora_target_modules,
|
||||||
|
**lora_kargs
|
||||||
|
)
|
||||||
|
llama_model = get_peft_model(llama_model, loraconfig)
|
||||||
|
|
||||||
|
llama_model.print_trainable_parameters()
|
||||||
|
|
||||||
|
else:
|
||||||
|
for name, param in llama_model.named_parameters():
|
||||||
|
param.requires_grad = False
|
||||||
|
logging.info('Loading LLAMA Done')
|
||||||
|
return llama_model, llama_tokenizer
|
||||||
|
|
||||||
|
|
||||||
class SharedQueueMixin:
|
def load_from_pretrained(self, url_or_filename):
|
||||||
@torch.no_grad()
|
if is_url(url_or_filename):
|
||||||
def _dequeue_and_enqueue(self, image_feat, text_feat, idxs=None):
|
cached_file = download_cached_file(
|
||||||
# gather keys before updating queue
|
url_or_filename, check_hash=False, progress=True
|
||||||
image_feats = concat_all_gather(image_feat)
|
)
|
||||||
text_feats = concat_all_gather(text_feat)
|
checkpoint = torch.load(cached_file, map_location="cpu")
|
||||||
|
elif os.path.isfile(url_or_filename):
|
||||||
|
checkpoint = torch.load(url_or_filename, map_location="cpu")
|
||||||
|
else:
|
||||||
|
raise RuntimeError("checkpoint url or path is invalid")
|
||||||
|
|
||||||
batch_size = image_feats.shape[0]
|
state_dict = checkpoint["model"]
|
||||||
|
|
||||||
ptr = int(self.queue_ptr)
|
msg = self.load_state_dict(state_dict, strict=False)
|
||||||
assert self.queue_size % batch_size == 0 # for simplicity
|
|
||||||
|
|
||||||
# replace the keys at ptr (dequeue and enqueue)
|
# logging.info("Missing keys {}".format(msg.missing_keys))
|
||||||
self.image_queue[:, ptr : ptr + batch_size] = image_feats.T
|
logging.info("load checkpoint from %s" % url_or_filename)
|
||||||
self.text_queue[:, ptr : ptr + batch_size] = text_feats.T
|
|
||||||
|
|
||||||
if idxs is not None:
|
return msg
|
||||||
idxs = concat_all_gather(idxs)
|
|
||||||
self.idx_queue[:, ptr : ptr + batch_size] = idxs.T
|
|
||||||
|
|
||||||
ptr = (ptr + batch_size) % self.queue_size # move pointer
|
|
||||||
self.queue_ptr[0] = ptr
|
|
||||||
|
|
||||||
|
|
||||||
class MomentumDistilationMixin:
|
def disabled_train(self, mode=True):
|
||||||
@torch.no_grad()
|
"""Overwrite model.train with this function to make sure train/eval mode
|
||||||
def copy_params(self):
|
does not change anymore."""
|
||||||
for model_pair in self.model_pairs:
|
return self
|
||||||
for param, param_m in zip(
|
|
||||||
model_pair[0].parameters(), model_pair[1].parameters()
|
|
||||||
):
|
|
||||||
param_m.data.copy_(param.data) # initialize
|
|
||||||
param_m.requires_grad = False # not update by gradient
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def _momentum_update(self):
|
|
||||||
for model_pair in self.model_pairs:
|
|
||||||
for param, param_m in zip(
|
|
||||||
model_pair[0].parameters(), model_pair[1].parameters()
|
|
||||||
):
|
|
||||||
param_m.data = param_m.data * self.momentum + param.data * (
|
|
||||||
1.0 - self.momentum
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class GatherLayer(torch.autograd.Function):
|
class LayerNorm(nn.LayerNorm):
|
||||||
"""
|
"""Subclass torch's LayerNorm to handle fp16."""
|
||||||
Gather tensors from all workers with support for backward propagation:
|
|
||||||
This implementation does not cut the gradients as torch.distributed.all_gather does.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@staticmethod
|
def forward(self, x: torch.Tensor):
|
||||||
def forward(ctx, x):
|
orig_type = x.dtype
|
||||||
output = [
|
ret = super().forward(x.type(torch.float32))
|
||||||
torch.zeros_like(x) for _ in range(torch.distributed.get_world_size())
|
return ret.type(orig_type)
|
||||||
]
|
|
||||||
torch.distributed.all_gather(output, x)
|
|
||||||
return tuple(output)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def backward(ctx, *grads):
|
|
||||||
all_gradients = torch.stack(grads)
|
|
||||||
torch.distributed.all_reduce(all_gradients)
|
|
||||||
return all_gradients[torch.distributed.get_rank()]
|
|
||||||
|
|
||||||
|
|
||||||
def all_gather_with_grad(tensors):
|
|
||||||
"""
|
|
||||||
Performs all_gather operation on the provided tensors.
|
|
||||||
Graph remains connected for backward grad computation.
|
|
||||||
"""
|
|
||||||
# Queue the gathered tensors
|
|
||||||
world_size = torch.distributed.get_world_size()
|
|
||||||
# There is no need for reduction in the single-proc case
|
|
||||||
if world_size == 1:
|
|
||||||
return tensors
|
|
||||||
|
|
||||||
# tensor_all = GatherLayer.apply(tensors)
|
|
||||||
tensor_all = GatherLayer.apply(tensors)
|
|
||||||
|
|
||||||
return torch.cat(tensor_all, dim=0)
|
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def concat_all_gather(tensor):
|
|
||||||
"""
|
|
||||||
Performs all_gather operation on the provided tensors.
|
|
||||||
*** Warning ***: torch.distributed.all_gather has no gradient.
|
|
||||||
"""
|
|
||||||
# if use distributed training
|
|
||||||
if not is_dist_avail_and_initialized():
|
|
||||||
return tensor
|
|
||||||
|
|
||||||
tensors_gather = [
|
|
||||||
torch.ones_like(tensor) for _ in range(torch.distributed.get_world_size())
|
|
||||||
]
|
|
||||||
torch.distributed.all_gather(tensors_gather, tensor, async_op=False)
|
|
||||||
|
|
||||||
output = torch.cat(tensors_gather, dim=0)
|
|
||||||
return output
|
|
||||||
|
|
||||||
|
|
||||||
def tile(x, dim, n_tile):
|
|
||||||
init_dim = x.size(dim)
|
|
||||||
repeat_idx = [1] * x.dim()
|
|
||||||
repeat_idx[dim] = n_tile
|
|
||||||
x = x.repeat(*(repeat_idx))
|
|
||||||
order_index = torch.LongTensor(
|
|
||||||
np.concatenate([init_dim * np.arange(n_tile) + i for i in range(init_dim)])
|
|
||||||
)
|
|
||||||
return torch.index_select(x, dim, order_index.to(x.device))
|
|
||||||
|
@ -1,221 +0,0 @@
|
|||||||
"""
|
|
||||||
Copyright (c) 2023, salesforce.com, inc.
|
|
||||||
All rights reserved.
|
|
||||||
SPDX-License-Identifier: BSD-3-Clause
|
|
||||||
For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
|
||||||
"""
|
|
||||||
import contextlib
|
|
||||||
import logging
|
|
||||||
import os
|
|
||||||
import time
|
|
||||||
import datetime
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.distributed as dist
|
|
||||||
import torch.nn.functional as F
|
|
||||||
|
|
||||||
import minigpt4.common.dist_utils as dist_utils
|
|
||||||
from minigpt4.common.dist_utils import download_cached_file
|
|
||||||
from minigpt4.common.utils import is_url
|
|
||||||
from minigpt4.common.logger import MetricLogger
|
|
||||||
from minigpt4.models.base_model import BaseModel
|
|
||||||
from minigpt4.models.Qformer import BertConfig, BertLMHeadModel
|
|
||||||
from minigpt4.models.eva_vit import create_eva_vit_g
|
|
||||||
from transformers import BertTokenizer
|
|
||||||
|
|
||||||
|
|
||||||
class Blip2Base(BaseModel):
|
|
||||||
@classmethod
|
|
||||||
def init_tokenizer(cls):
|
|
||||||
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
|
|
||||||
tokenizer.add_special_tokens({"bos_token": "[DEC]"})
|
|
||||||
return tokenizer
|
|
||||||
|
|
||||||
def maybe_autocast(self, dtype=torch.float16):
|
|
||||||
# if on cpu, don't use autocast
|
|
||||||
# if on gpu, use autocast with dtype if provided, otherwise use torch.float16
|
|
||||||
enable_autocast = self.device != torch.device("cpu")
|
|
||||||
|
|
||||||
if enable_autocast:
|
|
||||||
return torch.cuda.amp.autocast(dtype=dtype)
|
|
||||||
else:
|
|
||||||
return contextlib.nullcontext()
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def init_Qformer(cls, num_query_token, vision_width, cross_attention_freq=2):
|
|
||||||
encoder_config = BertConfig.from_pretrained("bert-base-uncased")
|
|
||||||
encoder_config.encoder_width = vision_width
|
|
||||||
# insert cross-attention layer every other block
|
|
||||||
encoder_config.add_cross_attention = True
|
|
||||||
encoder_config.cross_attention_freq = cross_attention_freq
|
|
||||||
encoder_config.query_length = num_query_token
|
|
||||||
Qformer = BertLMHeadModel(config=encoder_config)
|
|
||||||
query_tokens = nn.Parameter(
|
|
||||||
torch.zeros(1, num_query_token, encoder_config.hidden_size)
|
|
||||||
)
|
|
||||||
query_tokens.data.normal_(mean=0.0, std=encoder_config.initializer_range)
|
|
||||||
return Qformer, query_tokens
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def init_vision_encoder(
|
|
||||||
cls, model_name, img_size, drop_path_rate, use_grad_checkpoint, precision
|
|
||||||
):
|
|
||||||
assert model_name == "eva_clip_g", "vit model must be eva_clip_g for current version of MiniGPT-4"
|
|
||||||
visual_encoder = create_eva_vit_g(
|
|
||||||
img_size, drop_path_rate, use_grad_checkpoint, precision
|
|
||||||
)
|
|
||||||
|
|
||||||
ln_vision = LayerNorm(visual_encoder.num_features)
|
|
||||||
return visual_encoder, ln_vision
|
|
||||||
|
|
||||||
def load_from_pretrained(self, url_or_filename):
|
|
||||||
if is_url(url_or_filename):
|
|
||||||
cached_file = download_cached_file(
|
|
||||||
url_or_filename, check_hash=False, progress=True
|
|
||||||
)
|
|
||||||
checkpoint = torch.load(cached_file, map_location="cpu")
|
|
||||||
elif os.path.isfile(url_or_filename):
|
|
||||||
checkpoint = torch.load(url_or_filename, map_location="cpu")
|
|
||||||
else:
|
|
||||||
raise RuntimeError("checkpoint url or path is invalid")
|
|
||||||
|
|
||||||
state_dict = checkpoint["model"]
|
|
||||||
|
|
||||||
msg = self.load_state_dict(state_dict, strict=False)
|
|
||||||
|
|
||||||
# logging.info("Missing keys {}".format(msg.missing_keys))
|
|
||||||
logging.info("load checkpoint from %s" % url_or_filename)
|
|
||||||
|
|
||||||
return msg
|
|
||||||
|
|
||||||
|
|
||||||
def disabled_train(self, mode=True):
|
|
||||||
"""Overwrite model.train with this function to make sure train/eval mode
|
|
||||||
does not change anymore."""
|
|
||||||
return self
|
|
||||||
|
|
||||||
|
|
||||||
class LayerNorm(nn.LayerNorm):
|
|
||||||
"""Subclass torch's LayerNorm to handle fp16."""
|
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor):
|
|
||||||
orig_type = x.dtype
|
|
||||||
ret = super().forward(x.type(torch.float32))
|
|
||||||
return ret.type(orig_type)
|
|
||||||
|
|
||||||
|
|
||||||
def compute_sim_matrix(model, data_loader, **kwargs):
|
|
||||||
k_test = kwargs.pop("k_test")
|
|
||||||
|
|
||||||
metric_logger = MetricLogger(delimiter=" ")
|
|
||||||
header = "Evaluation:"
|
|
||||||
|
|
||||||
logging.info("Computing features for evaluation...")
|
|
||||||
start_time = time.time()
|
|
||||||
|
|
||||||
texts = data_loader.dataset.text
|
|
||||||
num_text = len(texts)
|
|
||||||
text_bs = 256
|
|
||||||
text_ids = []
|
|
||||||
text_embeds = []
|
|
||||||
text_atts = []
|
|
||||||
for i in range(0, num_text, text_bs):
|
|
||||||
text = texts[i : min(num_text, i + text_bs)]
|
|
||||||
text_input = model.tokenizer(
|
|
||||||
text,
|
|
||||||
padding="max_length",
|
|
||||||
truncation=True,
|
|
||||||
max_length=35,
|
|
||||||
return_tensors="pt",
|
|
||||||
).to(model.device)
|
|
||||||
text_feat = model.forward_text(text_input)
|
|
||||||
text_embed = F.normalize(model.text_proj(text_feat))
|
|
||||||
text_embeds.append(text_embed)
|
|
||||||
text_ids.append(text_input.input_ids)
|
|
||||||
text_atts.append(text_input.attention_mask)
|
|
||||||
|
|
||||||
text_embeds = torch.cat(text_embeds, dim=0)
|
|
||||||
text_ids = torch.cat(text_ids, dim=0)
|
|
||||||
text_atts = torch.cat(text_atts, dim=0)
|
|
||||||
|
|
||||||
vit_feats = []
|
|
||||||
image_embeds = []
|
|
||||||
for samples in data_loader:
|
|
||||||
image = samples["image"]
|
|
||||||
|
|
||||||
image = image.to(model.device)
|
|
||||||
image_feat, vit_feat = model.forward_image(image)
|
|
||||||
image_embed = model.vision_proj(image_feat)
|
|
||||||
image_embed = F.normalize(image_embed, dim=-1)
|
|
||||||
|
|
||||||
vit_feats.append(vit_feat.cpu())
|
|
||||||
image_embeds.append(image_embed)
|
|
||||||
|
|
||||||
vit_feats = torch.cat(vit_feats, dim=0)
|
|
||||||
image_embeds = torch.cat(image_embeds, dim=0)
|
|
||||||
|
|
||||||
sims_matrix = []
|
|
||||||
for image_embed in image_embeds:
|
|
||||||
sim_q2t = image_embed @ text_embeds.t()
|
|
||||||
sim_i2t, _ = sim_q2t.max(0)
|
|
||||||
sims_matrix.append(sim_i2t)
|
|
||||||
sims_matrix = torch.stack(sims_matrix, dim=0)
|
|
||||||
|
|
||||||
score_matrix_i2t = torch.full(
|
|
||||||
(len(data_loader.dataset.image), len(texts)), -100.0
|
|
||||||
).to(model.device)
|
|
||||||
|
|
||||||
num_tasks = dist_utils.get_world_size()
|
|
||||||
rank = dist_utils.get_rank()
|
|
||||||
step = sims_matrix.size(0) // num_tasks + 1
|
|
||||||
start = rank * step
|
|
||||||
end = min(sims_matrix.size(0), start + step)
|
|
||||||
|
|
||||||
for i, sims in enumerate(
|
|
||||||
metric_logger.log_every(sims_matrix[start:end], 50, header)
|
|
||||||
):
|
|
||||||
topk_sim, topk_idx = sims.topk(k=k_test, dim=0)
|
|
||||||
image_inputs = vit_feats[start + i].repeat(k_test, 1, 1).to(model.device)
|
|
||||||
score = model.compute_itm(
|
|
||||||
image_inputs=image_inputs,
|
|
||||||
text_ids=text_ids[topk_idx],
|
|
||||||
text_atts=text_atts[topk_idx],
|
|
||||||
).float()
|
|
||||||
score_matrix_i2t[start + i, topk_idx] = score + topk_sim
|
|
||||||
|
|
||||||
sims_matrix = sims_matrix.t()
|
|
||||||
score_matrix_t2i = torch.full(
|
|
||||||
(len(texts), len(data_loader.dataset.image)), -100.0
|
|
||||||
).to(model.device)
|
|
||||||
|
|
||||||
step = sims_matrix.size(0) // num_tasks + 1
|
|
||||||
start = rank * step
|
|
||||||
end = min(sims_matrix.size(0), start + step)
|
|
||||||
|
|
||||||
for i, sims in enumerate(
|
|
||||||
metric_logger.log_every(sims_matrix[start:end], 50, header)
|
|
||||||
):
|
|
||||||
topk_sim, topk_idx = sims.topk(k=k_test, dim=0)
|
|
||||||
image_inputs = vit_feats[topk_idx.cpu()].to(model.device)
|
|
||||||
score = model.compute_itm(
|
|
||||||
image_inputs=image_inputs,
|
|
||||||
text_ids=text_ids[start + i].repeat(k_test, 1),
|
|
||||||
text_atts=text_atts[start + i].repeat(k_test, 1),
|
|
||||||
).float()
|
|
||||||
score_matrix_t2i[start + i, topk_idx] = score + topk_sim
|
|
||||||
|
|
||||||
if dist_utils.is_dist_avail_and_initialized():
|
|
||||||
dist.barrier()
|
|
||||||
torch.distributed.all_reduce(
|
|
||||||
score_matrix_i2t, op=torch.distributed.ReduceOp.SUM
|
|
||||||
)
|
|
||||||
torch.distributed.all_reduce(
|
|
||||||
score_matrix_t2i, op=torch.distributed.ReduceOp.SUM
|
|
||||||
)
|
|
||||||
|
|
||||||
total_time = time.time() - start_time
|
|
||||||
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
|
||||||
logging.info("Evaluation time {}".format(total_time_str))
|
|
||||||
|
|
||||||
return score_matrix_i2t.cpu().numpy(), score_matrix_t2i.cpu().numpy()
|
|
@ -1,110 +0,0 @@
|
|||||||
"""
|
|
||||||
Copyright (c) 2022, salesforce.com, inc.
|
|
||||||
All rights reserved.
|
|
||||||
SPDX-License-Identifier: BSD-3-Clause
|
|
||||||
For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
|
||||||
"""
|
|
||||||
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from transformers.modeling_outputs import (
|
|
||||||
ModelOutput,
|
|
||||||
BaseModelOutputWithPoolingAndCrossAttentions,
|
|
||||||
CausalLMOutputWithCrossAttentions,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class BlipSimilarity(ModelOutput):
|
|
||||||
sim_i2t: torch.FloatTensor = None
|
|
||||||
sim_t2i: torch.FloatTensor = None
|
|
||||||
|
|
||||||
sim_i2t_m: Optional[torch.FloatTensor] = None
|
|
||||||
sim_t2i_m: Optional[torch.FloatTensor] = None
|
|
||||||
|
|
||||||
sim_i2t_targets: Optional[torch.FloatTensor] = None
|
|
||||||
sim_t2i_targets: Optional[torch.FloatTensor] = None
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class BlipIntermediateOutput(ModelOutput):
|
|
||||||
"""
|
|
||||||
Data class for intermediate outputs of BLIP models.
|
|
||||||
|
|
||||||
image_embeds (torch.FloatTensor): Image embeddings, shape (batch_size, num_patches, embed_dim).
|
|
||||||
text_embeds (torch.FloatTensor): Text embeddings, shape (batch_size, seq_len, embed_dim).
|
|
||||||
|
|
||||||
image_embeds_m (torch.FloatTensor): Image embeddings from momentum visual encoder, shape (batch_size, num_patches, embed_dim).
|
|
||||||
text_embeds_m (torch.FloatTensor): Text embeddings from momentum text encoder, shape (batch_size, seq_len, embed_dim).
|
|
||||||
|
|
||||||
encoder_output (BaseModelOutputWithPoolingAndCrossAttentions): output from the image-grounded text encoder.
|
|
||||||
encoder_output_neg (BaseModelOutputWithPoolingAndCrossAttentions): output from the image-grounded text encoder for negative pairs.
|
|
||||||
|
|
||||||
decoder_output (CausalLMOutputWithCrossAttentions): output from the image-grounded text decoder.
|
|
||||||
decoder_labels (torch.LongTensor): labels for the captioning loss.
|
|
||||||
|
|
||||||
itm_logits (torch.FloatTensor): logits for the image-text matching loss, shape (batch_size * 3, 2).
|
|
||||||
itm_labels (torch.LongTensor): labels for the image-text matching loss, shape (batch_size * 3,)
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
# uni-modal features
|
|
||||||
image_embeds: torch.FloatTensor = None
|
|
||||||
text_embeds: Optional[torch.FloatTensor] = None
|
|
||||||
|
|
||||||
image_embeds_m: Optional[torch.FloatTensor] = None
|
|
||||||
text_embeds_m: Optional[torch.FloatTensor] = None
|
|
||||||
|
|
||||||
# intermediate outputs of multimodal encoder
|
|
||||||
encoder_output: Optional[BaseModelOutputWithPoolingAndCrossAttentions] = None
|
|
||||||
encoder_output_neg: Optional[BaseModelOutputWithPoolingAndCrossAttentions] = None
|
|
||||||
|
|
||||||
itm_logits: Optional[torch.FloatTensor] = None
|
|
||||||
itm_labels: Optional[torch.LongTensor] = None
|
|
||||||
|
|
||||||
# intermediate outputs of multimodal decoder
|
|
||||||
decoder_output: Optional[CausalLMOutputWithCrossAttentions] = None
|
|
||||||
decoder_labels: Optional[torch.LongTensor] = None
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class BlipOutput(ModelOutput):
|
|
||||||
# some finetuned models (e.g. BlipVQA) do not compute similarity, thus optional.
|
|
||||||
sims: Optional[BlipSimilarity] = None
|
|
||||||
|
|
||||||
intermediate_output: BlipIntermediateOutput = None
|
|
||||||
|
|
||||||
loss: Optional[torch.FloatTensor] = None
|
|
||||||
|
|
||||||
loss_itc: Optional[torch.FloatTensor] = None
|
|
||||||
|
|
||||||
loss_itm: Optional[torch.FloatTensor] = None
|
|
||||||
|
|
||||||
loss_lm: Optional[torch.FloatTensor] = None
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class BlipOutputFeatures(ModelOutput):
|
|
||||||
"""
|
|
||||||
Data class of features from BlipFeatureExtractor.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
image_embeds: (torch.FloatTensor) of shape (batch_size, num_patches+1, embed_dim), optional
|
|
||||||
image_features: (torch.FloatTensor) of shape (batch_size, num_patches+1, feature_dim), optional
|
|
||||||
text_embeds: (torch.FloatTensor) of shape (batch_size, sequence_length+1, embed_dim), optional
|
|
||||||
text_features: (torch.FloatTensor) of shape (batch_size, sequence_length+1, feature_dim), optional
|
|
||||||
|
|
||||||
The first embedding or feature is for the [CLS] token.
|
|
||||||
|
|
||||||
Features are obtained by projecting the corresponding embedding into a normalized low-dimensional space.
|
|
||||||
"""
|
|
||||||
|
|
||||||
image_embeds: Optional[torch.FloatTensor] = None
|
|
||||||
image_embeds_proj: Optional[torch.FloatTensor] = None
|
|
||||||
|
|
||||||
text_embeds: Optional[torch.FloatTensor] = None
|
|
||||||
text_embeds_proj: Optional[torch.FloatTensor] = None
|
|
||||||
|
|
||||||
multimodal_embeds: Optional[torch.FloatTensor] = None
|
|
@ -1,384 +0,0 @@
|
|||||||
import logging
|
|
||||||
import random
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from torch.cuda.amp import autocast as autocast
|
|
||||||
import torch.nn as nn
|
|
||||||
|
|
||||||
from minigpt4.common.registry import registry
|
|
||||||
from minigpt4.models.blip2 import Blip2Base, disabled_train
|
|
||||||
from transformers.models.llama.modeling_llama import LlamaForCausalLM
|
|
||||||
from transformers import LlamaTokenizer
|
|
||||||
|
|
||||||
from peft import (
|
|
||||||
LoraConfig,
|
|
||||||
get_peft_model,
|
|
||||||
get_peft_model_state_dict,
|
|
||||||
prepare_model_for_int8_training,
|
|
||||||
set_peft_model_state_dict,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@registry.register_model("mini_gpt4")
|
|
||||||
class MiniGPT4(Blip2Base):
|
|
||||||
"""
|
|
||||||
BLIP2 GPT-LLAMA model.
|
|
||||||
"""
|
|
||||||
|
|
||||||
PRETRAINED_MODEL_CONFIG_DICT = {
|
|
||||||
"pretrain_vicuna0": "configs/models/minigpt4_vicuna0.yaml",
|
|
||||||
"pretrain_llama2": "configs/models/minigpt4_llama2.yaml",
|
|
||||||
}
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
vit_model="eva_clip_g",
|
|
||||||
q_former_model="https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/blip2_pretrained_flant5xxl.pth",
|
|
||||||
img_size=224,
|
|
||||||
drop_path_rate=0,
|
|
||||||
use_grad_checkpoint=False,
|
|
||||||
vit_precision="fp16",
|
|
||||||
freeze_vit=True,
|
|
||||||
has_qformer=True,
|
|
||||||
freeze_qformer=True,
|
|
||||||
num_query_token=32,
|
|
||||||
llama_model="",
|
|
||||||
prompt_path="",
|
|
||||||
prompt_template="",
|
|
||||||
max_txt_len=32,
|
|
||||||
end_sym='\n',
|
|
||||||
low_resource=False, # use 8 bit and put vit in cpu
|
|
||||||
device_8bit=0, # the device of 8bit model should be set when loading and cannot be changed anymore.
|
|
||||||
lora_r=0,
|
|
||||||
lora_target_modules=["q_proj", "v_proj"],
|
|
||||||
lora_alpha=16,
|
|
||||||
lora_dropout=0.05,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self.tokenizer = self.init_tokenizer()
|
|
||||||
self.low_resource = low_resource
|
|
||||||
|
|
||||||
print('Loading VIT')
|
|
||||||
self.visual_encoder, self.ln_vision = self.init_vision_encoder(
|
|
||||||
vit_model, img_size, drop_path_rate, use_grad_checkpoint, vit_precision
|
|
||||||
)
|
|
||||||
if freeze_vit:
|
|
||||||
for name, param in self.visual_encoder.named_parameters():
|
|
||||||
param.requires_grad = False
|
|
||||||
self.visual_encoder = self.visual_encoder.eval()
|
|
||||||
self.visual_encoder.train = disabled_train
|
|
||||||
for name, param in self.ln_vision.named_parameters():
|
|
||||||
param.requires_grad = False
|
|
||||||
self.ln_vision = self.ln_vision.eval()
|
|
||||||
self.ln_vision.train = disabled_train
|
|
||||||
logging.info("freeze vision encoder")
|
|
||||||
print('Loading VIT Done')
|
|
||||||
|
|
||||||
self.has_qformer = has_qformer
|
|
||||||
if self.has_qformer:
|
|
||||||
print('Loading Q-Former')
|
|
||||||
self.Qformer, self.query_tokens = self.init_Qformer(
|
|
||||||
num_query_token, self.visual_encoder.num_features
|
|
||||||
)
|
|
||||||
self.Qformer.cls = None
|
|
||||||
self.Qformer.bert.embeddings.word_embeddings = None
|
|
||||||
self.Qformer.bert.embeddings.position_embeddings = None
|
|
||||||
for layer in self.Qformer.bert.encoder.layer:
|
|
||||||
layer.output = None
|
|
||||||
layer.intermediate = None
|
|
||||||
self.load_from_pretrained(url_or_filename=q_former_model)
|
|
||||||
|
|
||||||
if freeze_qformer:
|
|
||||||
for name, param in self.Qformer.named_parameters():
|
|
||||||
param.requires_grad = False
|
|
||||||
self.Qformer = self.Qformer.eval()
|
|
||||||
self.Qformer.train = disabled_train
|
|
||||||
self.query_tokens.requires_grad = False
|
|
||||||
logging.info("freeze Qformer")
|
|
||||||
|
|
||||||
img_f_dim = self.Qformer.config.hidden_size
|
|
||||||
print('Loading Q-Former Done')
|
|
||||||
else:
|
|
||||||
img_f_dim = self.visual_encoder.num_features * 4
|
|
||||||
print('Do not use Q-Former here.')
|
|
||||||
|
|
||||||
print('Loading LLAMA')
|
|
||||||
self.llama_tokenizer = LlamaTokenizer.from_pretrained(llama_model, use_fast=False)
|
|
||||||
self.llama_tokenizer.pad_token = "$$"
|
|
||||||
|
|
||||||
if self.low_resource:
|
|
||||||
self.llama_model = LlamaForCausalLM.from_pretrained(
|
|
||||||
llama_model,
|
|
||||||
torch_dtype=torch.float16,
|
|
||||||
load_in_8bit=True,
|
|
||||||
device_map={'': device_8bit}
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.llama_model = LlamaForCausalLM.from_pretrained(
|
|
||||||
llama_model,
|
|
||||||
torch_dtype=torch.float16,
|
|
||||||
)
|
|
||||||
|
|
||||||
if lora_r > 0:
|
|
||||||
self.llama_model = prepare_model_for_int8_training(self.llama_model)
|
|
||||||
loraconfig = LoraConfig(
|
|
||||||
r=lora_r,
|
|
||||||
lora_alpha=lora_alpha,
|
|
||||||
target_modules=lora_target_modules,
|
|
||||||
lora_dropout=lora_dropout,
|
|
||||||
bias="none",
|
|
||||||
task_type="CAUSAL_LM"
|
|
||||||
)
|
|
||||||
self.llama_model = get_peft_model(self.llama_model, loraconfig)
|
|
||||||
|
|
||||||
# if ckpt_path:
|
|
||||||
# print('load the llm under lora')
|
|
||||||
# ckpt = torch.load(ckpt_path)
|
|
||||||
# set_peft_model_state_dict(self.llama_model,ckpt)
|
|
||||||
self.llama_model.print_trainable_parameters()
|
|
||||||
|
|
||||||
else:
|
|
||||||
for name, param in self.llama_model.named_parameters():
|
|
||||||
param.requires_grad = False
|
|
||||||
print('Loading LLAMA Done')
|
|
||||||
|
|
||||||
self.llama_proj = nn.Linear(
|
|
||||||
img_f_dim, self.llama_model.config.hidden_size
|
|
||||||
)
|
|
||||||
self.max_txt_len = max_txt_len
|
|
||||||
self.end_sym = end_sym
|
|
||||||
|
|
||||||
if prompt_path:
|
|
||||||
with open(prompt_path, 'r') as f:
|
|
||||||
raw_prompts = f.read().splitlines()
|
|
||||||
filted_prompts = [raw_prompt for raw_prompt in raw_prompts if "<ImageHere>" in raw_prompt]
|
|
||||||
self.prompt_list = [prompt_template.format(p) for p in filted_prompts]
|
|
||||||
print('Load {} training prompts'.format(len(self.prompt_list)))
|
|
||||||
print('Prompt Example \n{}'.format(random.choice(self.prompt_list)))
|
|
||||||
else:
|
|
||||||
self.prompt_list = []
|
|
||||||
|
|
||||||
def vit_to_cpu(self):
|
|
||||||
self.ln_vision.to("cpu")
|
|
||||||
self.ln_vision.float()
|
|
||||||
self.visual_encoder.to("cpu")
|
|
||||||
self.visual_encoder.float()
|
|
||||||
|
|
||||||
def encode_img(self, image):
|
|
||||||
device = image.device
|
|
||||||
if self.low_resource:
|
|
||||||
self.vit_to_cpu()
|
|
||||||
image = image.to("cpu")
|
|
||||||
|
|
||||||
with self.maybe_autocast():
|
|
||||||
image_embeds = self.ln_vision(self.visual_encoder(image)).to(device)
|
|
||||||
if self.has_qformer:
|
|
||||||
image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(device)
|
|
||||||
|
|
||||||
query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
|
|
||||||
query_output = self.Qformer.bert(
|
|
||||||
query_embeds=query_tokens,
|
|
||||||
encoder_hidden_states=image_embeds,
|
|
||||||
encoder_attention_mask=image_atts,
|
|
||||||
return_dict=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
inputs_llama = self.llama_proj(query_output.last_hidden_state)
|
|
||||||
else:
|
|
||||||
image_embeds = image_embeds[:, 1:, :]
|
|
||||||
bs, pn, hs = image_embeds.shape
|
|
||||||
image_embeds = image_embeds.view(bs, int(pn / 4), int(hs * 4))
|
|
||||||
|
|
||||||
inputs_llama = self.llama_proj(image_embeds)
|
|
||||||
atts_llama = torch.ones(inputs_llama.size()[:-1], dtype=torch.long).to(image.device)
|
|
||||||
return inputs_llama, atts_llama
|
|
||||||
|
|
||||||
def get_context_emb(self, prompt, img_list):
|
|
||||||
device = img_list[0].device
|
|
||||||
prompt_segs = prompt.split('<ImageHere>')
|
|
||||||
assert len(prompt_segs) == len(img_list) + 1, "Unmatched numbers of image placeholders and images."
|
|
||||||
seg_tokens = [
|
|
||||||
self.llama_tokenizer(
|
|
||||||
seg, return_tensors="pt", add_special_tokens=i == 0).to(device).input_ids
|
|
||||||
# only add bos to the first seg
|
|
||||||
for i, seg in enumerate(prompt_segs)
|
|
||||||
]
|
|
||||||
seg_embs = [self.embed_tokens(seg_t) for seg_t in seg_tokens]
|
|
||||||
|
|
||||||
mixed_embs = [emb for pair in zip(seg_embs[:-1], img_list) for emb in pair] + [seg_embs[-1]]
|
|
||||||
mixed_embs = torch.cat(mixed_embs, dim=1)
|
|
||||||
return mixed_embs
|
|
||||||
|
|
||||||
def prompt_wrap(self, img_embeds, atts_img, prompts):
|
|
||||||
if prompts:
|
|
||||||
emb_lists = []
|
|
||||||
if isinstance(prompts, str):
|
|
||||||
prompts = [prompts] * len(img_embeds)
|
|
||||||
|
|
||||||
for each_img_embed, each_prompt in zip(img_embeds, prompts):
|
|
||||||
p_before, p_after = each_prompt.split('<ImageHere>')
|
|
||||||
|
|
||||||
p_before_tokens = self.llama_tokenizer(
|
|
||||||
p_before, return_tensors="pt", add_special_tokens=False).to(img_embeds.device)
|
|
||||||
p_after_tokens = self.llama_tokenizer(
|
|
||||||
p_after, return_tensors="pt", add_special_tokens=False).to(img_embeds.device)
|
|
||||||
p_before_embed = self.embed_tokens(p_before_tokens.input_ids)
|
|
||||||
p_after_embed = self.embed_tokens(p_after_tokens.input_ids)
|
|
||||||
wrapped_emb = torch.cat([p_before_embed, each_img_embed[None], p_after_embed], dim=1)
|
|
||||||
emb_lists.append(wrapped_emb)
|
|
||||||
emb_lens = [emb.shape[1] for emb in emb_lists]
|
|
||||||
pad_emb = self.embed_tokens(torch.tensor(self.llama_tokenizer.pad_token_id, device=img_embeds.device))
|
|
||||||
wrapped_embs = pad_emb.expand(len(emb_lens), max(emb_lens), -1).clone()
|
|
||||||
wrapped_atts = torch.zeros([len(emb_lens), max(emb_lens)], dtype=torch.int, device=img_embeds.device)
|
|
||||||
for i, emb in enumerate(emb_lists):
|
|
||||||
wrapped_embs[i, :emb_lens[i]] = emb
|
|
||||||
wrapped_atts[i, :emb_lens[i]] = 1
|
|
||||||
return wrapped_embs, wrapped_atts
|
|
||||||
else:
|
|
||||||
return img_embeds, atts_img
|
|
||||||
|
|
||||||
def concat_emb_input_output(self, input_embs, input_atts, output_embs, output_atts):
|
|
||||||
input_lens = []
|
|
||||||
cat_embs = []
|
|
||||||
cat_atts = []
|
|
||||||
for i in range(input_embs.size(0)):
|
|
||||||
input_len = input_atts[i].sum()
|
|
||||||
input_lens.append(input_len)
|
|
||||||
cat_embs.append(
|
|
||||||
torch.cat([
|
|
||||||
input_embs[i][:input_len],
|
|
||||||
output_embs[i],
|
|
||||||
input_embs[i][input_len:]
|
|
||||||
])
|
|
||||||
)
|
|
||||||
cat_atts.append(
|
|
||||||
torch.cat([
|
|
||||||
input_atts[i][:input_len],
|
|
||||||
output_atts[i],
|
|
||||||
input_atts[i][input_len:]
|
|
||||||
])
|
|
||||||
)
|
|
||||||
cat_embs = torch.stack(cat_embs)
|
|
||||||
cat_atts = torch.stack(cat_atts)
|
|
||||||
return cat_embs, cat_atts, input_lens
|
|
||||||
|
|
||||||
def forward(self, samples):
|
|
||||||
image = samples["image"]
|
|
||||||
img_embeds, atts_img = self.encode_img(image)
|
|
||||||
|
|
||||||
if self.prompt_list:
|
|
||||||
instruction = random.choice(self.prompt_list)
|
|
||||||
else:
|
|
||||||
instruction = samples["instruction_input"] if "instruction_input" in samples else None
|
|
||||||
|
|
||||||
img_embeds, atts_img = self.prompt_wrap(img_embeds, atts_img, instruction)
|
|
||||||
|
|
||||||
self.llama_tokenizer.padding_side = "right"
|
|
||||||
text = [t + self.end_sym for t in samples["answer"]]
|
|
||||||
|
|
||||||
to_regress_tokens = self.llama_tokenizer(
|
|
||||||
text,
|
|
||||||
return_tensors="pt",
|
|
||||||
padding="longest",
|
|
||||||
truncation=True,
|
|
||||||
max_length=self.max_txt_len,
|
|
||||||
add_special_tokens=False
|
|
||||||
).to(image.device)
|
|
||||||
|
|
||||||
batch_size = img_embeds.shape[0]
|
|
||||||
bos = torch.ones([batch_size, 1],
|
|
||||||
dtype=to_regress_tokens.input_ids.dtype,
|
|
||||||
device=to_regress_tokens.input_ids.device) * self.llama_tokenizer.bos_token_id
|
|
||||||
bos_embeds = self.embed_tokens(bos)
|
|
||||||
atts_bos = atts_img[:, :1]
|
|
||||||
|
|
||||||
to_regress_embeds = self.embed_tokens(to_regress_tokens.input_ids)
|
|
||||||
inputs_embeds, attention_mask, input_lens = \
|
|
||||||
self.concat_emb_input_output(img_embeds, atts_img, to_regress_embeds, to_regress_tokens.attention_mask)
|
|
||||||
inputs_embeds = torch.cat([bos_embeds, inputs_embeds], dim=1)
|
|
||||||
attention_mask = torch.cat([atts_bos, attention_mask], dim=1)
|
|
||||||
|
|
||||||
part_targets = to_regress_tokens.input_ids.masked_fill(
|
|
||||||
to_regress_tokens.input_ids == self.llama_tokenizer.pad_token_id, -100
|
|
||||||
)
|
|
||||||
targets = (
|
|
||||||
torch.ones([inputs_embeds.shape[0], inputs_embeds.shape[1]],
|
|
||||||
dtype=torch.long).to(image.device).fill_(-100)
|
|
||||||
)
|
|
||||||
|
|
||||||
for i, target in enumerate(part_targets):
|
|
||||||
targets[i, input_lens[i] + 1:input_lens[i] + len(target) + 1] = target # plus 1 for bos
|
|
||||||
|
|
||||||
with self.maybe_autocast():
|
|
||||||
outputs = self.llama_model(
|
|
||||||
inputs_embeds=inputs_embeds,
|
|
||||||
attention_mask=attention_mask,
|
|
||||||
return_dict=True,
|
|
||||||
labels=targets,
|
|
||||||
)
|
|
||||||
loss = outputs.loss
|
|
||||||
|
|
||||||
return {"loss": loss}
|
|
||||||
|
|
||||||
def embed_tokens(self, token_ids):
|
|
||||||
if hasattr(self.llama_model.base_model, 'model'): ## lora wrapped model
|
|
||||||
embeds = self.llama_model.base_model.model.model.embed_tokens(token_ids)
|
|
||||||
else:
|
|
||||||
embeds = self.llama_model.base_model.embed_tokens(token_ids)
|
|
||||||
return embeds
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_config(cls, cfg):
|
|
||||||
vit_model = cfg.get("vit_model", "eva_clip_g")
|
|
||||||
q_former_model = cfg.get("q_former_model", "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/blip2_pretrained_flant5xxl.pth")
|
|
||||||
img_size = cfg.get("image_size")
|
|
||||||
num_query_token = cfg.get("num_query_token")
|
|
||||||
llama_model = cfg.get("llama_model")
|
|
||||||
|
|
||||||
drop_path_rate = cfg.get("drop_path_rate", 0)
|
|
||||||
use_grad_checkpoint = cfg.get("use_grad_checkpoint", False)
|
|
||||||
vit_precision = cfg.get("vit_precision", "fp16")
|
|
||||||
freeze_vit = cfg.get("freeze_vit", True)
|
|
||||||
has_qformer = cfg.get("has_qformer", True)
|
|
||||||
freeze_qformer = cfg.get("freeze_qformer", True)
|
|
||||||
low_resource = cfg.get("low_resource", False)
|
|
||||||
device_8bit = cfg.get("device_8bit", 0)
|
|
||||||
|
|
||||||
prompt_path = cfg.get("prompt_path", "")
|
|
||||||
prompt_template = cfg.get("prompt_template", "")
|
|
||||||
max_txt_len = cfg.get("max_txt_len", 32)
|
|
||||||
end_sym = cfg.get("end_sym", '\n')
|
|
||||||
|
|
||||||
lora_r = cfg.get("lora_r", 0)
|
|
||||||
lora_alpha = cfg.get("lora_alpha", 32)
|
|
||||||
|
|
||||||
model = cls(
|
|
||||||
vit_model=vit_model,
|
|
||||||
q_former_model=q_former_model,
|
|
||||||
img_size=img_size,
|
|
||||||
drop_path_rate=drop_path_rate,
|
|
||||||
use_grad_checkpoint=use_grad_checkpoint,
|
|
||||||
vit_precision=vit_precision,
|
|
||||||
freeze_vit=freeze_vit,
|
|
||||||
has_qformer=has_qformer,
|
|
||||||
freeze_qformer=freeze_qformer,
|
|
||||||
num_query_token=num_query_token,
|
|
||||||
llama_model=llama_model,
|
|
||||||
prompt_path=prompt_path,
|
|
||||||
prompt_template=prompt_template,
|
|
||||||
max_txt_len=max_txt_len,
|
|
||||||
end_sym=end_sym,
|
|
||||||
low_resource=low_resource,
|
|
||||||
device_8bit=device_8bit,
|
|
||||||
lora_r=lora_r,
|
|
||||||
lora_alpha=lora_alpha,
|
|
||||||
)
|
|
||||||
|
|
||||||
ckpt_path = cfg.get("ckpt", "") # load weights of MiniGPT-4
|
|
||||||
if ckpt_path:
|
|
||||||
print("Load BLIP2-LLM Checkpoint: {}".format(ckpt_path))
|
|
||||||
ckpt = torch.load(ckpt_path, map_location="cpu")
|
|
||||||
msg = model.load_state_dict(ckpt['model'], strict=False)
|
|
||||||
|
|
||||||
return model
|
|
195
minigpt4/models/minigpt4.py
Normal file
@ -0,0 +1,195 @@
|
|||||||
|
import logging
|
||||||
|
import random
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch.cuda.amp import autocast as autocast
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from minigpt4.common.registry import registry
|
||||||
|
from minigpt4.models.base_model import disabled_train
|
||||||
|
from minigpt4.models.minigpt_base import MiniGPTBase
|
||||||
|
from minigpt4.models.Qformer import BertConfig, BertLMHeadModel
|
||||||
|
|
||||||
|
|
||||||
|
@registry.register_model("minigpt4")
|
||||||
|
class MiniGPT4(MiniGPTBase):
|
||||||
|
"""
|
||||||
|
MiniGPT-4 model
|
||||||
|
"""
|
||||||
|
|
||||||
|
PRETRAINED_MODEL_CONFIG_DICT = {
|
||||||
|
"pretrain_vicuna0": "configs/models/minigpt4_vicuna0.yaml",
|
||||||
|
"pretrain_llama2": "configs/models/minigpt4_llama2.yaml",
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vit_model="eva_clip_g",
|
||||||
|
q_former_model="https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/blip2_pretrained_flant5xxl.pth",
|
||||||
|
img_size=224,
|
||||||
|
drop_path_rate=0,
|
||||||
|
use_grad_checkpoint=False,
|
||||||
|
vit_precision="fp16",
|
||||||
|
freeze_vit=True,
|
||||||
|
has_qformer=True,
|
||||||
|
freeze_qformer=True,
|
||||||
|
num_query_token=32,
|
||||||
|
llama_model="",
|
||||||
|
prompt_path="",
|
||||||
|
prompt_template="",
|
||||||
|
max_txt_len=32,
|
||||||
|
end_sym='\n',
|
||||||
|
low_resource=False, # use 8 bit and put vit in cpu
|
||||||
|
device_8bit=0, # the device of 8bit model should be set when loading and cannot be changed anymore.
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
vit_model=vit_model,
|
||||||
|
img_size=img_size,
|
||||||
|
drop_path_rate=drop_path_rate,
|
||||||
|
use_grad_checkpoint=use_grad_checkpoint,
|
||||||
|
vit_precision=vit_precision,
|
||||||
|
freeze_vit=freeze_vit,
|
||||||
|
llama_model=llama_model,
|
||||||
|
max_txt_len=max_txt_len,
|
||||||
|
end_sym=end_sym,
|
||||||
|
low_resource=low_resource,
|
||||||
|
device_8bit=device_8bit,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.has_qformer = has_qformer
|
||||||
|
if self.has_qformer:
|
||||||
|
print('Loading Q-Former')
|
||||||
|
self.Qformer, self.query_tokens = self.init_Qformer(
|
||||||
|
num_query_token, self.visual_encoder.num_features, freeze_qformer
|
||||||
|
)
|
||||||
|
self.load_from_pretrained(url_or_filename=q_former_model) # load q-former weights here
|
||||||
|
|
||||||
|
img_f_dim = self.Qformer.config.hidden_size
|
||||||
|
print('Loading Q-Former Done')
|
||||||
|
else:
|
||||||
|
img_f_dim = self.visual_encoder.num_features * 4
|
||||||
|
print('Do not use Q-Former here.')
|
||||||
|
|
||||||
|
self.llama_proj = nn.Linear(
|
||||||
|
img_f_dim, self.llama_model.config.hidden_size
|
||||||
|
)
|
||||||
|
|
||||||
|
if prompt_path:
|
||||||
|
with open(prompt_path, 'r') as f:
|
||||||
|
raw_prompts = f.read().splitlines()
|
||||||
|
filted_prompts = [raw_prompt for raw_prompt in raw_prompts if "<ImageHere>" in raw_prompt]
|
||||||
|
self.prompt_list = [prompt_template.format(p) for p in filted_prompts]
|
||||||
|
print('Load {} training prompts'.format(len(self.prompt_list)))
|
||||||
|
print('Prompt Example \n{}'.format(random.choice(self.prompt_list)))
|
||||||
|
else:
|
||||||
|
self.prompt_list = []
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def init_Qformer(cls, num_query_token, vision_width, freeze):
|
||||||
|
encoder_config = BertConfig.from_pretrained("bert-base-uncased")
|
||||||
|
encoder_config.encoder_width = vision_width
|
||||||
|
# insert cross-attention layer every other block
|
||||||
|
encoder_config.add_cross_attention = True
|
||||||
|
encoder_config.cross_attention_freq = 2
|
||||||
|
encoder_config.query_length = num_query_token
|
||||||
|
Qformer = BertLMHeadModel(config=encoder_config)
|
||||||
|
query_tokens = nn.Parameter(
|
||||||
|
torch.zeros(1, num_query_token, encoder_config.hidden_size)
|
||||||
|
)
|
||||||
|
query_tokens.data.normal_(mean=0.0, std=encoder_config.initializer_range)
|
||||||
|
|
||||||
|
Qformer.cls = None
|
||||||
|
Qformer.bert.embeddings.word_embeddings = None
|
||||||
|
Qformer.bert.embeddings.position_embeddings = None
|
||||||
|
for layer in Qformer.bert.encoder.layer:
|
||||||
|
layer.output = None
|
||||||
|
layer.intermediate = None
|
||||||
|
|
||||||
|
if freeze:
|
||||||
|
for name, param in Qformer.named_parameters():
|
||||||
|
param.requires_grad = False
|
||||||
|
Qformer = Qformer.eval()
|
||||||
|
Qformer.train = disabled_train
|
||||||
|
query_tokens.requires_grad = False
|
||||||
|
logging.info("freeze Qformer")
|
||||||
|
|
||||||
|
return Qformer, query_tokens
|
||||||
|
|
||||||
|
def encode_img(self, image):
|
||||||
|
device = image.device
|
||||||
|
|
||||||
|
if len(image.shape) > 4:
|
||||||
|
image = image.reshape(-1, *image.shape[-3:])
|
||||||
|
|
||||||
|
with self.maybe_autocast():
|
||||||
|
image_embeds = self.ln_vision(self.visual_encoder(image)).to(device)
|
||||||
|
if self.has_qformer:
|
||||||
|
image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(device)
|
||||||
|
|
||||||
|
query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
|
||||||
|
query_output = self.Qformer.bert(
|
||||||
|
query_embeds=query_tokens,
|
||||||
|
encoder_hidden_states=image_embeds,
|
||||||
|
encoder_attention_mask=image_atts,
|
||||||
|
return_dict=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
inputs_llama = self.llama_proj(query_output.last_hidden_state)
|
||||||
|
else:
|
||||||
|
image_embeds = image_embeds[:, 1:, :]
|
||||||
|
bs, pn, hs = image_embeds.shape
|
||||||
|
image_embeds = image_embeds.view(bs, int(pn / 4), int(hs * 4))
|
||||||
|
|
||||||
|
inputs_llama = self.llama_proj(image_embeds)
|
||||||
|
atts_llama = torch.ones(inputs_llama.size()[:-1], dtype=torch.long).to(image.device)
|
||||||
|
return inputs_llama, atts_llama
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_config(cls, cfg):
|
||||||
|
vit_model = cfg.get("vit_model", "eva_clip_g")
|
||||||
|
q_former_model = cfg.get("q_former_model", "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/blip2_pretrained_flant5xxl.pth")
|
||||||
|
img_size = cfg.get("image_size")
|
||||||
|
num_query_token = cfg.get("num_query_token")
|
||||||
|
llama_model = cfg.get("llama_model")
|
||||||
|
|
||||||
|
drop_path_rate = cfg.get("drop_path_rate", 0)
|
||||||
|
use_grad_checkpoint = cfg.get("use_grad_checkpoint", False)
|
||||||
|
vit_precision = cfg.get("vit_precision", "fp16")
|
||||||
|
freeze_vit = cfg.get("freeze_vit", True)
|
||||||
|
has_qformer = cfg.get("has_qformer", True)
|
||||||
|
freeze_qformer = cfg.get("freeze_qformer", True)
|
||||||
|
low_resource = cfg.get("low_resource", False)
|
||||||
|
device_8bit = cfg.get("device_8bit", 0)
|
||||||
|
|
||||||
|
prompt_path = cfg.get("prompt_path", "")
|
||||||
|
prompt_template = cfg.get("prompt_template", "")
|
||||||
|
max_txt_len = cfg.get("max_txt_len", 32)
|
||||||
|
end_sym = cfg.get("end_sym", '\n')
|
||||||
|
|
||||||
|
model = cls(
|
||||||
|
vit_model=vit_model,
|
||||||
|
q_former_model=q_former_model,
|
||||||
|
img_size=img_size,
|
||||||
|
drop_path_rate=drop_path_rate,
|
||||||
|
use_grad_checkpoint=use_grad_checkpoint,
|
||||||
|
vit_precision=vit_precision,
|
||||||
|
freeze_vit=freeze_vit,
|
||||||
|
has_qformer=has_qformer,
|
||||||
|
freeze_qformer=freeze_qformer,
|
||||||
|
num_query_token=num_query_token,
|
||||||
|
llama_model=llama_model,
|
||||||
|
prompt_path=prompt_path,
|
||||||
|
prompt_template=prompt_template,
|
||||||
|
max_txt_len=max_txt_len,
|
||||||
|
end_sym=end_sym,
|
||||||
|
low_resource=low_resource,
|
||||||
|
device_8bit=device_8bit,
|
||||||
|
)
|
||||||
|
|
||||||
|
ckpt_path = cfg.get("ckpt", "") # load weights of MiniGPT-4
|
||||||
|
if ckpt_path:
|
||||||
|
print("Load MiniGPT-4 Checkpoint: {}".format(ckpt_path))
|
||||||
|
ckpt = torch.load(ckpt_path, map_location="cpu")
|
||||||
|
msg = model.load_state_dict(ckpt['model'], strict=False)
|
||||||
|
|
||||||
|
return model
|
401
minigpt4/models/minigpt_base.py
Normal file
@ -0,0 +1,401 @@
|
|||||||
|
import logging
|
||||||
|
import random
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch.cuda.amp import autocast as autocast
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from minigpt4.common.registry import registry
|
||||||
|
from minigpt4.models.base_model import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class MiniGPTBase(BaseModel):
|
||||||
|
"""
|
||||||
|
Base class for MiniGPT-4 and MiniGPT-v2
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vit_model="eva_clip_g",
|
||||||
|
img_size=224,
|
||||||
|
drop_path_rate=0,
|
||||||
|
use_grad_checkpoint=False,
|
||||||
|
vit_precision="fp16",
|
||||||
|
freeze_vit=True,
|
||||||
|
llama_model="",
|
||||||
|
max_txt_len=32,
|
||||||
|
max_context_len=3800,
|
||||||
|
prompt_template="",
|
||||||
|
end_sym='\n',
|
||||||
|
low_resource=False, # use 8 bit and put vit in cpu
|
||||||
|
device_8bit=0, # the device of 8bit model should be set when loading and cannot be changed anymore.
|
||||||
|
lora_r=0, # lora_r means lora is not used
|
||||||
|
lora_target_modules=["q_proj", "v_proj"],
|
||||||
|
lora_alpha=16,
|
||||||
|
lora_dropout=0.05,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.llama_model, self.llama_tokenizer = self.init_llm(
|
||||||
|
llama_model_path=llama_model,
|
||||||
|
low_resource=low_resource,
|
||||||
|
low_res_device=device_8bit,
|
||||||
|
lora_r=lora_r,
|
||||||
|
lora_target_modules=lora_target_modules,
|
||||||
|
lora_alpha=lora_alpha,
|
||||||
|
lora_dropout=lora_dropout,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.visual_encoder, self.ln_vision = self.init_vision_encoder(
|
||||||
|
vit_model, img_size, drop_path_rate, use_grad_checkpoint, vit_precision, freeze_vit
|
||||||
|
)
|
||||||
|
|
||||||
|
self.max_txt_len = max_txt_len
|
||||||
|
self.max_context_len = max_context_len
|
||||||
|
self.end_sym = end_sym
|
||||||
|
|
||||||
|
self.prompt_template = prompt_template
|
||||||
|
self.prompt_list = []
|
||||||
|
|
||||||
|
def vit_to_cpu(self):
|
||||||
|
self.ln_vision.to("cpu")
|
||||||
|
self.ln_vision.float()
|
||||||
|
self.visual_encoder.to("cpu")
|
||||||
|
self.visual_encoder.float()
|
||||||
|
|
||||||
|
def get_context_emb(self, prompt, img_list):
|
||||||
|
device = img_list[0].device
|
||||||
|
prompt_segs = prompt.split('<ImageHere>')
|
||||||
|
assert len(prompt_segs) == len(img_list) + 1, "Unmatched numbers of image placeholders and images."
|
||||||
|
seg_tokens = [
|
||||||
|
self.llama_tokenizer(
|
||||||
|
seg, return_tensors="pt", add_special_tokens=i==0).to(device).input_ids # only add bos to the first seg
|
||||||
|
for i, seg in enumerate(prompt_segs)
|
||||||
|
]
|
||||||
|
seg_embs = [self.embed_tokens(seg_t) for seg_t in seg_tokens]
|
||||||
|
|
||||||
|
mixed_embs = [emb for pair in zip(seg_embs[:-1], img_list) for emb in pair] + [seg_embs[-1]]
|
||||||
|
mixed_embs = torch.cat(mixed_embs, dim=1)
|
||||||
|
return mixed_embs
|
||||||
|
|
||||||
|
def prompt_wrap(self, img_embeds, atts_img, prompts, lengths=None):
|
||||||
|
if prompts is None or len(prompts) == 0:
|
||||||
|
# prompts is not provided, just return the original image embedding
|
||||||
|
return img_embeds, atts_img
|
||||||
|
elif img_embeds is None:
|
||||||
|
# prompt is provided but there is no image embedding. return the prompt embedding in right padding
|
||||||
|
self.llama_tokenizer.padding_side = "right"
|
||||||
|
prompt_tokens = self.llama_tokenizer(
|
||||||
|
prompts,
|
||||||
|
return_tensors="pt",
|
||||||
|
padding="longest",
|
||||||
|
add_special_tokens=False
|
||||||
|
).to(self.device)
|
||||||
|
prompt_embeds = self.embed_tokens(prompt_tokens.input_ids)
|
||||||
|
atts_prompt = prompt_tokens.attention_mask
|
||||||
|
return prompt_embeds, atts_prompt
|
||||||
|
else:
|
||||||
|
# return the multi-modal embedding in right padding
|
||||||
|
emb_lists = []
|
||||||
|
if isinstance(prompts, str):
|
||||||
|
prompts = [prompts] * len(img_embeds)
|
||||||
|
|
||||||
|
for idx, (each_img_embed, each_prompt) in enumerate(zip(img_embeds, prompts)):
|
||||||
|
pn = each_img_embed.shape[-2]
|
||||||
|
if lengths is not None:
|
||||||
|
each_img_embed = each_img_embed.reshape(-1, each_img_embed.shape[-1])
|
||||||
|
each_img_embed = each_img_embed[:lengths[idx] * pn]
|
||||||
|
p_segs = each_prompt.split('<ImageHere>')
|
||||||
|
interleave_emb = []
|
||||||
|
for idx, seg in enumerate(p_segs[:-1]):
|
||||||
|
p_tokens = self.llama_tokenizer(
|
||||||
|
seg, return_tensors="pt", add_special_tokens=False).to(img_embeds.device)
|
||||||
|
p_embed = self.embed_tokens(p_tokens.input_ids)
|
||||||
|
interleave_emb.append(torch.cat([p_embed, each_img_embed[None][:, idx * pn:(idx + 1) * pn]], dim=1))
|
||||||
|
wrapped_emb = torch.cat(interleave_emb, dim=1)
|
||||||
|
p_tokens = self.llama_tokenizer(
|
||||||
|
p_segs[-1], return_tensors="pt", add_special_tokens=False).to(img_embeds.device)
|
||||||
|
p_embed = self.embed_tokens(p_tokens.input_ids)
|
||||||
|
wrapped_emb = torch.cat([wrapped_emb, p_embed], dim=1)
|
||||||
|
emb_lists.append(wrapped_emb)
|
||||||
|
|
||||||
|
emb_lens = [emb.shape[1] for emb in emb_lists]
|
||||||
|
pad_emb = self.embed_tokens(torch.tensor(self.llama_tokenizer.pad_token_id, device=img_embeds.device))
|
||||||
|
|
||||||
|
max_length = max(emb_lens) if max(emb_lens) < self.max_context_len else self.max_context_len
|
||||||
|
wrapped_embs = pad_emb.expand(len(emb_lens), max_length, -1).clone()
|
||||||
|
wrapped_atts = torch.zeros([len(emb_lens), max_length], dtype=torch.int, device=img_embeds.device)
|
||||||
|
|
||||||
|
for i, emb in enumerate(emb_lists):
|
||||||
|
length = emb_lens[i] if emb_lens[i] < self.max_context_len else self.max_context_len
|
||||||
|
wrapped_embs[i, :length] = emb[:, :length]
|
||||||
|
wrapped_atts[i, :length] = 1
|
||||||
|
return wrapped_embs, wrapped_atts
|
||||||
|
|
||||||
|
def concat_emb_input_output(self, input_embs, input_atts, output_embs, output_atts):
|
||||||
|
"""
|
||||||
|
Concatenate the batched input embedding and batched output embedding together.
|
||||||
|
Both the input and the output embedding should be right padded.
|
||||||
|
"""
|
||||||
|
input_lens = []
|
||||||
|
cat_embs = []
|
||||||
|
cat_atts = []
|
||||||
|
for i in range(input_embs.size(0)):
|
||||||
|
input_len = input_atts[i].sum()
|
||||||
|
input_lens.append(input_len)
|
||||||
|
cat_embs.append(
|
||||||
|
torch.cat([
|
||||||
|
input_embs[i][:input_len],
|
||||||
|
output_embs[i],
|
||||||
|
input_embs[i][input_len:]
|
||||||
|
])
|
||||||
|
)
|
||||||
|
cat_atts.append(
|
||||||
|
torch.cat([
|
||||||
|
input_atts[i][:input_len],
|
||||||
|
output_atts[i],
|
||||||
|
input_atts[i][input_len:]
|
||||||
|
])
|
||||||
|
)
|
||||||
|
cat_embs = torch.stack(cat_embs)
|
||||||
|
cat_atts = torch.stack(cat_atts)
|
||||||
|
return cat_embs, cat_atts, input_lens
|
||||||
|
|
||||||
|
def tokenize_conversation(self, conv_q, conv_a):
|
||||||
|
"""concatenate conversation and make sure the model is only trained to regress the answer"""
|
||||||
|
|
||||||
|
to_regress_token_ids_list = []
|
||||||
|
targets_list = []
|
||||||
|
|
||||||
|
batch_size = len(conv_q)
|
||||||
|
for batch_idx in range(batch_size):
|
||||||
|
questions, answers = conv_q[batch_idx], conv_a[batch_idx]
|
||||||
|
questions = [self.llama_tokenizer(q,
|
||||||
|
return_tensors="pt",
|
||||||
|
add_special_tokens=False).to(self.device) for q in questions[1:]] # the first question is handled in the prompt wrap function, skip it
|
||||||
|
answers = [self.llama_tokenizer(q,
|
||||||
|
return_tensors="pt",
|
||||||
|
add_special_tokens=False).to(self.device) for q in answers]
|
||||||
|
cur_id = []
|
||||||
|
cur_target = []
|
||||||
|
for i in range(len(questions)):
|
||||||
|
cur_id.append(answers[i].input_ids)
|
||||||
|
cur_target.append(answers[i].input_ids)
|
||||||
|
cur_id.append(questions[i].input_ids)
|
||||||
|
cur_target.append(torch.ones_like(questions[i].input_ids) * -100)
|
||||||
|
|
||||||
|
cur_id.append(answers[-1].input_ids)
|
||||||
|
cur_target.append(answers[-1].input_ids)
|
||||||
|
|
||||||
|
cur_id = torch.cat(cur_id, dim=1)
|
||||||
|
cur_target = torch.cat(cur_target, dim=1)
|
||||||
|
to_regress_token_ids_list.append(cur_id)
|
||||||
|
targets_list.append(cur_target)
|
||||||
|
|
||||||
|
max_len = min(max([target.shape[1] for target in targets_list]), self.max_txt_len)
|
||||||
|
to_regress_token_ids = torch.ones([batch_size, max_len],
|
||||||
|
dtype=cur_id.dtype, device=self.device) * self.llama_tokenizer.pad_token_id
|
||||||
|
targets = torch.ones([batch_size, max_len],
|
||||||
|
dtype=cur_id.dtype, device=self.device) * -100
|
||||||
|
for batch_idx in range(batch_size):
|
||||||
|
cur_len = to_regress_token_ids_list[batch_idx].shape[1]
|
||||||
|
to_regress_token_ids[batch_idx, :cur_len] = to_regress_token_ids_list[batch_idx][0, :max_len]
|
||||||
|
targets[batch_idx, :cur_len] = targets_list[batch_idx][0, :max_len]
|
||||||
|
|
||||||
|
to_regress_token_attn = (to_regress_token_ids != self.llama_tokenizer.pad_token_id).to(torch.int)
|
||||||
|
|
||||||
|
return to_regress_token_ids, to_regress_token_attn, targets
|
||||||
|
|
||||||
|
def preparing_embedding(self, samples):
|
||||||
|
### prepare input tokens
|
||||||
|
if 'image' in samples:
|
||||||
|
img_embeds, img_atts = self.encode_img(samples["image"])
|
||||||
|
else:
|
||||||
|
img_embeds = img_atts = None
|
||||||
|
|
||||||
|
if 'conv_q' in samples:
|
||||||
|
# handeling conversation datasets
|
||||||
|
conv_q, conv_a = samples['conv_q'], samples['conv_a']
|
||||||
|
|
||||||
|
connect_sym = samples['connect_sym'][0]
|
||||||
|
conv_q = [q.split(connect_sym)for q in conv_q]
|
||||||
|
conv_a = [a.split(connect_sym) for a in conv_a]
|
||||||
|
|
||||||
|
conv_q = [[self.prompt_template.format(item) for item in items] for items in conv_q]
|
||||||
|
|
||||||
|
cond_embeds, cond_atts = self.prompt_wrap(img_embeds, img_atts, [q[0] for q in conv_q])
|
||||||
|
regress_token_ids, regress_atts, part_targets = self.tokenize_conversation(conv_q, conv_a)
|
||||||
|
|
||||||
|
else:
|
||||||
|
if "instruction_input" in samples:
|
||||||
|
instruction = samples["instruction_input"]
|
||||||
|
elif self.prompt_list:
|
||||||
|
instruction = random.choice(self.prompt_list)
|
||||||
|
else:
|
||||||
|
instruction = None
|
||||||
|
|
||||||
|
if self.chat_template:
|
||||||
|
instruction = [self.prompt_template.format(instruct) for instruct in instruction]
|
||||||
|
|
||||||
|
if 'length' in samples:
|
||||||
|
# the input is a image train (like videos)
|
||||||
|
bsz, pn, hs = img_embeds.shape
|
||||||
|
img_embeds = img_embeds.reshape(len(samples['image']), -1, pn, hs)
|
||||||
|
cond_embeds, cond_atts = self.prompt_wrap(img_embeds, img_atts, instruction, samples['length'])
|
||||||
|
else:
|
||||||
|
cond_embeds, cond_atts = self.prompt_wrap(img_embeds, img_atts, instruction)
|
||||||
|
|
||||||
|
### prepare target tokens
|
||||||
|
self.llama_tokenizer.padding_side = "right"
|
||||||
|
text = [t + self.end_sym for t in samples["answer"]]
|
||||||
|
|
||||||
|
regress_tokens = self.llama_tokenizer(
|
||||||
|
text,
|
||||||
|
return_tensors="pt",
|
||||||
|
padding="longest",
|
||||||
|
truncation=True,
|
||||||
|
max_length=self.max_txt_len,
|
||||||
|
add_special_tokens=False
|
||||||
|
).to(self.device)
|
||||||
|
|
||||||
|
regress_token_ids = regress_tokens.input_ids
|
||||||
|
regress_atts = regress_tokens.attention_mask
|
||||||
|
part_targets = regress_token_ids.masked_fill(
|
||||||
|
regress_token_ids == self.llama_tokenizer.pad_token_id, -100
|
||||||
|
)
|
||||||
|
|
||||||
|
regress_embeds = self.embed_tokens(regress_token_ids)
|
||||||
|
|
||||||
|
return cond_embeds, cond_atts, regress_embeds, regress_atts, part_targets
|
||||||
|
|
||||||
|
def forward(self, samples, reduction='mean'):
|
||||||
|
# prepare the embedding to condition and the embedding to regress
|
||||||
|
cond_embeds, cond_atts, regress_embeds, regress_atts, part_targets = \
|
||||||
|
self.preparing_embedding(samples)
|
||||||
|
|
||||||
|
# concat the embedding to condition and the embedding to regress
|
||||||
|
inputs_embeds, attention_mask, input_lens = \
|
||||||
|
self.concat_emb_input_output(cond_embeds, cond_atts, regress_embeds, regress_atts)
|
||||||
|
|
||||||
|
# get bos token embedding
|
||||||
|
bos = torch.ones_like(part_targets[:, :1]) * self.llama_tokenizer.bos_token_id
|
||||||
|
bos_embeds = self.embed_tokens(bos)
|
||||||
|
bos_atts = cond_atts[:, :1]
|
||||||
|
|
||||||
|
# add bos token at the begining
|
||||||
|
inputs_embeds = torch.cat([bos_embeds, inputs_embeds], dim=1)
|
||||||
|
attention_mask = torch.cat([bos_atts, attention_mask], dim=1)
|
||||||
|
|
||||||
|
# ensemble the final targets
|
||||||
|
targets = torch.ones([inputs_embeds.shape[0], inputs_embeds.shape[1]],
|
||||||
|
dtype=torch.long).to(self.device).fill_(-100)
|
||||||
|
|
||||||
|
for i, target in enumerate(part_targets):
|
||||||
|
targets[i, input_lens[i]+1:input_lens[i]+len(target)+1] = target # plus 1 for bos
|
||||||
|
|
||||||
|
with self.maybe_autocast():
|
||||||
|
outputs = self.llama_model(
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
return_dict=True,
|
||||||
|
labels=targets,
|
||||||
|
reduction=reduction
|
||||||
|
)
|
||||||
|
loss = outputs.loss
|
||||||
|
|
||||||
|
return {"loss": loss}
|
||||||
|
|
||||||
|
def embed_tokens(self, token_ids):
|
||||||
|
if hasattr(self.llama_model.base_model, 'model'): ## lora wrapped model
|
||||||
|
embeds = self.llama_model.base_model.model.model.embed_tokens(token_ids)
|
||||||
|
else:
|
||||||
|
embeds = self.llama_model.base_model.embed_tokens(token_ids)
|
||||||
|
return embeds
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def generate(
|
||||||
|
self,
|
||||||
|
images,
|
||||||
|
texts,
|
||||||
|
num_beams=1,
|
||||||
|
max_new_tokens=20,
|
||||||
|
min_length=1,
|
||||||
|
top_p=0.9,
|
||||||
|
repetition_penalty=1,
|
||||||
|
length_penalty=1,
|
||||||
|
temperature=1,
|
||||||
|
do_sample=False,
|
||||||
|
stop_words_ids=[2],
|
||||||
|
):
|
||||||
|
'''
|
||||||
|
function for generate test use
|
||||||
|
'''
|
||||||
|
|
||||||
|
stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(
|
||||||
|
stops=[torch.tensor([i]).to(self.device) for i in stop_words_ids])])
|
||||||
|
|
||||||
|
img_embeds, atts_img = self.encode_img(images.to(self.device))
|
||||||
|
image_lists = [[image_emb[None]] for image_emb in img_embeds]
|
||||||
|
|
||||||
|
batch_embs = [self.get_context_emb(text, img_list) for text, img_list in zip(texts, image_lists)]
|
||||||
|
|
||||||
|
batch_size = len(batch_embs)
|
||||||
|
max_len = max([emb.shape[1] for emb in batch_embs])
|
||||||
|
emb_dim = batch_embs[0].shape[2]
|
||||||
|
dtype = batch_embs[0].dtype
|
||||||
|
device = batch_embs[0].device
|
||||||
|
|
||||||
|
embs = torch.zeros([batch_size, max_len, emb_dim], dtype=dtype, device=device)
|
||||||
|
attn_mask = torch.zeros([batch_size, max_len], dtype=torch.int, device=device)
|
||||||
|
for i, emb in enumerate(batch_embs):
|
||||||
|
emb_len = emb.shape[1]
|
||||||
|
embs[i, -emb_len:] = emb[0]
|
||||||
|
attn_mask[i, -emb_len:] = 1
|
||||||
|
|
||||||
|
with self.maybe_autocast():
|
||||||
|
outputs = self.llama_model.generate(
|
||||||
|
inputs_embeds=embs,
|
||||||
|
attention_mask=attn_mask,
|
||||||
|
max_new_tokens=max_new_tokens,
|
||||||
|
num_beams=num_beams,
|
||||||
|
length_penalty=length_penalty,
|
||||||
|
temperature=temperature,
|
||||||
|
do_sample=do_sample,
|
||||||
|
min_length=min_length,
|
||||||
|
top_p=top_p,
|
||||||
|
repetition_penalty=repetition_penalty
|
||||||
|
# stopping_criteria=stopping_criteria,
|
||||||
|
)
|
||||||
|
|
||||||
|
answers = []
|
||||||
|
for output_token in outputs:
|
||||||
|
if output_token[0] == 0:
|
||||||
|
output_token = output_token[1:]
|
||||||
|
output_texts = self.llama_tokenizer.decode(output_token, skip_special_tokens=True)
|
||||||
|
output_texts = output_texts.split('</s>')[0] # remove the stop sign </s>
|
||||||
|
output_texts = output_texts.replace("<s>", "")
|
||||||
|
output_texts = output_texts.split(r'[/INST]')[-1].strip()
|
||||||
|
answers.append(output_texts)
|
||||||
|
|
||||||
|
return answers
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def multi_select(self, images, texts, answers, num_cand=None):
|
||||||
|
all_losses = []
|
||||||
|
for answer in answers:
|
||||||
|
choice_samples = {
|
||||||
|
'image': images,
|
||||||
|
'instruction_input': texts,
|
||||||
|
'answer': answer
|
||||||
|
}
|
||||||
|
loss = self.forward(choice_samples, reduction='none')['loss'].reshape(-1, 1)
|
||||||
|
all_losses.append(loss)
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
all_losses = torch.cat(all_losses, dim=-1)
|
||||||
|
if num_cand is not None:
|
||||||
|
for i in range(all_losses.shape[0]):
|
||||||
|
all_losses[i, num_cand[i]:] = 9999
|
||||||
|
output_class_ranks = torch.argsort(all_losses, dim=-1)
|
||||||
|
return output_class_ranks.tolist()
|
139
minigpt4/models/minigpt_v2.py
Normal file
@ -0,0 +1,139 @@
|
|||||||
|
import logging
|
||||||
|
import random
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch.cuda.amp import autocast as autocast
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from minigpt4.common.registry import registry
|
||||||
|
from minigpt4.models.base_model import disabled_train
|
||||||
|
from minigpt4.models.minigpt_base import MiniGPTBase
|
||||||
|
from minigpt4.models.Qformer import BertConfig, BertLMHeadModel
|
||||||
|
|
||||||
|
|
||||||
|
@registry.register_model("minigpt_v2")
|
||||||
|
class MiniGPTv2(MiniGPTBase):
|
||||||
|
"""
|
||||||
|
MiniGPT-v2 model
|
||||||
|
"""
|
||||||
|
|
||||||
|
PRETRAINED_MODEL_CONFIG_DICT = {
|
||||||
|
"pretrain": "configs/models/minigpt_v2.yaml",
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vit_model="eva_clip_g",
|
||||||
|
img_size=448,
|
||||||
|
drop_path_rate=0,
|
||||||
|
use_grad_checkpoint=False,
|
||||||
|
vit_precision="fp16",
|
||||||
|
freeze_vit=True,
|
||||||
|
llama_model="",
|
||||||
|
prompt_template='[INST] {} [/INST]',
|
||||||
|
max_txt_len=300,
|
||||||
|
end_sym='\n',
|
||||||
|
lora_r=64,
|
||||||
|
lora_target_modules=["q_proj", "v_proj"],
|
||||||
|
lora_alpha=16,
|
||||||
|
lora_dropout=0.05,
|
||||||
|
chat_template=False,
|
||||||
|
use_grad_checkpoint_llm=False,
|
||||||
|
max_context_len=3800,
|
||||||
|
low_resource=False, # use 8 bit and put vit in cpu
|
||||||
|
device_8bit=0, # the device of 8bit model should be set when loading and cannot be changed anymore.
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
vit_model=vit_model,
|
||||||
|
img_size=img_size,
|
||||||
|
drop_path_rate=drop_path_rate,
|
||||||
|
use_grad_checkpoint=use_grad_checkpoint,
|
||||||
|
vit_precision=vit_precision,
|
||||||
|
freeze_vit=freeze_vit,
|
||||||
|
llama_model=llama_model,
|
||||||
|
max_txt_len=max_txt_len,
|
||||||
|
max_context_len=max_context_len,
|
||||||
|
end_sym=end_sym,
|
||||||
|
prompt_template=prompt_template,
|
||||||
|
low_resource=low_resource,
|
||||||
|
device_8bit=device_8bit,
|
||||||
|
lora_r=lora_r,
|
||||||
|
lora_target_modules=lora_target_modules,
|
||||||
|
lora_alpha=lora_alpha,
|
||||||
|
lora_dropout=lora_dropout,
|
||||||
|
)
|
||||||
|
|
||||||
|
img_f_dim = self.visual_encoder.num_features * 4
|
||||||
|
self.llama_proj = nn.Linear(
|
||||||
|
img_f_dim, self.llama_model.config.hidden_size
|
||||||
|
)
|
||||||
|
self.chat_template = chat_template
|
||||||
|
|
||||||
|
if use_grad_checkpoint_llm:
|
||||||
|
self.llama_model.gradient_checkpointing_enable()
|
||||||
|
|
||||||
|
def encode_img(self, image):
|
||||||
|
device = image.device
|
||||||
|
|
||||||
|
if len(image.shape) > 4:
|
||||||
|
image = image.reshape(-1, *image.shape[-3:])
|
||||||
|
|
||||||
|
with self.maybe_autocast():
|
||||||
|
image_embeds = self.ln_vision(self.visual_encoder(image)).to(device)
|
||||||
|
image_embeds = image_embeds[:, 1:, :]
|
||||||
|
bs, pn, hs = image_embeds.shape
|
||||||
|
image_embeds = image_embeds.view(bs, int(pn / 4), int(hs * 4))
|
||||||
|
|
||||||
|
inputs_llama = self.llama_proj(image_embeds)
|
||||||
|
atts_llama = torch.ones(inputs_llama.size()[:-1], dtype=torch.long).to(image.device)
|
||||||
|
return inputs_llama, atts_llama
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_config(cls, cfg):
|
||||||
|
vit_model = cfg.get("vit_model", "eva_clip_g")
|
||||||
|
img_size = cfg.get("image_size")
|
||||||
|
llama_model = cfg.get("llama_model")
|
||||||
|
|
||||||
|
drop_path_rate = cfg.get("drop_path_rate", 0)
|
||||||
|
use_grad_checkpoint = cfg.get("use_grad_checkpoint", False)
|
||||||
|
vit_precision = cfg.get("vit_precision", "fp16")
|
||||||
|
freeze_vit = cfg.get("freeze_vit", True)
|
||||||
|
low_resource = cfg.get("low_resource", False)
|
||||||
|
|
||||||
|
prompt_template = cfg.get("prompt_template", '[INST] {} [/INST]')
|
||||||
|
max_txt_len = cfg.get("max_txt_len", 300)
|
||||||
|
end_sym = cfg.get("end_sym", '\n')
|
||||||
|
|
||||||
|
lora_r = cfg.get("lora_r", 64)
|
||||||
|
lora_alpha = cfg.get("lora_alpha", 16)
|
||||||
|
chat_template = cfg.get("chat_template", False)
|
||||||
|
|
||||||
|
use_grad_checkpoint_llm = cfg.get("use_grad_checkpoint_llm", False)
|
||||||
|
max_context_len = cfg.get("max_context_len", 3800)
|
||||||
|
|
||||||
|
model = cls(
|
||||||
|
vit_model=vit_model,
|
||||||
|
img_size=img_size,
|
||||||
|
drop_path_rate=drop_path_rate,
|
||||||
|
use_grad_checkpoint=use_grad_checkpoint,
|
||||||
|
vit_precision=vit_precision,
|
||||||
|
freeze_vit=freeze_vit,
|
||||||
|
llama_model=llama_model,
|
||||||
|
prompt_template=prompt_template,
|
||||||
|
max_txt_len=max_txt_len,
|
||||||
|
low_resource=low_resource,
|
||||||
|
end_sym=end_sym,
|
||||||
|
lora_r=lora_r,
|
||||||
|
lora_alpha=lora_alpha,
|
||||||
|
chat_template=chat_template,
|
||||||
|
use_grad_checkpoint_llm=use_grad_checkpoint_llm,
|
||||||
|
max_context_len=max_context_len,
|
||||||
|
)
|
||||||
|
|
||||||
|
ckpt_path = cfg.get("ckpt", "") # load weights of MiniGPT-4
|
||||||
|
if ckpt_path:
|
||||||
|
print("Load Minigpt-4-LLM Checkpoint: {}".format(ckpt_path))
|
||||||
|
ckpt = torch.load(ckpt_path, map_location="cpu")
|
||||||
|
msg = model.load_state_dict(ckpt['model'], strict=False)
|
||||||
|
|
||||||
|
return model
|
@ -1,628 +1,17 @@
|
|||||||
# This script is based on https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py
|
|
||||||
|
|
||||||
""" PyTorch LLaMA model."""
|
|
||||||
import math
|
import math
|
||||||
from typing import List, Optional, Tuple, Union
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.utils.checkpoint
|
import torch.nn.functional as F
|
||||||
from torch import nn
|
from torch.nn import CrossEntropyLoss
|
||||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
|
||||||
|
|
||||||
from transformers.activations import ACT2FN
|
from transformers.utils import add_start_docstrings_to_model_forward, replace_return_docstrings
|
||||||
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
|
from transformers.modeling_outputs import CausalLMOutputWithPast
|
||||||
from transformers.modeling_utils import PreTrainedModel
|
from transformers.models.llama.modeling_llama import LLAMA_INPUTS_DOCSTRING, _CONFIG_FOR_DOC
|
||||||
from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
|
from transformers.models.llama.modeling_llama import LlamaForCausalLM as LlamaForCausalLMOrig
|
||||||
from transformers.models.llama.configuration_llama import LlamaConfig
|
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
class LlamaForCausalLM(LlamaForCausalLMOrig):
|
||||||
|
|
||||||
_CONFIG_FOR_DOC = "LlamaConfig"
|
|
||||||
|
|
||||||
|
|
||||||
# Copied from transformers.models.bart.modeling_bart._make_causal_mask
|
|
||||||
def _make_causal_mask(
|
|
||||||
input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Make causal mask used for bi-directional self-attention.
|
|
||||||
"""
|
|
||||||
bsz, tgt_len = input_ids_shape
|
|
||||||
mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min, device=device), device=device)
|
|
||||||
mask_cond = torch.arange(mask.size(-1), device=device)
|
|
||||||
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
|
|
||||||
mask = mask.to(dtype)
|
|
||||||
|
|
||||||
if past_key_values_length > 0:
|
|
||||||
mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
|
|
||||||
return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
|
|
||||||
|
|
||||||
|
|
||||||
# Copied from transformers.models.bart.modeling_bart._expand_mask
|
|
||||||
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
|
|
||||||
"""
|
|
||||||
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
|
|
||||||
"""
|
|
||||||
bsz, src_len = mask.size()
|
|
||||||
tgt_len = tgt_len if tgt_len is not None else src_len
|
|
||||||
|
|
||||||
expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
|
|
||||||
|
|
||||||
inverted_mask = 1.0 - expanded_mask
|
|
||||||
|
|
||||||
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
|
|
||||||
|
|
||||||
|
|
||||||
class LlamaRMSNorm(nn.Module):
|
|
||||||
def __init__(self, hidden_size, eps=1e-6):
|
|
||||||
"""
|
|
||||||
LlamaRMSNorm is equivalent to T5LayerNorm
|
|
||||||
"""
|
|
||||||
super().__init__()
|
|
||||||
self.weight = nn.Parameter(torch.ones(hidden_size))
|
|
||||||
self.variance_epsilon = eps
|
|
||||||
|
|
||||||
def forward(self, hidden_states):
|
|
||||||
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
|
|
||||||
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
|
||||||
|
|
||||||
# convert into half-precision if necessary
|
|
||||||
if self.weight.dtype in [torch.float16, torch.bfloat16]:
|
|
||||||
hidden_states = hidden_states.to(self.weight.dtype)
|
|
||||||
|
|
||||||
return self.weight * hidden_states
|
|
||||||
|
|
||||||
|
|
||||||
class LlamaRotaryEmbedding(torch.nn.Module):
|
|
||||||
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
|
|
||||||
super().__init__()
|
|
||||||
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
|
|
||||||
self.register_buffer("inv_freq", inv_freq)
|
|
||||||
|
|
||||||
# Build here to make `torch.jit.trace` work.
|
|
||||||
self.max_seq_len_cached = max_position_embeddings
|
|
||||||
t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
|
|
||||||
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
|
||||||
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
|
||||||
emb = torch.cat((freqs, freqs), dim=-1)
|
|
||||||
self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False)
|
|
||||||
self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False)
|
|
||||||
|
|
||||||
def forward(self, x, seq_len=None):
|
|
||||||
# x: [bs, num_attention_heads, seq_len, head_size]
|
|
||||||
# This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.
|
|
||||||
if seq_len > self.max_seq_len_cached:
|
|
||||||
self.max_seq_len_cached = seq_len
|
|
||||||
t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype)
|
|
||||||
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
|
||||||
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
|
||||||
emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
|
|
||||||
self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False)
|
|
||||||
self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False)
|
|
||||||
return (
|
|
||||||
self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
|
|
||||||
self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def rotate_half(x):
|
|
||||||
"""Rotates half the hidden dims of the input."""
|
|
||||||
x1 = x[..., : x.shape[-1] // 2]
|
|
||||||
x2 = x[..., x.shape[-1] // 2 :]
|
|
||||||
return torch.cat((-x2, x1), dim=-1)
|
|
||||||
|
|
||||||
|
|
||||||
def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
|
|
||||||
gather_indices = position_ids[:, None, :, None] # [bs, 1, seq_len, 1]
|
|
||||||
gather_indices = gather_indices.repeat(1, cos.shape[1], 1, cos.shape[3])
|
|
||||||
cos = torch.gather(cos.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices)
|
|
||||||
sin = torch.gather(sin.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices)
|
|
||||||
q_embed = (q * cos) + (rotate_half(q) * sin)
|
|
||||||
k_embed = (k * cos) + (rotate_half(k) * sin)
|
|
||||||
return q_embed, k_embed
|
|
||||||
|
|
||||||
|
|
||||||
class LlamaMLP(nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
hidden_size: int,
|
|
||||||
intermediate_size: int,
|
|
||||||
hidden_act: str,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
|
|
||||||
self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False)
|
|
||||||
self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
|
|
||||||
self.act_fn = ACT2FN[hidden_act]
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
|
||||||
|
|
||||||
|
|
||||||
class LlamaAttention(nn.Module):
|
|
||||||
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
|
||||||
|
|
||||||
def __init__(self, config: LlamaConfig):
|
|
||||||
super().__init__()
|
|
||||||
self.config = config
|
|
||||||
self.hidden_size = config.hidden_size
|
|
||||||
self.num_heads = config.num_attention_heads
|
|
||||||
self.head_dim = self.hidden_size // self.num_heads
|
|
||||||
self.max_position_embeddings = config.max_position_embeddings
|
|
||||||
|
|
||||||
if (self.head_dim * self.num_heads) != self.hidden_size:
|
|
||||||
raise ValueError(
|
|
||||||
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
|
|
||||||
f" and `num_heads`: {self.num_heads})."
|
|
||||||
)
|
|
||||||
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
|
|
||||||
self.k_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
|
|
||||||
self.v_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
|
|
||||||
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
|
|
||||||
self.rotary_emb = LlamaRotaryEmbedding(self.head_dim, max_position_embeddings=self.max_position_embeddings)
|
|
||||||
|
|
||||||
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
|
||||||
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
hidden_states: torch.Tensor,
|
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
|
||||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
|
||||||
output_attentions: bool = False,
|
|
||||||
use_cache: bool = False,
|
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
|
||||||
bsz, q_len, _ = hidden_states.size()
|
|
||||||
|
|
||||||
query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
|
||||||
key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
|
||||||
value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
|
||||||
|
|
||||||
kv_seq_len = key_states.shape[-2]
|
|
||||||
if past_key_value is not None:
|
|
||||||
kv_seq_len += past_key_value[0].shape[-2]
|
|
||||||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
|
||||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
|
||||||
# [bsz, nh, t, hd]
|
|
||||||
|
|
||||||
if past_key_value is not None:
|
|
||||||
# reuse k, v, self_attention
|
|
||||||
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
|
||||||
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
|
||||||
|
|
||||||
past_key_value = (key_states, value_states) if use_cache else None
|
|
||||||
|
|
||||||
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
|
||||||
|
|
||||||
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
|
|
||||||
raise ValueError(
|
|
||||||
f"Attention weights should be of size {(bsz * self.num_heads, q_len, kv_seq_len)}, but is"
|
|
||||||
f" {attn_weights.size()}"
|
|
||||||
)
|
|
||||||
|
|
||||||
if attention_mask is not None:
|
|
||||||
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
|
|
||||||
raise ValueError(
|
|
||||||
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
|
|
||||||
)
|
|
||||||
attn_weights = attn_weights + attention_mask
|
|
||||||
attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min))
|
|
||||||
|
|
||||||
# upcast attention to fp32
|
|
||||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
|
||||||
attn_output = torch.matmul(attn_weights, value_states)
|
|
||||||
|
|
||||||
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
|
||||||
raise ValueError(
|
|
||||||
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
|
|
||||||
f" {attn_output.size()}"
|
|
||||||
)
|
|
||||||
|
|
||||||
attn_output = attn_output.transpose(1, 2)
|
|
||||||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
|
||||||
|
|
||||||
attn_output = self.o_proj(attn_output)
|
|
||||||
|
|
||||||
if not output_attentions:
|
|
||||||
attn_weights = None
|
|
||||||
|
|
||||||
return attn_output, attn_weights, past_key_value
|
|
||||||
|
|
||||||
|
|
||||||
class LlamaDecoderLayer(nn.Module):
|
|
||||||
def __init__(self, config: LlamaConfig):
|
|
||||||
super().__init__()
|
|
||||||
self.hidden_size = config.hidden_size
|
|
||||||
self.self_attn = LlamaAttention(config=config)
|
|
||||||
self.mlp = LlamaMLP(
|
|
||||||
hidden_size=self.hidden_size,
|
|
||||||
intermediate_size=config.intermediate_size,
|
|
||||||
hidden_act=config.hidden_act,
|
|
||||||
)
|
|
||||||
self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
|
||||||
self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
hidden_states: torch.Tensor,
|
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
|
||||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
|
||||||
output_attentions: Optional[bool] = False,
|
|
||||||
use_cache: Optional[bool] = False,
|
|
||||||
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
|
||||||
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
|
|
||||||
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
|
||||||
output_attentions (`bool`, *optional*):
|
|
||||||
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
|
||||||
returned tensors for more detail.
|
|
||||||
use_cache (`bool`, *optional*):
|
|
||||||
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
|
|
||||||
(see `past_key_values`).
|
|
||||||
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
|
|
||||||
"""
|
|
||||||
|
|
||||||
residual = hidden_states
|
|
||||||
|
|
||||||
hidden_states = self.input_layernorm(hidden_states)
|
|
||||||
|
|
||||||
# Self Attention
|
|
||||||
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
|
||||||
hidden_states=hidden_states,
|
|
||||||
attention_mask=attention_mask,
|
|
||||||
position_ids=position_ids,
|
|
||||||
past_key_value=past_key_value,
|
|
||||||
output_attentions=output_attentions,
|
|
||||||
use_cache=use_cache,
|
|
||||||
)
|
|
||||||
hidden_states = residual + hidden_states
|
|
||||||
|
|
||||||
# Fully Connected
|
|
||||||
residual = hidden_states
|
|
||||||
hidden_states = self.post_attention_layernorm(hidden_states)
|
|
||||||
hidden_states = self.mlp(hidden_states)
|
|
||||||
hidden_states = residual + hidden_states
|
|
||||||
|
|
||||||
outputs = (hidden_states,)
|
|
||||||
|
|
||||||
if output_attentions:
|
|
||||||
outputs += (self_attn_weights,)
|
|
||||||
|
|
||||||
if use_cache:
|
|
||||||
outputs += (present_key_value,)
|
|
||||||
|
|
||||||
return outputs
|
|
||||||
|
|
||||||
|
|
||||||
LLAMA_START_DOCSTRING = r"""
|
|
||||||
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
|
||||||
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
|
||||||
etc.)
|
|
||||||
|
|
||||||
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
|
|
||||||
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
|
|
||||||
and behavior.
|
|
||||||
|
|
||||||
Parameters:
|
|
||||||
config ([`LlamaConfig`]):
|
|
||||||
Model configuration class with all the parameters of the model. Initializing with a config file does not
|
|
||||||
load the weights associated with the model, only the configuration. Check out the
|
|
||||||
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
@add_start_docstrings(
|
|
||||||
"The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
|
|
||||||
LLAMA_START_DOCSTRING,
|
|
||||||
)
|
|
||||||
class LlamaPreTrainedModel(PreTrainedModel):
|
|
||||||
config_class = LlamaConfig
|
|
||||||
base_model_prefix = "model"
|
|
||||||
supports_gradient_checkpointing = True
|
|
||||||
_no_split_modules = ["LlamaDecoderLayer"]
|
|
||||||
_keys_to_ignore_on_load_unexpected = [r"decoder\.version"]
|
|
||||||
|
|
||||||
def _init_weights(self, module):
|
|
||||||
std = self.config.initializer_range
|
|
||||||
if isinstance(module, nn.Linear):
|
|
||||||
module.weight.data.normal_(mean=0.0, std=std)
|
|
||||||
if module.bias is not None:
|
|
||||||
module.bias.data.zero_()
|
|
||||||
elif isinstance(module, nn.Embedding):
|
|
||||||
module.weight.data.normal_(mean=0.0, std=std)
|
|
||||||
if module.padding_idx is not None:
|
|
||||||
module.weight.data[module.padding_idx].zero_()
|
|
||||||
|
|
||||||
def _set_gradient_checkpointing(self, module, value=False):
|
|
||||||
if isinstance(module, LlamaModel):
|
|
||||||
module.gradient_checkpointing = value
|
|
||||||
|
|
||||||
|
|
||||||
LLAMA_INPUTS_DOCSTRING = r"""
|
|
||||||
Args:
|
|
||||||
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
|
||||||
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
|
|
||||||
it.
|
|
||||||
|
|
||||||
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
|
||||||
[`PreTrainedTokenizer.__call__`] for details.
|
|
||||||
|
|
||||||
[What are input IDs?](../glossary#input-ids)
|
|
||||||
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
||||||
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
|
||||||
|
|
||||||
- 1 for tokens that are **not masked**,
|
|
||||||
- 0 for tokens that are **masked**.
|
|
||||||
|
|
||||||
[What are attention masks?](../glossary#attention-mask)
|
|
||||||
|
|
||||||
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
|
||||||
[`PreTrainedTokenizer.__call__`] for details.
|
|
||||||
|
|
||||||
If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
|
|
||||||
`past_key_values`).
|
|
||||||
|
|
||||||
If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
|
|
||||||
and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
|
|
||||||
information on the default strategy.
|
|
||||||
|
|
||||||
- 1 indicates the head is **not masked**,
|
|
||||||
- 0 indicates the head is **masked**.
|
|
||||||
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
||||||
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
|
|
||||||
config.n_positions - 1]`.
|
|
||||||
|
|
||||||
[What are position IDs?](../glossary#position-ids)
|
|
||||||
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
|
||||||
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
|
|
||||||
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
|
|
||||||
`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
|
|
||||||
|
|
||||||
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
|
|
||||||
blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
|
|
||||||
|
|
||||||
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
|
|
||||||
don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
|
|
||||||
`decoder_input_ids` of shape `(batch_size, sequence_length)`.
|
|
||||||
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
|
||||||
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
|
|
||||||
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
|
|
||||||
model's internal embedding lookup matrix.
|
|
||||||
use_cache (`bool`, *optional*):
|
|
||||||
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
|
|
||||||
`past_key_values`).
|
|
||||||
output_attentions (`bool`, *optional*):
|
|
||||||
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
|
||||||
tensors for more detail.
|
|
||||||
output_hidden_states (`bool`, *optional*):
|
|
||||||
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
|
||||||
more detail.
|
|
||||||
return_dict (`bool`, *optional*):
|
|
||||||
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
@add_start_docstrings(
|
|
||||||
"The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
|
|
||||||
LLAMA_START_DOCSTRING,
|
|
||||||
)
|
|
||||||
class LlamaModel(LlamaPreTrainedModel):
|
|
||||||
"""
|
|
||||||
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`]
|
|
||||||
|
|
||||||
Args:
|
|
||||||
config: LlamaConfig
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, config: LlamaConfig):
|
|
||||||
super().__init__(config)
|
|
||||||
self.padding_idx = config.pad_token_id
|
|
||||||
self.vocab_size = config.vocab_size
|
|
||||||
|
|
||||||
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
|
|
||||||
self.layers = nn.ModuleList([LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)])
|
|
||||||
self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
|
||||||
|
|
||||||
self.gradient_checkpointing = False
|
|
||||||
# Initialize weights and apply final processing
|
|
||||||
self.post_init()
|
|
||||||
|
|
||||||
def get_input_embeddings(self):
|
|
||||||
return self.embed_tokens
|
|
||||||
|
|
||||||
def set_input_embeddings(self, value):
|
|
||||||
self.embed_tokens = value
|
|
||||||
|
|
||||||
# Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
|
|
||||||
def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
|
|
||||||
# create causal mask
|
|
||||||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
|
||||||
combined_attention_mask = None
|
|
||||||
if input_shape[-1] > 1:
|
|
||||||
combined_attention_mask = _make_causal_mask(
|
|
||||||
input_shape,
|
|
||||||
inputs_embeds.dtype,
|
|
||||||
device=inputs_embeds.device,
|
|
||||||
past_key_values_length=past_key_values_length,
|
|
||||||
)
|
|
||||||
|
|
||||||
if attention_mask is not None:
|
|
||||||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
|
||||||
expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
|
|
||||||
inputs_embeds.device
|
|
||||||
)
|
|
||||||
combined_attention_mask = (
|
|
||||||
expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
|
|
||||||
)
|
|
||||||
|
|
||||||
return combined_attention_mask
|
|
||||||
|
|
||||||
@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
input_ids: torch.LongTensor = None,
|
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
|
||||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
|
||||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
||||||
query_embeds: Optional[torch.FloatTensor] = None,
|
|
||||||
use_cache: Optional[bool] = None,
|
|
||||||
output_attentions: Optional[bool] = None,
|
|
||||||
output_hidden_states: Optional[bool] = None,
|
|
||||||
return_dict: Optional[bool] = None,
|
|
||||||
) -> Union[Tuple, BaseModelOutputWithPast]:
|
|
||||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
||||||
output_hidden_states = (
|
|
||||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
||||||
)
|
|
||||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
|
||||||
|
|
||||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
||||||
|
|
||||||
# retrieve input_ids and inputs_embeds
|
|
||||||
if input_ids is not None and inputs_embeds is not None:
|
|
||||||
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
|
|
||||||
elif input_ids is not None:
|
|
||||||
batch_size, seq_length = input_ids.shape
|
|
||||||
elif inputs_embeds is not None:
|
|
||||||
batch_size, seq_length, _ = inputs_embeds.shape
|
|
||||||
else:
|
|
||||||
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
|
|
||||||
|
|
||||||
if inputs_embeds is None:
|
|
||||||
inputs_embeds = self.embed_tokens(input_ids)
|
|
||||||
if query_embeds is not None:
|
|
||||||
inputs_embeds = torch.cat([query_embeds, inputs_embeds], dim=1)
|
|
||||||
batch_size, seq_length, _ = inputs_embeds.shape
|
|
||||||
|
|
||||||
seq_length_with_past = seq_length
|
|
||||||
past_key_values_length = 0
|
|
||||||
|
|
||||||
if past_key_values is not None:
|
|
||||||
past_key_values_length = past_key_values[0][0].shape[2]
|
|
||||||
seq_length_with_past = seq_length_with_past + past_key_values_length
|
|
||||||
|
|
||||||
if position_ids is None:
|
|
||||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
|
||||||
position_ids = torch.arange(
|
|
||||||
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
|
|
||||||
)
|
|
||||||
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
|
|
||||||
else:
|
|
||||||
position_ids = position_ids.view(-1, seq_length).long()
|
|
||||||
|
|
||||||
# embed positions
|
|
||||||
if attention_mask is None:
|
|
||||||
attention_mask = torch.ones(
|
|
||||||
(batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
|
|
||||||
)
|
|
||||||
attention_mask = self._prepare_decoder_attention_mask(
|
|
||||||
attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
|
|
||||||
)
|
|
||||||
|
|
||||||
hidden_states = inputs_embeds
|
|
||||||
|
|
||||||
if self.gradient_checkpointing and self.training:
|
|
||||||
if use_cache:
|
|
||||||
logger.warning_once(
|
|
||||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
|
||||||
)
|
|
||||||
use_cache = False
|
|
||||||
|
|
||||||
# decoder layers
|
|
||||||
all_hidden_states = () if output_hidden_states else None
|
|
||||||
all_self_attns = () if output_attentions else None
|
|
||||||
next_decoder_cache = () if use_cache else None
|
|
||||||
|
|
||||||
for idx, decoder_layer in enumerate(self.layers):
|
|
||||||
if output_hidden_states:
|
|
||||||
all_hidden_states += (hidden_states,)
|
|
||||||
|
|
||||||
past_key_value = past_key_values[idx] if past_key_values is not None else None
|
|
||||||
|
|
||||||
if self.gradient_checkpointing and self.training:
|
|
||||||
|
|
||||||
def create_custom_forward(module):
|
|
||||||
def custom_forward(*inputs):
|
|
||||||
# None for past_key_value
|
|
||||||
return module(*inputs, output_attentions, None)
|
|
||||||
|
|
||||||
return custom_forward
|
|
||||||
|
|
||||||
layer_outputs = torch.utils.checkpoint.checkpoint(
|
|
||||||
create_custom_forward(decoder_layer),
|
|
||||||
hidden_states,
|
|
||||||
attention_mask,
|
|
||||||
position_ids,
|
|
||||||
None,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
layer_outputs = decoder_layer(
|
|
||||||
hidden_states,
|
|
||||||
attention_mask=attention_mask,
|
|
||||||
position_ids=position_ids,
|
|
||||||
past_key_value=past_key_value,
|
|
||||||
output_attentions=output_attentions,
|
|
||||||
use_cache=use_cache,
|
|
||||||
)
|
|
||||||
|
|
||||||
hidden_states = layer_outputs[0]
|
|
||||||
|
|
||||||
if use_cache:
|
|
||||||
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
|
|
||||||
|
|
||||||
if output_attentions:
|
|
||||||
all_self_attns += (layer_outputs[1],)
|
|
||||||
|
|
||||||
hidden_states = self.norm(hidden_states)
|
|
||||||
|
|
||||||
# add hidden states from the last decoder layer
|
|
||||||
if output_hidden_states:
|
|
||||||
all_hidden_states += (hidden_states,)
|
|
||||||
|
|
||||||
next_cache = next_decoder_cache if use_cache else None
|
|
||||||
if not return_dict:
|
|
||||||
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
|
||||||
return BaseModelOutputWithPast(
|
|
||||||
last_hidden_state=hidden_states,
|
|
||||||
past_key_values=next_cache,
|
|
||||||
hidden_states=all_hidden_states,
|
|
||||||
attentions=all_self_attns,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class LlamaForCausalLM(LlamaPreTrainedModel):
|
|
||||||
def __init__(self, config):
|
|
||||||
super().__init__(config)
|
|
||||||
self.model = LlamaModel(config)
|
|
||||||
|
|
||||||
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
|
||||||
|
|
||||||
# Initialize weights and apply final processing
|
|
||||||
self.post_init()
|
|
||||||
|
|
||||||
def get_input_embeddings(self):
|
|
||||||
return self.model.embed_tokens
|
|
||||||
|
|
||||||
def set_input_embeddings(self, value):
|
|
||||||
self.model.embed_tokens = value
|
|
||||||
|
|
||||||
def get_output_embeddings(self):
|
|
||||||
return self.lm_head
|
|
||||||
|
|
||||||
def set_output_embeddings(self, new_embeddings):
|
|
||||||
self.lm_head = new_embeddings
|
|
||||||
|
|
||||||
def set_decoder(self, decoder):
|
|
||||||
self.model = decoder
|
|
||||||
|
|
||||||
def get_decoder(self):
|
|
||||||
return self.model
|
|
||||||
|
|
||||||
@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
|
@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
|
||||||
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
||||||
@ -633,12 +22,12 @@ class LlamaForCausalLM(LlamaPreTrainedModel):
|
|||||||
position_ids: Optional[torch.LongTensor] = None,
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
query_embeds: Optional[torch.FloatTensor] = None,
|
|
||||||
labels: Optional[torch.LongTensor] = None,
|
labels: Optional[torch.LongTensor] = None,
|
||||||
use_cache: Optional[bool] = None,
|
use_cache: Optional[bool] = None,
|
||||||
output_attentions: Optional[bool] = None,
|
output_attentions: Optional[bool] = None,
|
||||||
output_hidden_states: Optional[bool] = None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
return_dict: Optional[bool] = None,
|
return_dict: Optional[bool] = None,
|
||||||
|
reduction: Optional[str] = "mean",
|
||||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||||
r"""
|
r"""
|
||||||
Args:
|
Args:
|
||||||
@ -657,13 +46,13 @@ class LlamaForCausalLM(LlamaPreTrainedModel):
|
|||||||
>>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
|
>>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
|
||||||
>>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
|
>>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
|
||||||
|
|
||||||
>>> prompt = "Hey, are you consciours? Can you talk to me?"
|
>>> prompt = "Hey, are you conscious? Can you talk to me?"
|
||||||
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
||||||
|
|
||||||
>>> # Generate
|
>>> # Generate
|
||||||
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
||||||
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||||
"Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you."
|
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
||||||
```"""
|
```"""
|
||||||
|
|
||||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||||
@ -679,7 +68,6 @@ class LlamaForCausalLM(LlamaPreTrainedModel):
|
|||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
past_key_values=past_key_values,
|
past_key_values=past_key_values,
|
||||||
inputs_embeds=inputs_embeds,
|
inputs_embeds=inputs_embeds,
|
||||||
query_embeds=query_embeds,
|
|
||||||
use_cache=use_cache,
|
use_cache=use_cache,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
@ -687,7 +75,13 @@ class LlamaForCausalLM(LlamaPreTrainedModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
hidden_states = outputs[0]
|
hidden_states = outputs[0]
|
||||||
logits = self.lm_head(hidden_states)
|
if self.config.pretraining_tp > 1:
|
||||||
|
lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0)
|
||||||
|
logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)]
|
||||||
|
logits = torch.cat(logits, dim=-1)
|
||||||
|
else:
|
||||||
|
logits = self.lm_head(hidden_states)
|
||||||
|
logits = logits.float()
|
||||||
|
|
||||||
loss = None
|
loss = None
|
||||||
if labels is not None:
|
if labels is not None:
|
||||||
@ -695,12 +89,14 @@ class LlamaForCausalLM(LlamaPreTrainedModel):
|
|||||||
shift_logits = logits[..., :-1, :].contiguous()
|
shift_logits = logits[..., :-1, :].contiguous()
|
||||||
shift_labels = labels[..., 1:].contiguous()
|
shift_labels = labels[..., 1:].contiguous()
|
||||||
# Flatten the tokens
|
# Flatten the tokens
|
||||||
loss_fct = CrossEntropyLoss()
|
loss_fct = CrossEntropyLoss(reduction=reduction)
|
||||||
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
||||||
shift_labels = shift_labels.view(-1)
|
shift_labels = shift_labels.view(-1)
|
||||||
# Enable model parallelism
|
# Enable model parallelism
|
||||||
shift_labels = shift_labels.to(shift_logits.device)
|
shift_labels = shift_labels.to(shift_logits.device)
|
||||||
loss = loss_fct(shift_logits, shift_labels)
|
loss = loss_fct(shift_logits, shift_labels)
|
||||||
|
if reduction == "none":
|
||||||
|
loss = loss.view(logits.size(0), -1).mean(1)
|
||||||
|
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
output = (logits,) + outputs[1:]
|
output = (logits,) + outputs[1:]
|
||||||
@ -713,43 +109,3 @@ class LlamaForCausalLM(LlamaPreTrainedModel):
|
|||||||
hidden_states=outputs.hidden_states,
|
hidden_states=outputs.hidden_states,
|
||||||
attentions=outputs.attentions,
|
attentions=outputs.attentions,
|
||||||
)
|
)
|
||||||
|
|
||||||
def prepare_inputs_for_generation(
|
|
||||||
self, input_ids, query_embeds=None, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
|
|
||||||
):
|
|
||||||
if past_key_values:
|
|
||||||
input_ids = input_ids[:, -1:]
|
|
||||||
|
|
||||||
position_ids = kwargs.get("position_ids", None)
|
|
||||||
if attention_mask is not None and position_ids is None:
|
|
||||||
# create position_ids on the fly for batch generation
|
|
||||||
position_ids = attention_mask.long().cumsum(-1) - 1
|
|
||||||
position_ids.masked_fill_(attention_mask == 0, 1)
|
|
||||||
if past_key_values:
|
|
||||||
position_ids = position_ids[:, -1].unsqueeze(-1)
|
|
||||||
query_embeds = None
|
|
||||||
|
|
||||||
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
|
||||||
if inputs_embeds is not None and past_key_values is None:
|
|
||||||
model_inputs = {"inputs_embeds": inputs_embeds}
|
|
||||||
else:
|
|
||||||
model_inputs = {"input_ids": input_ids}
|
|
||||||
|
|
||||||
model_inputs.update(
|
|
||||||
{
|
|
||||||
"position_ids": position_ids,
|
|
||||||
"query_embeds": query_embeds,
|
|
||||||
"past_key_values": past_key_values,
|
|
||||||
"use_cache": kwargs.get("use_cache"),
|
|
||||||
"attention_mask": attention_mask,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
return model_inputs
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _reorder_cache(past_key_values, beam_idx):
|
|
||||||
reordered_past = ()
|
|
||||||
for layer_past in past_key_values:
|
|
||||||
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
|
|
||||||
return reordered_past
|
|
||||||
|
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
model:
|
model:
|
||||||
arch: mini_gpt4
|
arch: minigpt4
|
||||||
model_type: pretrain_llama2
|
model_type: pretrain_llama2
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
model:
|
model:
|
||||||
arch: mini_gpt4
|
arch: minigpt4
|
||||||
model_type: pretrain_llama2
|
model_type: pretrain_llama2
|
||||||
|
|
||||||
max_txt_len: 160
|
max_txt_len: 160
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
model:
|
model:
|
||||||
arch: mini_gpt4
|
arch: minigpt4
|
||||||
model_type: pretrain_vicuna0
|
model_type: pretrain_vicuna0
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
model:
|
model:
|
||||||
arch: mini_gpt4
|
arch: minigpt4
|
||||||
model_type: pretrain_vicuna0
|
model_type: pretrain_vicuna0
|
||||||
|
|
||||||
max_txt_len: 160
|
max_txt_len: 160
|
||||||
|