Merge pull request #373 from TsuTikgiau/main

Update to v2
This commit is contained in:
ZhuDeyao 2023-10-13 23:51:49 +03:00 committed by GitHub
commit bf36d7fb89
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
37 changed files with 1782 additions and 1674 deletions

41
MiniGPT4_Train.md Normal file
View File

@ -0,0 +1,41 @@
## Training of MiniGPT-4
The training of MiniGPT-4 contains two alignment stages.
**1. First pretraining stage**
In the first pretrained stage, the model is trained using image-text pairs from Laion and CC datasets
to align the vision and language model. To download and prepare the datasets, please check
our [first stage dataset preparation instruction](dataset/README_1_STAGE.md).
After the first stage, the visual features are mapped and can be understood by the language
model.
To launch the first stage training, run the following command. In our experiments, we use 4 A100.
You can change the save path in the config file
[train_configs/minigpt4_stage1_pretrain.yaml](train_configs/minigpt4_stage1_pretrain.yaml)
```bash
torchrun --nproc-per-node NUM_GPU train.py --cfg-path train_configs/minigpt4_stage1_pretrain.yaml
```
A MiniGPT-4 checkpoint with only stage one training can be downloaded
[here (13B)](https://drive.google.com/file/d/1u9FRRBB3VovP1HxCAlpD9Lw4t4P6-Yq8/view?usp=share_link) or [here (7B)](https://drive.google.com/file/d/1HihQtCEXUyBM1i9DQbaK934wW3TZi-h5/view?usp=share_link).
Compared to the model after stage two, this checkpoint generate incomplete and repeated sentences frequently.
**2. Second finetuning stage**
In the second stage, we use a small high quality image-text pair dataset created by ourselves
and convert it to a conversation format to further align MiniGPT-4.
To download and prepare our second stage dataset, please check our
[second stage dataset preparation instruction](dataset/README_2_STAGE.md).
To launch the second stage alignment,
first specify the path to the checkpoint file trained in stage 1 in
[train_configs/minigpt4_stage1_pretrain.yaml](train_configs/minigpt4_stage2_finetune.yaml).
You can also specify the output path there.
Then, run the following command. In our experiments, we use 1 A100.
```bash
torchrun --nproc-per-node NUM_GPU train.py --cfg-path train_configs/minigpt4_stage2_finetune.yaml
```
After the second stage alignment, MiniGPT-4 is able to talk about the image coherently and user-friendly.

Binary file not shown.

BIN
MiniGPTv2.pdf Normal file

Binary file not shown.

View File

@ -1,35 +0,0 @@
## How to Prepare Vicuna Weight
Vicuna is an open-source LLAMA-based LLM that has a performance close to ChatGPT.
We currently use the v0 version of Vicuna-13B.
To prepare Vicunas weight, first download Vicunas **delta** weight from [https://huggingface.co/lmsys/vicuna-13b-delta-v0](https://huggingface.co/lmsys/vicuna-13b-delta-v0).
In case you have git-lfs installed (https://git-lfs.com), this can be done by
```
git lfs install
git clone https://huggingface.co/lmsys/vicuna-13b-delta-v0 # more powerful, need at least 24G gpu memory
# or
git clone https://huggingface.co/lmsys/vicuna-7b-delta-v0 # smaller, need 12G gpu memory
```
Note that this is not directly the working weight, but the difference between the working weight and the original weight of LLAMA-13B. (Due to LLAMAs rules, we cannot distribute the weight of LLAMA.)
Then, you need to obtain the original LLAMA-7B or LLAMA-13B weights in the HuggingFace format
either following the instruction provided by HuggingFace
[here](https://huggingface.co/docs/transformers/main/model_doc/llama) or from the Internet.
When these two weights are ready, we can use tools from Vicunas team to create the real working weight.
First, Install their library that is compatible with v0 Vicuna by
```
pip install git+https://github.com/lm-sys/FastChat.git@v0.1.10
```
Then, run the following command to create the final working weight
```
python -m fastchat.model.apply_delta --base /path/to/llama-13bOR7b-hf/ --target /path/to/save/working/vicuna/weight/ --delta /path/to/vicuna-13bOR7b-delta-v0/
```
Now you are good to go!

153
README.md
View File

@ -1,24 +1,48 @@
# MiniGPT-4: Enhancing Vision-language Understanding with Advanced Large Language Models
[Deyao Zhu](https://tsutikgiau.github.io/)* , [Jun Chen](https://junchen14.github.io/)* (On Job Market!), [Xiaoqian Shen](https://xiaoqian-shen.github.io), [Xiang Li](https://xiangli.ac.cn), and [Mohamed Elhoseiny](https://www.mohamed-elhoseiny.com/). *Equal Contribution
# MiniGPT-V
**King Abdullah University of Science and Technology**
<font size='5'>**MiniGPT-v2: Large Language Model as a Unified Interface for Vision-Language Multi-task Learning**</font>
Jun Chen, Deyao Zhu, Xiaoqian Shen, Xiang Li, Zechun Liu, Pengchuan Zhang, Raghuraman Krishnamoorthi, Vikas Chandra, Yunyang Xiong☨, Mohamed Elhoseiny☨
☨equal last author
<a href='https://minigpt-v2.github.io'><img src='https://img.shields.io/badge/Project-Page-Green'></a> <a href='https://github.com/Vision-CAIR/MiniGPT-4/blob/main/MiniGPTv2.pdf'><img src='https://img.shields.io/badge/Paper-PDF-red'></a> <a href='https://minigpt-v2.github.io'><img src='https://img.shields.io/badge/Gradio-Demo-blue'></a> [![YouTube](https://badges.aleen42.com/src/youtube.svg)](https://www.youtube.com/watch?v=atFCwV2hSY4)
<font size='5'>**MiniGPT-4: Enhancing Vision-language Understanding with Advanced Large Language Models**</font>
Deyao Zhu*, Jun Chen*, Xiaoqian Shen, Xiang Li, Mohamed Elhoseiny
*equal contribution
<a href='https://minigpt-4.github.io'><img src='https://img.shields.io/badge/Project-Page-Green'></a> <a href='https://arxiv.org/abs/2304.10592'><img src='https://img.shields.io/badge/Paper-Arxiv-red'></a> <a href='https://huggingface.co/spaces/Vision-CAIR/minigpt4'><img src='https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue'></a> <a href='https://huggingface.co/Vision-CAIR/MiniGPT-4'><img src='https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Model-blue'></a> [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1OK4kYsZphwt5DXchKkzMBjYF6jnkqh4R?usp=sharing) [![YouTube](https://badges.aleen42.com/src/youtube.svg)](https://www.youtube.com/watch?v=__tftoxpBAw&feature=youtu.be)
*King Abdullah University of Science and Technology*
## 💡 Get help - [Q&A](https://github.com/Vision-CAIR/MiniGPT-4/discussions/categories/q-a) or [Discord 💬](https://discord.gg/5WdJkjbAeE)
## News
We now provide a llama 2 version of MiniGPT-4
[Oct.13 2023] Breaking! We release the first major update with our MiniGPT-v2
[Aug.28 2023] We now provide a llama 2 version of MiniGPT-4
## Online Demo
Click the image to chat with MiniGPT-v2 around your images
[![demo](figs/minigpt2_demo.png)](https://minigpt-v2.github.io/)
Click the image to chat with MiniGPT-4 around your images
[![demo](figs/online_demo.png)](https://minigpt-4.github.io)
## Examples
## MiniGPT-v2 Examples
![MiniGPT-v2 demos](figs/demo.png)
## MiniGPT-4 Examples
| | |
:-------------------------:|:-------------------------:
![find wild](figs/examples/wop_2.png) | ![write story](figs/examples/ad_2.png)
@ -28,17 +52,6 @@ More examples can be found in the [project page](https://minigpt-4.github.io).
## Introduction
- MiniGPT-4 aligns a frozen visual encoder from BLIP-2 with a frozen LLM, Vicuna, using just one projection layer.
- We train MiniGPT-4 with two stages. The first traditional pretraining stage is trained using roughly 5 million aligned image-text pairs in 10 hours using 4 A100s. After the first stage, Vicuna is able to understand the image. But the generation ability of Vicuna is heavily impacted.
- To address this issue and improve usability, we propose a novel way to create high-quality image-text pairs by the model itself and ChatGPT together. Based on this, we then create a small (3500 pairs in total) yet high-quality dataset.
- The second finetuning stage is trained on this dataset in a conversation template to significantly improve its generation reliability and overall usability. To our surprise, this stage is computationally efficient and takes only around 7 minutes with a single A100.
- MiniGPT-4 yields many emerging vision-language capabilities similar to those demonstrated in GPT-4.
![overview](figs/overview.png)
## Getting Started
### Installation
@ -56,42 +69,62 @@ conda activate minigpt4
**2. Prepare the pretrained LLM weights**
Currently, we provide both Vicuna V0 and Llama 2 version of MiniGPT-4.
**MiniGPT-v2** is based on Llama2 Chat 7B. For **MiniGPT-4**, we have both Vicuna V0 and Llama 2 version.
Download the corresponding LLM weights from the following huggingface space via clone the repository using git-lfs.
| Vicuna V0 13B | Vicuna V0 7B | Llama 2 Chat 7B |
| Llama 2 Chat 7B | Vicuna V0 13B | Vicuna V0 7B |
:------------------------------------------------------------------------------------------------:|:----------------------------------------------------------------------------------------------:|:----------------------------------------------------------------------------------------------:
[Downlad](https://huggingface.co/Vision-CAIR/vicuna/tree/main) | [Download](https://huggingface.co/Vision-CAIR/vicuna-7b/tree/main) | [Download](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf/tree/main)
[Download](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf/tree/main) | [Downlad](https://huggingface.co/Vision-CAIR/vicuna/tree/main) | [Download](https://huggingface.co/Vision-CAIR/vicuna-7b/tree/main)
Then, set the path to the vicuna weight in the model config file
[here](minigpt4/configs/models/minigpt4_vicuna0.yaml#L18) at Line 18
and/or the path to the llama2 weight in the model config file
Then, set the variable *llama_model* in the model config file to the LLM weight path.
* For MiniGPT-v2, set the LLM path
[here](minigpt4/configs/models/minigpt_v2.yaml#L15) at Line 14.
* For MiniGPT-4 (Llama2), set the LLM path
[here](minigpt4/configs/models/minigpt4_llama2.yaml#L15) at Line 15.
**3. Prepare the pretrained MiniGPT-4 checkpoint**
* For MiniGPT-4 (Vicuna), set the LLM path
[here](minigpt4/configs/models/minigpt4_vicuna0.yaml#L18) at Line 18
Download the pretrained checkpoints according to the Vicuna model you prepare.
**3. Prepare the pretrained model checkpoints**
| Checkpoint Aligned with Vicuna 13B | Checkpoint Aligned with Vicuna 7B | Checkpoint Aligned with Llama 2 Chat 7B |
:------------------------------------------------------------------------------------------------:|:-----------------------------------------------------------------------------------------------:|:----------------------------------------------------------------------------------------------:
[Downlad](https://drive.google.com/file/d/1a4zLvaiDBr-36pasffmgpvH5P7CKmpze/view?usp=share_link) | [Download](https://drive.google.com/file/d/1RY9jV0dyqLX-o38LrumkKRh6Jtaop58R/view?usp=sharing) | [Download](https://drive.google.com/file/d/11nAPjEok8eAGGEG1N2vXo3kBLCg0WgUk/view?usp=sharing)
Download the pretrained model checkpoints
Then, set the path to the pretrained checkpoint in the evaluation config file
| MiniGPT-v2 (LLaMA-2 Chat 7B) |
|------------------------------|
| [Download](https://drive.google.com/file/d/1aVbfW7nkCSYx99_vCRyP1sOlQiWVSnAl/view?usp=sharing) |
For **MiniGPT-v2**, set the path to the pretrained checkpoint in the evaluation config file
in [eval_configs/minigptv2_eval.yaml](eval_configs/minigptv2_eval.yaml#L10) at Line 8.
| MiniGPT-4 (Vicuna 13B) | MiniGPT-4 (Vicuna 7B) | MiniGPT-4 (LLaMA-2 Chat 7B) |
|----------------------------|---------------------------|---------------------------------|
| [Download](https://drive.google.com/file/d/1a4zLvaiDBr-36pasffmgpvH5P7CKmpze/view?usp=share_link) | [Download](https://drive.google.com/file/d/1RY9jV0dyqLX-o38LrumkKRh6Jtaop58R/view?usp=sharing) | [Download](https://drive.google.com/file/d/11nAPjEok8eAGGEG1N2vXo3kBLCg0WgUk/view?usp=sharing) |
For **MiniGPT-4**, set the path to the pretrained checkpoint in the evaluation config file
in [eval_configs/minigpt4_eval.yaml](eval_configs/minigpt4_eval.yaml#L10) at Line 8 for Vicuna version or [eval_configs/minigpt4_llama2_eval.yaml](eval_configs/minigpt4_llama2_eval.yaml#L10) for LLama2 version.
### Launching Demo Locally
Try out our demo [demo.py](demo.py) for the vicuna version on your local machine by running
For MiniGPT-v2, run
```
python demo_v2.py --cfg-path eval_configs/minigpt4v2_eval.yaml --gpu-id 0
```
For MiniGPT-4 (Vicuna version), run
```
python demo.py --cfg-path eval_configs/minigpt4_eval.yaml --gpu-id 0
```
or for Llama 2 version by
For MiniGPT-4 (Llama2 version), run
```
python demo.py --cfg-path eval_configs/minigpt4_llama2_eval.yaml --gpu-id 0
@ -101,52 +134,17 @@ python demo.py --cfg-path eval_configs/minigpt4_llama2_eval.yaml --gpu-id 0
To save GPU memory, LLMs loads as 8 bit by default, with a beam search width of 1.
This configuration requires about 23G GPU memory for 13B LLM and 11.5G GPU memory for 7B LLM.
For more powerful GPUs, you can run the model
in 16 bit by setting `low_resource` to `False` in the relevant config file
(line 6 of either [minigpt4_eval.yaml](eval_configs/minigpt4_eval.yaml#6) if using Vicuna or [minigpt4_llama2_eval.yaml](eval_configs/minigpt4_llama2_eval.yaml#6) if using Llama 2) and use a larger beam search width.
in 16 bit by setting `low_resource` to `False` in the relevant config file:
Thanks [@WangRongsheng](https://github.com/WangRongsheng), you can also run our code on [Colab](https://colab.research.google.com/drive/1OK4kYsZphwt5DXchKkzMBjYF6jnkqh4R?usp=sharing)
* MiniGPT-v2: [minigptv2_eval.yaml](eval_configs/minigptv2_eval.yaml#6)
* MiniGPT-4 (Llama2): [minigpt4_llama2_eval.yaml](eval_configs/minigpt4_llama2_eval.yaml#6)
* MiniGPT-4 (Vicuna): [minigpt4_eval.yaml](eval_configs/minigpt4_eval.yaml#6)
Thanks [@WangRongsheng](https://github.com/WangRongsheng), you can also run MiniGPT-4 on [Colab](https://colab.research.google.com/drive/1OK4kYsZphwt5DXchKkzMBjYF6jnkqh4R?usp=sharing)
### Training
The training of MiniGPT-4 contains two alignment stages.
**1. First pretraining stage**
In the first pretrained stage, the model is trained using image-text pairs from Laion and CC datasets
to align the vision and language model. To download and prepare the datasets, please check
our [first stage dataset preparation instruction](dataset/README_1_STAGE.md).
After the first stage, the visual features are mapped and can be understood by the language
model.
To launch the first stage training, run the following command. In our experiments, we use 4 A100.
You can change the save path in the config file
[train_configs/minigpt4_stage1_pretrain.yaml](train_configs/minigpt4_stage1_pretrain.yaml)
```bash
torchrun --nproc-per-node NUM_GPU train.py --cfg-path train_configs/minigpt4_stage1_pretrain.yaml
```
A MiniGPT-4 checkpoint with only stage one training can be downloaded
[here (13B)](https://drive.google.com/file/d/1u9FRRBB3VovP1HxCAlpD9Lw4t4P6-Yq8/view?usp=share_link) or [here (7B)](https://drive.google.com/file/d/1HihQtCEXUyBM1i9DQbaK934wW3TZi-h5/view?usp=share_link).
Compared to the model after stage two, this checkpoint generate incomplete and repeated sentences frequently.
**2. Second finetuning stage**
In the second stage, we use a small high quality image-text pair dataset created by ourselves
and convert it to a conversation format to further align MiniGPT-4.
To download and prepare our second stage dataset, please check our
[second stage dataset preparation instruction](dataset/README_2_STAGE.md).
To launch the second stage alignment,
first specify the path to the checkpoint file trained in stage 1 in
[train_configs/minigpt4_stage1_pretrain.yaml](train_configs/minigpt4_stage2_finetune.yaml).
You can also specify the output path there.
Then, run the following command. In our experiments, we use 1 A100.
```bash
torchrun --nproc-per-node NUM_GPU train.py --cfg-path train_configs/minigpt4_stage2_finetune.yaml
```
After the second stage alignment, MiniGPT-4 is able to talk about the image coherently and user-friendly.
For training details of MiniGPT-4, check [here](MiniGPT4_Train.md).
@ -156,10 +154,19 @@ After the second stage alignment, MiniGPT-4 is able to talk about the image cohe
+ [BLIP2](https://huggingface.co/docs/transformers/main/model_doc/blip-2) The model architecture of MiniGPT-4 follows BLIP-2. Don't forget to check this great open-source work if you don't know it before!
+ [Lavis](https://github.com/salesforce/LAVIS) This repository is built upon Lavis!
+ [Vicuna](https://github.com/lm-sys/FastChat) The fantastic language ability of Vicuna with only 13B parameters is just amazing. And it is open-source!
+ [LLaMA](https://github.com/facebookresearch/llama) The strong open-sourced LLaMA 2 language model.
If you're using MiniGPT-4 in your research or applications, please cite using this BibTeX:
If you're using MiniGPT-4/MiniGPT-v2 in your research or applications, please cite using this BibTeX:
```bibtex
@article{Chen2023minigpt,
title={MiniGPT-v2: Large Language Model as a Unified Interface for Vision-Language Multi-task Learning},
author={Chen, Jun and Zhu, Deyao and Shen, Xiaoqian and Li, Xiang and Liu, Zechu and Zhang, Pengchuan and Krishnamoorthi, Raghuraman and Chandra, Vikas and Xiong, Yunyang and Elhoseiny, Mohamed},
journal={github},
year={2023}
}
@article{zhu2023minigpt,
title={MiniGPT-4: Enhancing Vision-Language Understanding with Advanced Large Language Models},
author={Zhu, Deyao and Chen, Jun and Shen, Xiaoqian and Li, Xiang and Elhoseiny, Mohamed},

16
demo.py
View File

@ -7,10 +7,12 @@ import torch
import torch.backends.cudnn as cudnn
import gradio as gr
from transformers import StoppingCriteriaList
from minigpt4.common.config import Config
from minigpt4.common.dist_utils import get_rank
from minigpt4.common.registry import registry
from minigpt4.conversation.conversation import Chat, CONV_VISION_Vicuna0, CONV_VISION_LLama2
from minigpt4.conversation.conversation import Chat, CONV_VISION_Vicuna0, CONV_VISION_LLama2, StoppingCriteriaSub
# imports modules for registration
from minigpt4.datasets.builders import *
@ -66,7 +68,12 @@ CONV_VISION = conv_dict[model_config.model_type]
vis_processor_cfg = cfg.datasets_cfg.cc_sbu_align.vis_processor.train
vis_processor = registry.get_processor_class(vis_processor_cfg.name).from_config(vis_processor_cfg)
chat = Chat(model, vis_processor, device='cuda:{}'.format(args.gpu_id))
stop_words_ids = [[835], [2277, 29937]]
stop_words_ids = [torch.tensor(ids).to(device='cuda:{}'.format(args.gpu_id)) for ids in stop_words_ids]
stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)])
chat = Chat(model, vis_processor, device='cuda:{}'.format(args.gpu_id), stopping_criteria=stopping_criteria)
print('Initialization Finished')
@ -89,6 +96,7 @@ def upload_img(gr_img, text_input, chat_state):
chat_state = CONV_VISION.copy()
img_list = []
llm_message = chat.upload_img(gr_img, chat_state, img_list)
chat.encode_img(img_list)
return gr.update(interactive=False), gr.update(interactive=True, placeholder='Type and press Enter'), gr.update(value="Start Chatting", interactive=False), chat_state, img_list
@ -124,7 +132,7 @@ with gr.Blocks() as demo:
gr.Markdown(article)
with gr.Row():
with gr.Column(scale=0.5):
with gr.Column(scale=1):
image = gr.Image(type="pil")
upload_button = gr.Button(value="Upload & Start Chat", interactive=True, variant="primary")
clear = gr.Button("Restart")
@ -147,7 +155,7 @@ with gr.Blocks() as demo:
label="Temperature",
)
with gr.Column():
with gr.Column(scale=2):
chat_state = gr.State()
img_list = gr.State()
chatbot = gr.Chatbot(label='MiniGPT-4')

662
demo_v2.py Normal file
View File

@ -0,0 +1,662 @@
import argparse
import os
import random
from collections import defaultdict
import cv2
import re
import numpy as np
from PIL import Image
import torch
import html
import gradio as gr
import torchvision.transforms as T
import torch.backends.cudnn as cudnn
from minigpt4.common.config import Config
from minigpt4.common.registry import registry
from minigpt4.conversation.conversation import Conversation, SeparatorStyle, Chat
# imports modules for registration
from minigpt4.datasets.builders import *
from minigpt4.models import *
from minigpt4.processors import *
from minigpt4.runners import *
from minigpt4.tasks import *
def parse_args():
parser = argparse.ArgumentParser(description="Demo")
parser.add_argument("--cfg-path", default='eval_configs/minigptv2_eval.yaml',
help="path to configuration file.")
parser.add_argument("--gpu-id", type=int, default=0, help="specify the gpu to load the model.")
parser.add_argument(
"--options",
nargs="+",
help="override some settings in the used config, the key-value pair "
"in xxx=yyy format will be merged into config file (deprecate), "
"change to --cfg-options instead.",
)
args = parser.parse_args()
return args
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)
cudnn.benchmark = False
cudnn.deterministic = True
print('Initializing Chat')
args = parse_args()
cfg = Config(args)
device = 'cuda:{}'.format(args.gpu_id)
model_config = cfg.model_cfg
model_config.device_8bit = args.gpu_id
model_cls = registry.get_model_class(model_config.arch)
model = model_cls.from_config(model_config).to(device)
bounding_box_size = 100
vis_processor_cfg = cfg.datasets_cfg.cc_sbu_align.vis_processor.train
vis_processor = registry.get_processor_class(vis_processor_cfg.name).from_config(vis_processor_cfg)
model = model.eval()
CONV_VISION = Conversation(
system="",
roles=(r"<s>[INST] ", r" [/INST]"),
messages=[],
offset=2,
sep_style=SeparatorStyle.SINGLE,
sep="",
)
def extract_substrings(string):
# first check if there is no-finished bracket
index = string.rfind('}')
if index != -1:
string = string[:index + 1]
pattern = r'<p>(.*?)\}(?!<)'
matches = re.findall(pattern, string)
substrings = [match for match in matches]
return substrings
def is_overlapping(rect1, rect2):
x1, y1, x2, y2 = rect1
x3, y3, x4, y4 = rect2
return not (x2 < x3 or x1 > x4 or y2 < y3 or y1 > y4)
def computeIoU(bbox1, bbox2):
x1, y1, x2, y2 = bbox1
x3, y3, x4, y4 = bbox2
intersection_x1 = max(x1, x3)
intersection_y1 = max(y1, y3)
intersection_x2 = min(x2, x4)
intersection_y2 = min(y2, y4)
intersection_area = max(0, intersection_x2 - intersection_x1 + 1) * max(0, intersection_y2 - intersection_y1 + 1)
bbox1_area = (x2 - x1 + 1) * (y2 - y1 + 1)
bbox2_area = (x4 - x3 + 1) * (y4 - y3 + 1)
union_area = bbox1_area + bbox2_area - intersection_area
iou = intersection_area / union_area
return iou
def save_tmp_img(visual_img):
file_name = "".join([str(random.randint(0, 9)) for _ in range(5)]) + ".jpg"
file_path = "/tmp/" + file_name
visual_img.save(file_path)
return file_path
def mask2bbox(mask):
if mask is None:
return ''
mask = mask.resize([100, 100], resample=Image.NEAREST)
mask = np.array(mask)[:, :, 0]
rows = np.any(mask, axis=1)
cols = np.any(mask, axis=0)
if rows.sum():
# Get the top, bottom, left, and right boundaries
rmin, rmax = np.where(rows)[0][[0, -1]]
cmin, cmax = np.where(cols)[0][[0, -1]]
bbox = '{{<{}><{}><{}><{}>}}'.format(cmin, rmin, cmax, rmax)
else:
bbox = ''
return bbox
def escape_markdown(text):
# List of Markdown special characters that need to be escaped
md_chars = ['<', '>']
# Escape each special character
for char in md_chars:
text = text.replace(char, '\\' + char)
return text
def reverse_escape(text):
md_chars = ['\\<', '\\>']
for char in md_chars:
text = text.replace(char, char[1:])
return text
colors = [
(255, 0, 0),
(0, 255, 0),
(0, 0, 255),
(210, 210, 0),
(255, 0, 255),
(0, 255, 255),
(114, 128, 250),
(0, 165, 255),
(0, 128, 0),
(144, 238, 144),
(238, 238, 175),
(255, 191, 0),
(0, 128, 0),
(226, 43, 138),
(255, 0, 255),
(0, 215, 255),
]
color_map = {
f"{color_id}": f"#{hex(color[2])[2:].zfill(2)}{hex(color[1])[2:].zfill(2)}{hex(color[0])[2:].zfill(2)}" for
color_id, color in enumerate(colors)
}
used_colors = colors
def visualize_all_bbox_together(image, generation):
if image is None:
return None, ''
generation = html.unescape(generation)
print('gen begin', generation)
image_width, image_height = image.size
image = image.resize([500, int(500 / image_width * image_height)])
image_width, image_height = image.size
string_list = extract_substrings(generation)
if string_list: # it is grounding or detection
mode = 'all'
entities = defaultdict(list)
i = 0
j = 0
for string in string_list:
try:
obj, string = string.split('</p>')
except ValueError:
print('wrong string: ', string)
continue
bbox_list = string.split('<delim>')
flag = False
for bbox_string in bbox_list:
integers = re.findall(r'-?\d+', bbox_string)
if len(integers) == 4:
x0, y0, x1, y1 = int(integers[0]), int(integers[1]), int(integers[2]), int(integers[3])
left = x0 / bounding_box_size * image_width
bottom = y0 / bounding_box_size * image_height
right = x1 / bounding_box_size * image_width
top = y1 / bounding_box_size * image_height
entities[obj].append([left, bottom, right, top])
j += 1
flag = True
if flag:
i += 1
else:
integers = re.findall(r'-?\d+', generation)
if len(integers) == 4: # it is refer
mode = 'single'
entities = list()
x0, y0, x1, y1 = int(integers[0]), int(integers[1]), int(integers[2]), int(integers[3])
left = x0 / bounding_box_size * image_width
bottom = y0 / bounding_box_size * image_height
right = x1 / bounding_box_size * image_width
top = y1 / bounding_box_size * image_height
entities.append([left, bottom, right, top])
else:
# don't detect any valid bbox to visualize
return None, ''
if len(entities) == 0:
return None, ''
if isinstance(image, Image.Image):
image_h = image.height
image_w = image.width
image = np.array(image)
elif isinstance(image, str):
if os.path.exists(image):
pil_img = Image.open(image).convert("RGB")
image = np.array(pil_img)[:, :, [2, 1, 0]]
image_h = pil_img.height
image_w = pil_img.width
else:
raise ValueError(f"invaild image path, {image}")
elif isinstance(image, torch.Tensor):
image_tensor = image.cpu()
reverse_norm_mean = torch.tensor([0.48145466, 0.4578275, 0.40821073])[:, None, None]
reverse_norm_std = torch.tensor([0.26862954, 0.26130258, 0.27577711])[:, None, None]
image_tensor = image_tensor * reverse_norm_std + reverse_norm_mean
pil_img = T.ToPILImage()(image_tensor)
image_h = pil_img.height
image_w = pil_img.width
image = np.array(pil_img)[:, :, [2, 1, 0]]
else:
raise ValueError(f"invaild image format, {type(image)} for {image}")
indices = list(range(len(entities)))
new_image = image.copy()
previous_bboxes = []
# size of text
text_size = 0.5
# thickness of text
text_line = 1 # int(max(1 * min(image_h, image_w) / 512, 1))
box_line = 2
(c_width, text_height), _ = cv2.getTextSize("F", cv2.FONT_HERSHEY_COMPLEX, text_size, text_line)
base_height = int(text_height * 0.675)
text_offset_original = text_height - base_height
text_spaces = 2
# num_bboxes = sum(len(x[-1]) for x in entities)
used_colors = colors # random.sample(colors, k=num_bboxes)
color_id = -1
for entity_idx, entity_name in enumerate(entities):
if mode == 'single' or mode == 'identify':
bboxes = entity_name
bboxes = [bboxes]
else:
bboxes = entities[entity_name]
color_id += 1
for bbox_id, (x1_norm, y1_norm, x2_norm, y2_norm) in enumerate(bboxes):
skip_flag = False
orig_x1, orig_y1, orig_x2, orig_y2 = int(x1_norm), int(y1_norm), int(x2_norm), int(y2_norm)
color = used_colors[entity_idx % len(used_colors)] # tuple(np.random.randint(0, 255, size=3).tolist())
new_image = cv2.rectangle(new_image, (orig_x1, orig_y1), (orig_x2, orig_y2), color, box_line)
if mode == 'all':
l_o, r_o = box_line // 2 + box_line % 2, box_line // 2 + box_line % 2 + 1
x1 = orig_x1 - l_o
y1 = orig_y1 - l_o
if y1 < text_height + text_offset_original + 2 * text_spaces:
y1 = orig_y1 + r_o + text_height + text_offset_original + 2 * text_spaces
x1 = orig_x1 + r_o
# add text background
(text_width, text_height), _ = cv2.getTextSize(f" {entity_name}", cv2.FONT_HERSHEY_COMPLEX, text_size,
text_line)
text_bg_x1, text_bg_y1, text_bg_x2, text_bg_y2 = x1, y1 - (
text_height + text_offset_original + 2 * text_spaces), x1 + text_width, y1
for prev_bbox in previous_bboxes:
if computeIoU((text_bg_x1, text_bg_y1, text_bg_x2, text_bg_y2), prev_bbox['bbox']) > 0.95 and \
prev_bbox['phrase'] == entity_name:
skip_flag = True
break
while is_overlapping((text_bg_x1, text_bg_y1, text_bg_x2, text_bg_y2), prev_bbox['bbox']):
text_bg_y1 += (text_height + text_offset_original + 2 * text_spaces)
text_bg_y2 += (text_height + text_offset_original + 2 * text_spaces)
y1 += (text_height + text_offset_original + 2 * text_spaces)
if text_bg_y2 >= image_h:
text_bg_y1 = max(0, image_h - (text_height + text_offset_original + 2 * text_spaces))
text_bg_y2 = image_h
y1 = image_h
break
if not skip_flag:
alpha = 0.5
for i in range(text_bg_y1, text_bg_y2):
for j in range(text_bg_x1, text_bg_x2):
if i < image_h and j < image_w:
if j < text_bg_x1 + 1.35 * c_width:
# original color
bg_color = color
else:
# white
bg_color = [255, 255, 255]
new_image[i, j] = (alpha * new_image[i, j] + (1 - alpha) * np.array(bg_color)).astype(
np.uint8)
cv2.putText(
new_image, f" {entity_name}", (x1, y1 - text_offset_original - 1 * text_spaces),
cv2.FONT_HERSHEY_COMPLEX, text_size, (0, 0, 0), text_line, cv2.LINE_AA
)
previous_bboxes.append(
{'bbox': (text_bg_x1, text_bg_y1, text_bg_x2, text_bg_y2), 'phrase': entity_name})
if mode == 'all':
def color_iterator(colors):
while True:
for color in colors:
yield color
color_gen = color_iterator(colors)
# Add colors to phrases and remove <p></p>
def colored_phrases(match):
phrase = match.group(1)
color = next(color_gen)
return f'<span style="color:rgb{color}">{phrase}</span>'
print('gen before', generation)
generation = re.sub(r'{<\d+><\d+><\d+><\d+>}|<delim>', '', generation)
print('gen after', generation)
generation_colored = re.sub(r'<p>(.*?)</p>', colored_phrases, generation)
else:
generation_colored = ''
pil_image = Image.fromarray(new_image)
return pil_image, generation_colored
def gradio_reset(chat_state, img_list):
if chat_state is not None:
chat_state.messages = []
if img_list is not None:
img_list = []
return None, gr.update(value=None, interactive=True), gr.update(placeholder='Upload your image and chat',
interactive=True), chat_state, img_list
def image_upload_trigger(upload_flag, replace_flag, img_list):
# set the upload flag to true when receive a new image.
# if there is an old image (and old conversation), set the replace flag to true to reset the conv later.
print('flag', upload_flag, replace_flag)
print("SET UPLOAD FLAG!")
upload_flag = 1
if img_list:
print("SET REPLACE FLAG!")
replace_flag = 1
print('flag', upload_flag, replace_flag)
return upload_flag, replace_flag
def example_trigger(text_input, image, upload_flag, replace_flag, img_list):
# set the upload flag to true when receive a new image.
# if there is an old image (and old conversation), set the replace flag to true to reset the conv later.
print('flag', upload_flag, replace_flag)
print("SET UPLOAD FLAG!")
upload_flag = 1
if img_list or replace_flag == 1:
print("SET REPLACE FLAG!")
replace_flag = 1
print('flag', upload_flag, replace_flag)
return upload_flag, replace_flag
def gradio_ask(user_message, chatbot, chat_state, gr_img, img_list, upload_flag, replace_flag):
if isinstance(gr_img, dict):
gr_img, mask = gr_img['image'], gr_img['mask']
else:
mask = None
if '[identify]' in user_message:
# check if user provide bbox in the text input
integers = re.findall(r'-?\d+', user_message)
if len(integers) != 4: # no bbox in text
bbox = mask2bbox(mask)
user_message = user_message + bbox
if len(user_message) == 0:
return gr.update(interactive=True, placeholder='Input should not be empty!'), chatbot, chat_state
if chat_state is None:
chat_state = CONV_VISION.copy()
print('upload flag: {}'.format(upload_flag))
if upload_flag:
if replace_flag:
print('RESET!!!!!!!')
chat_state = CONV_VISION.copy() # new image, reset everything
replace_flag = 0
chatbot = []
print('UPLOAD IMAGE!!')
img_list = []
llm_message = chat.upload_img(gr_img, chat_state, img_list)
upload_flag = 0
chat.ask(user_message, chat_state)
chatbot = chatbot + [[user_message, None]]
if '[identify]' in user_message:
visual_img, _ = visualize_all_bbox_together(gr_img, user_message)
if visual_img is not None:
print('Visualizing the input')
file_path = save_tmp_img(visual_img)
chatbot = chatbot + [[(file_path,), None]]
return '', chatbot, chat_state, img_list, upload_flag, replace_flag
def gradio_answer(chatbot, chat_state, img_list, temperature):
llm_message = chat.answer(conv=chat_state,
img_list=img_list,
temperature=temperature,
max_new_tokens=500,
max_length=2000)[0]
chatbot[-1][1] = llm_message
return chatbot, chat_state
def gradio_stream_answer(chatbot, chat_state, img_list, temperature):
print('chat state', chat_state.get_prompt())
if not isinstance(img_list[0], torch.Tensor):
chat.encode_img(img_list)
streamer = chat.stream_answer(conv=chat_state,
img_list=img_list,
temperature=temperature,
max_new_tokens=500,
max_length=2000)
output = ''
for new_output in streamer:
escapped = escape_markdown(new_output)
output += escapped
chatbot[-1][1] = output
yield chatbot, chat_state
# print('message: ', chat_state.messages)
chat_state.messages[-1][1] = '</s>'
return chatbot, chat_state
def gradio_visualize(chatbot, gr_img):
if isinstance(gr_img, dict):
gr_img, mask = gr_img['image'], gr_img['mask']
unescaped = reverse_escape(chatbot[-1][1])
visual_img, generation_color = visualize_all_bbox_together(gr_img, unescaped)
if visual_img is not None:
print('Visualizing the output')
if len(generation_color):
chatbot[-1][1] = generation_color
file_path = save_tmp_img(visual_img)
chatbot = chatbot + [[None, (file_path,)]]
return chatbot
def gradio_taskselect(idx):
prompt_list = [
'',
'[grounding] describe this image in detail',
'[refer] ',
'[detection] ',
'[identify] what is this ',
'[vqa] '
]
instruct_list = [
'**Hint:** Type in whatever you want',
'**Hint:** Send the command to generate a grounded image description',
'**Hint:** Type in a phrase about an object in the image and send the command',
'**Hint:** Type in a caption or phrase, and see object locations in the image',
'**Hint:** Draw a bounding box on the uploaded image then send the command. Click the "clear" botton on the top right of the image before redraw',
'**Hint:** Send a question to get a short answer',
]
return prompt_list[idx], instruct_list[idx]
chat = Chat(model, vis_processor, device=device)
title = """<h1 align="center">MiniGPT-v2 Demo</h1>"""
description = 'Welcome to Our MiniGPT-v2 Chatbot Demo!'
# article = """<p><a href='https://minigpt-v2.github.io'><img src='https://img.shields.io/badge/Project-Page-Green'></a></p><p><a href='https://github.com/Vision-CAIR/MiniGPT-4/blob/main/MiniGPTv2.pdf'><img src='https://img.shields.io/badge/Paper-PDF-red'></a></p><p><a href='https://github.com/Vision-CAIR/MiniGPT-4'><img src='https://img.shields.io/badge/GitHub-Repo-blue'></a></p><p><a href='https://www.youtube.com/watch?v=atFCwV2hSY4'><img src='https://img.shields.io/badge/YouTube-Video-red'></a></p>"""
article = """<p><a href='https://minigpt-v2.github.io'><img src='https://img.shields.io/badge/Project-Page-Green'></a></p>"""
introduction = '''
For Abilities Involving Visual Grounding:
1. Grounding: CLICK **Send** to generate a grounded image description.
2. Refer: Input a referring object and CLICK **Send**.
3. Detection: Write a caption or phrase, and CLICK **Send**.
4. Identify: Draw the bounding box on the uploaded image window and CLICK **Send** to generate the bounding box. (CLICK "clear" button before re-drawing next time).
5. VQA: Input a visual question and CLICK **Send**.
6. No Tag: Input whatever you want and CLICK **Send** without any tagging
You can also simply chat in free form!
'''
text_input = gr.Textbox(placeholder='Upload your image and chat', interactive=True, show_label=False, container=False,
scale=8)
with gr.Blocks() as demo:
gr.Markdown(title)
# gr.Markdown(description)
gr.Markdown(article)
with gr.Row():
with gr.Column(scale=0.5):
image = gr.Image(type="pil", tool='sketch', brush_radius=20)
temperature = gr.Slider(
minimum=0.1,
maximum=2.0,
value=1.0,
step=0.1,
interactive=True,
label="Temperature",
)
clear = gr.Button("Restart")
gr.Markdown(introduction)
with gr.Column():
chat_state = gr.State(value=None)
img_list = gr.State(value=[])
chatbot = gr.Chatbot(label='MiniGPT-v2')
dataset = gr.Dataset(
components=[gr.Textbox(visible=False)],
samples=[['No Tag'], ['Grounding'], ['Refer'], ['Detection'], ['Identify'], ['VQA']],
type="index",
label='Task Shortcuts',
)
task_inst = gr.Markdown('**Hint:** Upload your image and chat')
with gr.Row():
text_input.render()
send = gr.Button("Send", variant='primary', size='sm', scale=1)
upload_flag = gr.State(value=0)
replace_flag = gr.State(value=0)
image.upload(image_upload_trigger, [upload_flag, replace_flag, img_list], [upload_flag, replace_flag])
with gr.Row():
with gr.Column():
gr.Examples(examples=[
["examples_v2/office.jpg", "[grounding] describe this image in detail", upload_flag, replace_flag,
img_list],
["examples_v2/sofa.jpg", "[detection] sofas", upload_flag, replace_flag, img_list],
["examples_v2/2000x1372_wmkn_0012149409555.jpg", "[refer] the world cup", upload_flag, replace_flag,
img_list],
["examples_v2/KFC-20-for-20-Nuggets.jpg", "[identify] what is this {<4><50><30><65>}", upload_flag,
replace_flag, img_list],
], inputs=[image, text_input, upload_flag, replace_flag, img_list], fn=example_trigger,
outputs=[upload_flag, replace_flag])
with gr.Column():
gr.Examples(examples=[
["examples_v2/glip_test.jpg", "[vqa] where should I hide in this room when playing hide and seek",
upload_flag, replace_flag, img_list],
["examples_v2/float.png", "Please write a poem about the image", upload_flag, replace_flag, img_list],
["examples_v2/thief.png", "Is the weapon fateful", upload_flag, replace_flag, img_list],
["examples_v2/cockdial.png", "What might happen in this image in the next second", upload_flag,
replace_flag, img_list],
], inputs=[image, text_input, upload_flag, replace_flag, img_list], fn=example_trigger,
outputs=[upload_flag, replace_flag])
dataset.click(
gradio_taskselect,
inputs=[dataset],
outputs=[text_input, task_inst],
show_progress="hidden",
postprocess=False,
queue=False,
)
text_input.submit(
gradio_ask,
[text_input, chatbot, chat_state, image, img_list, upload_flag, replace_flag],
[text_input, chatbot, chat_state, img_list, upload_flag, replace_flag], queue=False
).success(
gradio_stream_answer,
[chatbot, chat_state, img_list, temperature],
[chatbot, chat_state]
).success(
gradio_visualize,
[chatbot, image],
[chatbot],
queue=False,
)
send.click(
gradio_ask,
[text_input, chatbot, chat_state, image, img_list, upload_flag, replace_flag],
[text_input, chatbot, chat_state, img_list, upload_flag, replace_flag], queue=False
).success(
gradio_stream_answer,
[chatbot, chat_state, img_list, temperature],
[chatbot, chat_state]
).success(
gradio_visualize,
[chatbot, image],
[chatbot],
queue=False,
)
clear.click(gradio_reset, [chat_state, img_list], [chatbot, image, text_input, chat_state, img_list], queue=False)
demo.launch(share=True, enable_queue=True)

View File

@ -7,57 +7,27 @@ dependencies:
- python=3.9
- cudatoolkit
- pip
- pytorch=1.12.1
- pytorch-mutex=1.0=cuda
- torchaudio=0.12.1
- torchvision=0.13.1
- pip:
- accelerate==0.16.0
- aiohttp==3.8.4
- aiosignal==1.3.1
- async-timeout==4.0.2
- attrs==22.2.0
- bitsandbytes==0.37.0
- cchardet==2.1.7
- chardet==5.1.0
- contourpy==1.0.7
- cycler==0.11.0
- filelock==3.9.0
- fonttools==4.38.0
- frozenlist==1.3.3
- huggingface-hub==0.13.4
- importlib-resources==5.12.0
- kiwisolver==1.4.4
- torch==2.0.0
- torchaudio
- torchvision
- huggingface-hub==0.18.0
- matplotlib==3.7.0
- multidict==6.0.4
- openai==0.27.0
- packaging==23.0
- psutil==5.9.4
- pycocotools==2.0.6
- pyparsing==3.0.9
- python-dateutil==2.8.2
- iopath
- pyyaml==6.0
- regex==2022.10.31
- tokenizers==0.13.2
- tqdm==4.64.1
- transformers==4.28.0
- transformers==4.30.0
- timm==0.6.13
- spacy==3.5.1
- webdataset==0.2.48
- scikit-learn==1.2.2
- scipy==1.10.1
- yarl==1.8.2
- zipp==3.14.0
- omegaconf==2.3.0
- opencv-python==4.7.0.72
- iopath==0.1.10
- decord==0.6.0
- tenacity==8.2.2
- peft
- pycocoevalcap
- peft==0.2.0
- sentence-transformers
- umap-learn
- notebook
- gradio==3.24.1
- gradio-client==0.0.8
- gradio==3.47.1
- accelerate==0.20.3
- bitsandbytes==0.37.0
- wandb

View File

@ -1,11 +1,11 @@
model:
arch: mini_gpt4
arch: minigpt4
model_type: pretrain_vicuna0
max_txt_len: 160
end_sym: "###"
low_resource: True
prompt_template: '###Human: {} ###Assistant: '
ckpt: '/path/to/checkpoint/'
ckpt: 'please set this value to the path of pretrained checkpoint'
datasets:

View File

@ -1,11 +1,11 @@
model:
arch: mini_gpt4
arch: minigpt4
model_type: pretrain_llama2
max_txt_len: 160
end_sym: "</s>"
low_resource: True
prompt_template: '[INST] {} [/INST] '
ckpt: '/path/to/checkpoint/'
ckpt: 'please set this value to the path of pretrained checkpoint'
datasets:

View File

@ -0,0 +1,24 @@
model:
arch: minigpt_v2
model_type: pretrain
max_txt_len: 160
end_sym: "</s>"
low_resource: True
prompt_template: '[INST] {} [/INST]'
ckpt: 'please set this value to the path of pretrained checkpoint'
lora_r: 64
lora_alpha: 16
datasets:
cc_sbu_align:
vis_processor:
train:
name: "blip2_image_eval"
image_size: 448
text_processor:
train:
name: "blip_caption"
run:
task: image_text_pretrain

Binary file not shown.

After

Width:  |  Height:  |  Size: 91 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 83 KiB

BIN
examples_v2/cockdial.png Executable file

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.5 MiB

BIN
examples_v2/float.png Executable file

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.2 MiB

BIN
examples_v2/glip_test.jpg Executable file

Binary file not shown.

After

Width:  |  Height:  |  Size: 92 KiB

BIN
examples_v2/office.jpg Executable file

Binary file not shown.

After

Width:  |  Height:  |  Size: 25 KiB

BIN
examples_v2/sofa.jpg Executable file

Binary file not shown.

After

Width:  |  Height:  |  Size: 116 KiB

BIN
examples_v2/thief.png Executable file

Binary file not shown.

After

Width:  |  Height:  |  Size: 865 KiB

BIN
figs/demo.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.1 MiB

View File

@ -1,5 +1,5 @@
model:
arch: mini_gpt4
arch: minigpt4
# vit encoder
image_size: 224
@ -12,7 +12,7 @@ model:
# generation configs
prompt: ""
llama_model: "/path/to/llama2/weight"
llama_model: "please set this value to the path of llama2-chat-7b"
preprocess:
vis_processor:

View File

@ -1,5 +1,5 @@
model:
arch: mini_gpt4
arch: minigpt4
# vit encoder
image_size: 224
@ -15,7 +15,7 @@ model:
# generation configs
prompt: ""
llama_model: "/path/to/vicuna/weight"
llama_model: "please set this value to the path of vicuna model"
preprocess:
vis_processor:

View File

@ -0,0 +1,31 @@
model:
arch: minigpt_v2
# vit encoder
image_size: 448
drop_path_rate: 0
use_grad_checkpoint: False
vit_precision: "fp16"
freeze_vit: True
# generation configs
prompt: ""
llama_model: "please set this value to the path of llama2-chat-7b"
lora_r: 64
lora_alpha: 16
preprocess:
vis_processor:
train:
name: "blip2_image_train"
image_size: 448
eval:
name: "blip2_image_eval"
image_size: 448
text_processor:
train:
name: "blip_caption"
eval:
name: "blip_caption"

View File

@ -1,10 +1,11 @@
import argparse
import time
from threading import Thread
from PIL import Image
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, LlamaTokenizer
from transformers import StoppingCriteria, StoppingCriteriaList
from transformers import StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer
import dataclasses
from enum import auto, Enum
@ -129,13 +130,16 @@ CONV_VISION_LLama2 = Conversation(
class Chat:
def __init__(self, model, vis_processor, device='cuda:0'):
def __init__(self, model, vis_processor, device='cuda:0', stopping_criteria=None):
self.device = device
self.model = model
self.vis_processor = vis_processor
stop_words_ids = [torch.tensor([835]).to(self.device),
torch.tensor([2277, 29937]).to(self.device)] # '###' can be encoded in two different ways.
self.stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)])
if stopping_criteria is not None:
self.stopping_criteria = stopping_criteria
else:
stop_words_ids = [torch.tensor([2]).to(self.device)]
self.stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)])
def ask(self, text, conv):
if len(conv.messages) > 0 and conv.messages[-1][0] == conv.roles[0] \
@ -144,8 +148,8 @@ class Chat:
else:
conv.append_message(conv.roles[0], text)
def answer(self, conv, img_list, max_new_tokens=300, num_beams=1, min_length=1, top_p=0.9,
repetition_penalty=1.0, length_penalty=1, temperature=1.0, max_length=2000):
def answer_prepare(self, conv, img_list, max_new_tokens=300, num_beams=1, min_length=1, top_p=0.9,
repetition_penalty=1.05, length_penalty=1, temperature=1.0, max_length=2000):
conv.append_message(conv.roles[1], None)
embs = self.get_context_emb(conv, img_list)
@ -154,10 +158,9 @@ class Chat:
print('Warning: The number of tokens in current conversation exceeds the max length. '
'The model will not see the contexts outside the range.')
begin_idx = max(0, current_max_len - max_length)
embs = embs[:, begin_idx:]
outputs = self.model.llama_model.generate(
generation_kwargs = dict(
inputs_embeds=embs,
max_new_tokens=max_new_tokens,
stopping_criteria=self.stopping_criteria,
@ -169,18 +172,31 @@ class Chat:
length_penalty=length_penalty,
temperature=temperature,
)
output_token = outputs[0]
if output_token[0] == 0: # the model might output a unknow token <unk> at the beginning. remove it
output_token = output_token[1:]
if output_token[0] == 1: # some users find that there is a start token <s> at the beginning. remove it
output_token = output_token[1:]
output_text = self.model.llama_tokenizer.decode(output_token, add_special_tokens=False)
return generation_kwargs
def answer(self, conv, img_list, **kargs):
generation_dict = self.answer_prepare(conv, img_list, **kargs)
output_token = self.model.llama_model.generate(**generation_dict)[0]
output_text = self.model.llama_tokenizer.decode(output_token, skip_special_tokens=True)
output_text = output_text.split('###')[0] # remove the stop sign '###'
output_text = output_text.split('Assistant:')[-1].strip()
conv.messages[-1][1] = output_text
return output_text, output_token.cpu().numpy()
def upload_img(self, image, conv, img_list):
def stream_answer(self, conv, img_list, **kargs):
generation_kwargs = self.answer_prepare(conv, img_list, **kargs)
streamer = TextIteratorStreamer(self.model.llama_tokenizer, skip_special_tokens=True)
generation_kwargs['streamer'] = streamer
thread = Thread(target=self.model.llama_model.generate, kwargs=generation_kwargs)
thread.start()
return streamer
def encode_img(self, img_list):
image = img_list[0]
img_list.pop(0)
if isinstance(image, str): # is a image path
raw_image = Image.open(image).convert('RGB')
image = self.vis_processor(raw_image).unsqueeze(0).to(self.device)
@ -194,9 +210,12 @@ class Chat:
image_emb, _ = self.model.encode_img(image)
img_list.append(image_emb)
def upload_img(self, image, conv, img_list):
conv.append_message(conv.roles[0], "<Img><ImageHere></Img>")
img_list.append(image)
msg = "Received."
# self.conv.append_message(self.conv.roles[1], msg)
return msg
def get_context_emb(self, conv, img_list):
@ -209,7 +228,9 @@ class Chat:
# only add bos to the first seg
for i, seg in enumerate(prompt_segs)
]
seg_embs = [self.model.llama_model.model.embed_tokens(seg_t) for seg_t in seg_tokens]
print('debug device: ', self.device)
print('debug model device: ', self.model.device)
seg_embs = [self.model.embed_tokens(seg_t) for seg_t in seg_tokens]
mixed_embs = [emb for pair in zip(seg_embs[:-1], img_list) for emb in pair] + [seg_embs[-1]]
mixed_embs = torch.cat(mixed_embs, dim=1)
return mixed_embs

View File

@ -11,16 +11,18 @@ from omegaconf import OmegaConf
from minigpt4.common.registry import registry
from minigpt4.models.base_model import BaseModel
from minigpt4.models.blip2 import Blip2Base
from minigpt4.models.mini_gpt4 import MiniGPT4
from minigpt4.models.minigpt_base import MiniGPTBase
from minigpt4.models.minigpt4 import MiniGPT4
from minigpt4.models.minigpt_v2 import MiniGPTv2
from minigpt4.processors.base_processor import BaseProcessor
__all__ = [
"load_model",
"BaseModel",
"Blip2Base",
"MiniGPTBase",
"MiniGPT4",
"MiniGPTv2"
]

View File

@ -5,15 +5,26 @@
For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""
import logging
import os
import logging
import contextlib
from omegaconf import OmegaConf
import numpy as np
import torch
import torch.nn as nn
from transformers import BertTokenizer, LlamaTokenizer
from transformers.models.llama.modeling_llama import LlamaForCausalLM
from peft import (
LoraConfig,
get_peft_model,
prepare_model_for_int8_training,
)
from minigpt4.common.dist_utils import download_cached_file, is_dist_avail_and_initialized
from minigpt4.common.utils import get_abs_path, is_url
from omegaconf import OmegaConf
from minigpt4.models.eva_vit import create_eva_vit_g
class BaseModel(nn.Module):
@ -117,131 +128,121 @@ class BaseModel(nn.Module):
else:
return tot
def maybe_autocast(self, dtype=torch.float16):
# if on cpu, don't use autocast
# if on gpu, use autocast with dtype if provided, otherwise use torch.float16
enable_autocast = self.device != torch.device("cpu")
class BaseEncoder(nn.Module):
"""
Base class for primitive encoders, such as ViT, TimeSformer, etc.
"""
if enable_autocast:
return torch.cuda.amp.autocast(dtype=dtype)
else:
return contextlib.nullcontext()
def __init__(self):
super().__init__()
@classmethod
def init_vision_encoder(
cls, model_name, img_size, drop_path_rate, use_grad_checkpoint, precision, freeze
):
logging.info('Loading VIT')
def forward_features(self, samples, **kwargs):
raise NotImplementedError
assert model_name == "eva_clip_g", "vit model must be eva_clip_g for current version of MiniGPT-4"
if not freeze:
precision = "fp32" # fp16 is not for training
@property
def device(self):
return list(self.parameters())[0].device
visual_encoder = create_eva_vit_g(
img_size, drop_path_rate, use_grad_checkpoint, precision
)
ln_vision = LayerNorm(visual_encoder.num_features)
if freeze:
for name, param in visual_encoder.named_parameters():
param.requires_grad = False
visual_encoder = visual_encoder.eval()
visual_encoder.train = disabled_train
for name, param in ln_vision.named_parameters():
param.requires_grad = False
ln_vision = ln_vision.eval()
ln_vision.train = disabled_train
logging.info("freeze vision encoder")
logging.info('Loading VIT Done')
return visual_encoder, ln_vision
def init_llm(cls, llama_model_path, low_resource=False, low_res_device=0, lora_r=0,
lora_target_modules=["q_proj","v_proj"], **lora_kargs):
logging.info('Loading LLAMA')
llama_tokenizer = LlamaTokenizer.from_pretrained(llama_model_path, use_fast=False)
llama_tokenizer.pad_token = "$$"
if low_resource:
llama_model = LlamaForCausalLM.from_pretrained(
llama_model_path,
torch_dtype=torch.float16,
load_in_8bit=True,
device_map={'': low_res_device}
)
else:
llama_model = LlamaForCausalLM.from_pretrained(
llama_model_path,
torch_dtype=torch.float16,
)
if lora_r > 0:
llama_model = prepare_model_for_int8_training(llama_model)
loraconfig = LoraConfig(
r=lora_r,
bias="none",
task_type="CAUSAL_LM",
target_modules=lora_target_modules,
**lora_kargs
)
llama_model = get_peft_model(llama_model, loraconfig)
llama_model.print_trainable_parameters()
else:
for name, param in llama_model.named_parameters():
param.requires_grad = False
logging.info('Loading LLAMA Done')
return llama_model, llama_tokenizer
class SharedQueueMixin:
@torch.no_grad()
def _dequeue_and_enqueue(self, image_feat, text_feat, idxs=None):
# gather keys before updating queue
image_feats = concat_all_gather(image_feat)
text_feats = concat_all_gather(text_feat)
def load_from_pretrained(self, url_or_filename):
if is_url(url_or_filename):
cached_file = download_cached_file(
url_or_filename, check_hash=False, progress=True
)
checkpoint = torch.load(cached_file, map_location="cpu")
elif os.path.isfile(url_or_filename):
checkpoint = torch.load(url_or_filename, map_location="cpu")
else:
raise RuntimeError("checkpoint url or path is invalid")
batch_size = image_feats.shape[0]
state_dict = checkpoint["model"]
ptr = int(self.queue_ptr)
assert self.queue_size % batch_size == 0 # for simplicity
msg = self.load_state_dict(state_dict, strict=False)
# replace the keys at ptr (dequeue and enqueue)
self.image_queue[:, ptr : ptr + batch_size] = image_feats.T
self.text_queue[:, ptr : ptr + batch_size] = text_feats.T
# logging.info("Missing keys {}".format(msg.missing_keys))
logging.info("load checkpoint from %s" % url_or_filename)
if idxs is not None:
idxs = concat_all_gather(idxs)
self.idx_queue[:, ptr : ptr + batch_size] = idxs.T
ptr = (ptr + batch_size) % self.queue_size # move pointer
self.queue_ptr[0] = ptr
return msg
class MomentumDistilationMixin:
@torch.no_grad()
def copy_params(self):
for model_pair in self.model_pairs:
for param, param_m in zip(
model_pair[0].parameters(), model_pair[1].parameters()
):
param_m.data.copy_(param.data) # initialize
param_m.requires_grad = False # not update by gradient
@torch.no_grad()
def _momentum_update(self):
for model_pair in self.model_pairs:
for param, param_m in zip(
model_pair[0].parameters(), model_pair[1].parameters()
):
param_m.data = param_m.data * self.momentum + param.data * (
1.0 - self.momentum
)
def disabled_train(self, mode=True):
"""Overwrite model.train with this function to make sure train/eval mode
does not change anymore."""
return self
class GatherLayer(torch.autograd.Function):
"""
Gather tensors from all workers with support for backward propagation:
This implementation does not cut the gradients as torch.distributed.all_gather does.
"""
class LayerNorm(nn.LayerNorm):
"""Subclass torch's LayerNorm to handle fp16."""
@staticmethod
def forward(ctx, x):
output = [
torch.zeros_like(x) for _ in range(torch.distributed.get_world_size())
]
torch.distributed.all_gather(output, x)
return tuple(output)
@staticmethod
def backward(ctx, *grads):
all_gradients = torch.stack(grads)
torch.distributed.all_reduce(all_gradients)
return all_gradients[torch.distributed.get_rank()]
def forward(self, x: torch.Tensor):
orig_type = x.dtype
ret = super().forward(x.type(torch.float32))
return ret.type(orig_type)
def all_gather_with_grad(tensors):
"""
Performs all_gather operation on the provided tensors.
Graph remains connected for backward grad computation.
"""
# Queue the gathered tensors
world_size = torch.distributed.get_world_size()
# There is no need for reduction in the single-proc case
if world_size == 1:
return tensors
# tensor_all = GatherLayer.apply(tensors)
tensor_all = GatherLayer.apply(tensors)
return torch.cat(tensor_all, dim=0)
@torch.no_grad()
def concat_all_gather(tensor):
"""
Performs all_gather operation on the provided tensors.
*** Warning ***: torch.distributed.all_gather has no gradient.
"""
# if use distributed training
if not is_dist_avail_and_initialized():
return tensor
tensors_gather = [
torch.ones_like(tensor) for _ in range(torch.distributed.get_world_size())
]
torch.distributed.all_gather(tensors_gather, tensor, async_op=False)
output = torch.cat(tensors_gather, dim=0)
return output
def tile(x, dim, n_tile):
init_dim = x.size(dim)
repeat_idx = [1] * x.dim()
repeat_idx[dim] = n_tile
x = x.repeat(*(repeat_idx))
order_index = torch.LongTensor(
np.concatenate([init_dim * np.arange(n_tile) + i for i in range(init_dim)])
)
return torch.index_select(x, dim, order_index.to(x.device))

View File

@ -1,221 +0,0 @@
"""
Copyright (c) 2023, salesforce.com, inc.
All rights reserved.
SPDX-License-Identifier: BSD-3-Clause
For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""
import contextlib
import logging
import os
import time
import datetime
import torch
import torch.nn as nn
import torch.distributed as dist
import torch.nn.functional as F
import minigpt4.common.dist_utils as dist_utils
from minigpt4.common.dist_utils import download_cached_file
from minigpt4.common.utils import is_url
from minigpt4.common.logger import MetricLogger
from minigpt4.models.base_model import BaseModel
from minigpt4.models.Qformer import BertConfig, BertLMHeadModel
from minigpt4.models.eva_vit import create_eva_vit_g
from transformers import BertTokenizer
class Blip2Base(BaseModel):
@classmethod
def init_tokenizer(cls):
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
tokenizer.add_special_tokens({"bos_token": "[DEC]"})
return tokenizer
def maybe_autocast(self, dtype=torch.float16):
# if on cpu, don't use autocast
# if on gpu, use autocast with dtype if provided, otherwise use torch.float16
enable_autocast = self.device != torch.device("cpu")
if enable_autocast:
return torch.cuda.amp.autocast(dtype=dtype)
else:
return contextlib.nullcontext()
@classmethod
def init_Qformer(cls, num_query_token, vision_width, cross_attention_freq=2):
encoder_config = BertConfig.from_pretrained("bert-base-uncased")
encoder_config.encoder_width = vision_width
# insert cross-attention layer every other block
encoder_config.add_cross_attention = True
encoder_config.cross_attention_freq = cross_attention_freq
encoder_config.query_length = num_query_token
Qformer = BertLMHeadModel(config=encoder_config)
query_tokens = nn.Parameter(
torch.zeros(1, num_query_token, encoder_config.hidden_size)
)
query_tokens.data.normal_(mean=0.0, std=encoder_config.initializer_range)
return Qformer, query_tokens
@classmethod
def init_vision_encoder(
cls, model_name, img_size, drop_path_rate, use_grad_checkpoint, precision
):
assert model_name == "eva_clip_g", "vit model must be eva_clip_g for current version of MiniGPT-4"
visual_encoder = create_eva_vit_g(
img_size, drop_path_rate, use_grad_checkpoint, precision
)
ln_vision = LayerNorm(visual_encoder.num_features)
return visual_encoder, ln_vision
def load_from_pretrained(self, url_or_filename):
if is_url(url_or_filename):
cached_file = download_cached_file(
url_or_filename, check_hash=False, progress=True
)
checkpoint = torch.load(cached_file, map_location="cpu")
elif os.path.isfile(url_or_filename):
checkpoint = torch.load(url_or_filename, map_location="cpu")
else:
raise RuntimeError("checkpoint url or path is invalid")
state_dict = checkpoint["model"]
msg = self.load_state_dict(state_dict, strict=False)
# logging.info("Missing keys {}".format(msg.missing_keys))
logging.info("load checkpoint from %s" % url_or_filename)
return msg
def disabled_train(self, mode=True):
"""Overwrite model.train with this function to make sure train/eval mode
does not change anymore."""
return self
class LayerNorm(nn.LayerNorm):
"""Subclass torch's LayerNorm to handle fp16."""
def forward(self, x: torch.Tensor):
orig_type = x.dtype
ret = super().forward(x.type(torch.float32))
return ret.type(orig_type)
def compute_sim_matrix(model, data_loader, **kwargs):
k_test = kwargs.pop("k_test")
metric_logger = MetricLogger(delimiter=" ")
header = "Evaluation:"
logging.info("Computing features for evaluation...")
start_time = time.time()
texts = data_loader.dataset.text
num_text = len(texts)
text_bs = 256
text_ids = []
text_embeds = []
text_atts = []
for i in range(0, num_text, text_bs):
text = texts[i : min(num_text, i + text_bs)]
text_input = model.tokenizer(
text,
padding="max_length",
truncation=True,
max_length=35,
return_tensors="pt",
).to(model.device)
text_feat = model.forward_text(text_input)
text_embed = F.normalize(model.text_proj(text_feat))
text_embeds.append(text_embed)
text_ids.append(text_input.input_ids)
text_atts.append(text_input.attention_mask)
text_embeds = torch.cat(text_embeds, dim=0)
text_ids = torch.cat(text_ids, dim=0)
text_atts = torch.cat(text_atts, dim=0)
vit_feats = []
image_embeds = []
for samples in data_loader:
image = samples["image"]
image = image.to(model.device)
image_feat, vit_feat = model.forward_image(image)
image_embed = model.vision_proj(image_feat)
image_embed = F.normalize(image_embed, dim=-1)
vit_feats.append(vit_feat.cpu())
image_embeds.append(image_embed)
vit_feats = torch.cat(vit_feats, dim=0)
image_embeds = torch.cat(image_embeds, dim=0)
sims_matrix = []
for image_embed in image_embeds:
sim_q2t = image_embed @ text_embeds.t()
sim_i2t, _ = sim_q2t.max(0)
sims_matrix.append(sim_i2t)
sims_matrix = torch.stack(sims_matrix, dim=0)
score_matrix_i2t = torch.full(
(len(data_loader.dataset.image), len(texts)), -100.0
).to(model.device)
num_tasks = dist_utils.get_world_size()
rank = dist_utils.get_rank()
step = sims_matrix.size(0) // num_tasks + 1
start = rank * step
end = min(sims_matrix.size(0), start + step)
for i, sims in enumerate(
metric_logger.log_every(sims_matrix[start:end], 50, header)
):
topk_sim, topk_idx = sims.topk(k=k_test, dim=0)
image_inputs = vit_feats[start + i].repeat(k_test, 1, 1).to(model.device)
score = model.compute_itm(
image_inputs=image_inputs,
text_ids=text_ids[topk_idx],
text_atts=text_atts[topk_idx],
).float()
score_matrix_i2t[start + i, topk_idx] = score + topk_sim
sims_matrix = sims_matrix.t()
score_matrix_t2i = torch.full(
(len(texts), len(data_loader.dataset.image)), -100.0
).to(model.device)
step = sims_matrix.size(0) // num_tasks + 1
start = rank * step
end = min(sims_matrix.size(0), start + step)
for i, sims in enumerate(
metric_logger.log_every(sims_matrix[start:end], 50, header)
):
topk_sim, topk_idx = sims.topk(k=k_test, dim=0)
image_inputs = vit_feats[topk_idx.cpu()].to(model.device)
score = model.compute_itm(
image_inputs=image_inputs,
text_ids=text_ids[start + i].repeat(k_test, 1),
text_atts=text_atts[start + i].repeat(k_test, 1),
).float()
score_matrix_t2i[start + i, topk_idx] = score + topk_sim
if dist_utils.is_dist_avail_and_initialized():
dist.barrier()
torch.distributed.all_reduce(
score_matrix_i2t, op=torch.distributed.ReduceOp.SUM
)
torch.distributed.all_reduce(
score_matrix_t2i, op=torch.distributed.ReduceOp.SUM
)
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
logging.info("Evaluation time {}".format(total_time_str))
return score_matrix_i2t.cpu().numpy(), score_matrix_t2i.cpu().numpy()

View File

@ -1,110 +0,0 @@
"""
Copyright (c) 2022, salesforce.com, inc.
All rights reserved.
SPDX-License-Identifier: BSD-3-Clause
For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""
from dataclasses import dataclass
from typing import Optional
import torch
from transformers.modeling_outputs import (
ModelOutput,
BaseModelOutputWithPoolingAndCrossAttentions,
CausalLMOutputWithCrossAttentions,
)
@dataclass
class BlipSimilarity(ModelOutput):
sim_i2t: torch.FloatTensor = None
sim_t2i: torch.FloatTensor = None
sim_i2t_m: Optional[torch.FloatTensor] = None
sim_t2i_m: Optional[torch.FloatTensor] = None
sim_i2t_targets: Optional[torch.FloatTensor] = None
sim_t2i_targets: Optional[torch.FloatTensor] = None
@dataclass
class BlipIntermediateOutput(ModelOutput):
"""
Data class for intermediate outputs of BLIP models.
image_embeds (torch.FloatTensor): Image embeddings, shape (batch_size, num_patches, embed_dim).
text_embeds (torch.FloatTensor): Text embeddings, shape (batch_size, seq_len, embed_dim).
image_embeds_m (torch.FloatTensor): Image embeddings from momentum visual encoder, shape (batch_size, num_patches, embed_dim).
text_embeds_m (torch.FloatTensor): Text embeddings from momentum text encoder, shape (batch_size, seq_len, embed_dim).
encoder_output (BaseModelOutputWithPoolingAndCrossAttentions): output from the image-grounded text encoder.
encoder_output_neg (BaseModelOutputWithPoolingAndCrossAttentions): output from the image-grounded text encoder for negative pairs.
decoder_output (CausalLMOutputWithCrossAttentions): output from the image-grounded text decoder.
decoder_labels (torch.LongTensor): labels for the captioning loss.
itm_logits (torch.FloatTensor): logits for the image-text matching loss, shape (batch_size * 3, 2).
itm_labels (torch.LongTensor): labels for the image-text matching loss, shape (batch_size * 3,)
"""
# uni-modal features
image_embeds: torch.FloatTensor = None
text_embeds: Optional[torch.FloatTensor] = None
image_embeds_m: Optional[torch.FloatTensor] = None
text_embeds_m: Optional[torch.FloatTensor] = None
# intermediate outputs of multimodal encoder
encoder_output: Optional[BaseModelOutputWithPoolingAndCrossAttentions] = None
encoder_output_neg: Optional[BaseModelOutputWithPoolingAndCrossAttentions] = None
itm_logits: Optional[torch.FloatTensor] = None
itm_labels: Optional[torch.LongTensor] = None
# intermediate outputs of multimodal decoder
decoder_output: Optional[CausalLMOutputWithCrossAttentions] = None
decoder_labels: Optional[torch.LongTensor] = None
@dataclass
class BlipOutput(ModelOutput):
# some finetuned models (e.g. BlipVQA) do not compute similarity, thus optional.
sims: Optional[BlipSimilarity] = None
intermediate_output: BlipIntermediateOutput = None
loss: Optional[torch.FloatTensor] = None
loss_itc: Optional[torch.FloatTensor] = None
loss_itm: Optional[torch.FloatTensor] = None
loss_lm: Optional[torch.FloatTensor] = None
@dataclass
class BlipOutputFeatures(ModelOutput):
"""
Data class of features from BlipFeatureExtractor.
Args:
image_embeds: (torch.FloatTensor) of shape (batch_size, num_patches+1, embed_dim), optional
image_features: (torch.FloatTensor) of shape (batch_size, num_patches+1, feature_dim), optional
text_embeds: (torch.FloatTensor) of shape (batch_size, sequence_length+1, embed_dim), optional
text_features: (torch.FloatTensor) of shape (batch_size, sequence_length+1, feature_dim), optional
The first embedding or feature is for the [CLS] token.
Features are obtained by projecting the corresponding embedding into a normalized low-dimensional space.
"""
image_embeds: Optional[torch.FloatTensor] = None
image_embeds_proj: Optional[torch.FloatTensor] = None
text_embeds: Optional[torch.FloatTensor] = None
text_embeds_proj: Optional[torch.FloatTensor] = None
multimodal_embeds: Optional[torch.FloatTensor] = None

View File

@ -1,384 +0,0 @@
import logging
import random
import torch
from torch.cuda.amp import autocast as autocast
import torch.nn as nn
from minigpt4.common.registry import registry
from minigpt4.models.blip2 import Blip2Base, disabled_train
from transformers.models.llama.modeling_llama import LlamaForCausalLM
from transformers import LlamaTokenizer
from peft import (
LoraConfig,
get_peft_model,
get_peft_model_state_dict,
prepare_model_for_int8_training,
set_peft_model_state_dict,
)
@registry.register_model("mini_gpt4")
class MiniGPT4(Blip2Base):
"""
BLIP2 GPT-LLAMA model.
"""
PRETRAINED_MODEL_CONFIG_DICT = {
"pretrain_vicuna0": "configs/models/minigpt4_vicuna0.yaml",
"pretrain_llama2": "configs/models/minigpt4_llama2.yaml",
}
def __init__(
self,
vit_model="eva_clip_g",
q_former_model="https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/blip2_pretrained_flant5xxl.pth",
img_size=224,
drop_path_rate=0,
use_grad_checkpoint=False,
vit_precision="fp16",
freeze_vit=True,
has_qformer=True,
freeze_qformer=True,
num_query_token=32,
llama_model="",
prompt_path="",
prompt_template="",
max_txt_len=32,
end_sym='\n',
low_resource=False, # use 8 bit and put vit in cpu
device_8bit=0, # the device of 8bit model should be set when loading and cannot be changed anymore.
lora_r=0,
lora_target_modules=["q_proj", "v_proj"],
lora_alpha=16,
lora_dropout=0.05,
):
super().__init__()
self.tokenizer = self.init_tokenizer()
self.low_resource = low_resource
print('Loading VIT')
self.visual_encoder, self.ln_vision = self.init_vision_encoder(
vit_model, img_size, drop_path_rate, use_grad_checkpoint, vit_precision
)
if freeze_vit:
for name, param in self.visual_encoder.named_parameters():
param.requires_grad = False
self.visual_encoder = self.visual_encoder.eval()
self.visual_encoder.train = disabled_train
for name, param in self.ln_vision.named_parameters():
param.requires_grad = False
self.ln_vision = self.ln_vision.eval()
self.ln_vision.train = disabled_train
logging.info("freeze vision encoder")
print('Loading VIT Done')
self.has_qformer = has_qformer
if self.has_qformer:
print('Loading Q-Former')
self.Qformer, self.query_tokens = self.init_Qformer(
num_query_token, self.visual_encoder.num_features
)
self.Qformer.cls = None
self.Qformer.bert.embeddings.word_embeddings = None
self.Qformer.bert.embeddings.position_embeddings = None
for layer in self.Qformer.bert.encoder.layer:
layer.output = None
layer.intermediate = None
self.load_from_pretrained(url_or_filename=q_former_model)
if freeze_qformer:
for name, param in self.Qformer.named_parameters():
param.requires_grad = False
self.Qformer = self.Qformer.eval()
self.Qformer.train = disabled_train
self.query_tokens.requires_grad = False
logging.info("freeze Qformer")
img_f_dim = self.Qformer.config.hidden_size
print('Loading Q-Former Done')
else:
img_f_dim = self.visual_encoder.num_features * 4
print('Do not use Q-Former here.')
print('Loading LLAMA')
self.llama_tokenizer = LlamaTokenizer.from_pretrained(llama_model, use_fast=False)
self.llama_tokenizer.pad_token = "$$"
if self.low_resource:
self.llama_model = LlamaForCausalLM.from_pretrained(
llama_model,
torch_dtype=torch.float16,
load_in_8bit=True,
device_map={'': device_8bit}
)
else:
self.llama_model = LlamaForCausalLM.from_pretrained(
llama_model,
torch_dtype=torch.float16,
)
if lora_r > 0:
self.llama_model = prepare_model_for_int8_training(self.llama_model)
loraconfig = LoraConfig(
r=lora_r,
lora_alpha=lora_alpha,
target_modules=lora_target_modules,
lora_dropout=lora_dropout,
bias="none",
task_type="CAUSAL_LM"
)
self.llama_model = get_peft_model(self.llama_model, loraconfig)
# if ckpt_path:
# print('load the llm under lora')
# ckpt = torch.load(ckpt_path)
# set_peft_model_state_dict(self.llama_model,ckpt)
self.llama_model.print_trainable_parameters()
else:
for name, param in self.llama_model.named_parameters():
param.requires_grad = False
print('Loading LLAMA Done')
self.llama_proj = nn.Linear(
img_f_dim, self.llama_model.config.hidden_size
)
self.max_txt_len = max_txt_len
self.end_sym = end_sym
if prompt_path:
with open(prompt_path, 'r') as f:
raw_prompts = f.read().splitlines()
filted_prompts = [raw_prompt for raw_prompt in raw_prompts if "<ImageHere>" in raw_prompt]
self.prompt_list = [prompt_template.format(p) for p in filted_prompts]
print('Load {} training prompts'.format(len(self.prompt_list)))
print('Prompt Example \n{}'.format(random.choice(self.prompt_list)))
else:
self.prompt_list = []
def vit_to_cpu(self):
self.ln_vision.to("cpu")
self.ln_vision.float()
self.visual_encoder.to("cpu")
self.visual_encoder.float()
def encode_img(self, image):
device = image.device
if self.low_resource:
self.vit_to_cpu()
image = image.to("cpu")
with self.maybe_autocast():
image_embeds = self.ln_vision(self.visual_encoder(image)).to(device)
if self.has_qformer:
image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(device)
query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
query_output = self.Qformer.bert(
query_embeds=query_tokens,
encoder_hidden_states=image_embeds,
encoder_attention_mask=image_atts,
return_dict=True,
)
inputs_llama = self.llama_proj(query_output.last_hidden_state)
else:
image_embeds = image_embeds[:, 1:, :]
bs, pn, hs = image_embeds.shape
image_embeds = image_embeds.view(bs, int(pn / 4), int(hs * 4))
inputs_llama = self.llama_proj(image_embeds)
atts_llama = torch.ones(inputs_llama.size()[:-1], dtype=torch.long).to(image.device)
return inputs_llama, atts_llama
def get_context_emb(self, prompt, img_list):
device = img_list[0].device
prompt_segs = prompt.split('<ImageHere>')
assert len(prompt_segs) == len(img_list) + 1, "Unmatched numbers of image placeholders and images."
seg_tokens = [
self.llama_tokenizer(
seg, return_tensors="pt", add_special_tokens=i == 0).to(device).input_ids
# only add bos to the first seg
for i, seg in enumerate(prompt_segs)
]
seg_embs = [self.embed_tokens(seg_t) for seg_t in seg_tokens]
mixed_embs = [emb for pair in zip(seg_embs[:-1], img_list) for emb in pair] + [seg_embs[-1]]
mixed_embs = torch.cat(mixed_embs, dim=1)
return mixed_embs
def prompt_wrap(self, img_embeds, atts_img, prompts):
if prompts:
emb_lists = []
if isinstance(prompts, str):
prompts = [prompts] * len(img_embeds)
for each_img_embed, each_prompt in zip(img_embeds, prompts):
p_before, p_after = each_prompt.split('<ImageHere>')
p_before_tokens = self.llama_tokenizer(
p_before, return_tensors="pt", add_special_tokens=False).to(img_embeds.device)
p_after_tokens = self.llama_tokenizer(
p_after, return_tensors="pt", add_special_tokens=False).to(img_embeds.device)
p_before_embed = self.embed_tokens(p_before_tokens.input_ids)
p_after_embed = self.embed_tokens(p_after_tokens.input_ids)
wrapped_emb = torch.cat([p_before_embed, each_img_embed[None], p_after_embed], dim=1)
emb_lists.append(wrapped_emb)
emb_lens = [emb.shape[1] for emb in emb_lists]
pad_emb = self.embed_tokens(torch.tensor(self.llama_tokenizer.pad_token_id, device=img_embeds.device))
wrapped_embs = pad_emb.expand(len(emb_lens), max(emb_lens), -1).clone()
wrapped_atts = torch.zeros([len(emb_lens), max(emb_lens)], dtype=torch.int, device=img_embeds.device)
for i, emb in enumerate(emb_lists):
wrapped_embs[i, :emb_lens[i]] = emb
wrapped_atts[i, :emb_lens[i]] = 1
return wrapped_embs, wrapped_atts
else:
return img_embeds, atts_img
def concat_emb_input_output(self, input_embs, input_atts, output_embs, output_atts):
input_lens = []
cat_embs = []
cat_atts = []
for i in range(input_embs.size(0)):
input_len = input_atts[i].sum()
input_lens.append(input_len)
cat_embs.append(
torch.cat([
input_embs[i][:input_len],
output_embs[i],
input_embs[i][input_len:]
])
)
cat_atts.append(
torch.cat([
input_atts[i][:input_len],
output_atts[i],
input_atts[i][input_len:]
])
)
cat_embs = torch.stack(cat_embs)
cat_atts = torch.stack(cat_atts)
return cat_embs, cat_atts, input_lens
def forward(self, samples):
image = samples["image"]
img_embeds, atts_img = self.encode_img(image)
if self.prompt_list:
instruction = random.choice(self.prompt_list)
else:
instruction = samples["instruction_input"] if "instruction_input" in samples else None
img_embeds, atts_img = self.prompt_wrap(img_embeds, atts_img, instruction)
self.llama_tokenizer.padding_side = "right"
text = [t + self.end_sym for t in samples["answer"]]
to_regress_tokens = self.llama_tokenizer(
text,
return_tensors="pt",
padding="longest",
truncation=True,
max_length=self.max_txt_len,
add_special_tokens=False
).to(image.device)
batch_size = img_embeds.shape[0]
bos = torch.ones([batch_size, 1],
dtype=to_regress_tokens.input_ids.dtype,
device=to_regress_tokens.input_ids.device) * self.llama_tokenizer.bos_token_id
bos_embeds = self.embed_tokens(bos)
atts_bos = atts_img[:, :1]
to_regress_embeds = self.embed_tokens(to_regress_tokens.input_ids)
inputs_embeds, attention_mask, input_lens = \
self.concat_emb_input_output(img_embeds, atts_img, to_regress_embeds, to_regress_tokens.attention_mask)
inputs_embeds = torch.cat([bos_embeds, inputs_embeds], dim=1)
attention_mask = torch.cat([atts_bos, attention_mask], dim=1)
part_targets = to_regress_tokens.input_ids.masked_fill(
to_regress_tokens.input_ids == self.llama_tokenizer.pad_token_id, -100
)
targets = (
torch.ones([inputs_embeds.shape[0], inputs_embeds.shape[1]],
dtype=torch.long).to(image.device).fill_(-100)
)
for i, target in enumerate(part_targets):
targets[i, input_lens[i] + 1:input_lens[i] + len(target) + 1] = target # plus 1 for bos
with self.maybe_autocast():
outputs = self.llama_model(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
return_dict=True,
labels=targets,
)
loss = outputs.loss
return {"loss": loss}
def embed_tokens(self, token_ids):
if hasattr(self.llama_model.base_model, 'model'): ## lora wrapped model
embeds = self.llama_model.base_model.model.model.embed_tokens(token_ids)
else:
embeds = self.llama_model.base_model.embed_tokens(token_ids)
return embeds
@classmethod
def from_config(cls, cfg):
vit_model = cfg.get("vit_model", "eva_clip_g")
q_former_model = cfg.get("q_former_model", "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/blip2_pretrained_flant5xxl.pth")
img_size = cfg.get("image_size")
num_query_token = cfg.get("num_query_token")
llama_model = cfg.get("llama_model")
drop_path_rate = cfg.get("drop_path_rate", 0)
use_grad_checkpoint = cfg.get("use_grad_checkpoint", False)
vit_precision = cfg.get("vit_precision", "fp16")
freeze_vit = cfg.get("freeze_vit", True)
has_qformer = cfg.get("has_qformer", True)
freeze_qformer = cfg.get("freeze_qformer", True)
low_resource = cfg.get("low_resource", False)
device_8bit = cfg.get("device_8bit", 0)
prompt_path = cfg.get("prompt_path", "")
prompt_template = cfg.get("prompt_template", "")
max_txt_len = cfg.get("max_txt_len", 32)
end_sym = cfg.get("end_sym", '\n')
lora_r = cfg.get("lora_r", 0)
lora_alpha = cfg.get("lora_alpha", 32)
model = cls(
vit_model=vit_model,
q_former_model=q_former_model,
img_size=img_size,
drop_path_rate=drop_path_rate,
use_grad_checkpoint=use_grad_checkpoint,
vit_precision=vit_precision,
freeze_vit=freeze_vit,
has_qformer=has_qformer,
freeze_qformer=freeze_qformer,
num_query_token=num_query_token,
llama_model=llama_model,
prompt_path=prompt_path,
prompt_template=prompt_template,
max_txt_len=max_txt_len,
end_sym=end_sym,
low_resource=low_resource,
device_8bit=device_8bit,
lora_r=lora_r,
lora_alpha=lora_alpha,
)
ckpt_path = cfg.get("ckpt", "") # load weights of MiniGPT-4
if ckpt_path:
print("Load BLIP2-LLM Checkpoint: {}".format(ckpt_path))
ckpt = torch.load(ckpt_path, map_location="cpu")
msg = model.load_state_dict(ckpt['model'], strict=False)
return model

195
minigpt4/models/minigpt4.py Normal file
View File

@ -0,0 +1,195 @@
import logging
import random
import torch
from torch.cuda.amp import autocast as autocast
import torch.nn as nn
from minigpt4.common.registry import registry
from minigpt4.models.base_model import disabled_train
from minigpt4.models.minigpt_base import MiniGPTBase
from minigpt4.models.Qformer import BertConfig, BertLMHeadModel
@registry.register_model("minigpt4")
class MiniGPT4(MiniGPTBase):
"""
MiniGPT-4 model
"""
PRETRAINED_MODEL_CONFIG_DICT = {
"pretrain_vicuna0": "configs/models/minigpt4_vicuna0.yaml",
"pretrain_llama2": "configs/models/minigpt4_llama2.yaml",
}
def __init__(
self,
vit_model="eva_clip_g",
q_former_model="https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/blip2_pretrained_flant5xxl.pth",
img_size=224,
drop_path_rate=0,
use_grad_checkpoint=False,
vit_precision="fp16",
freeze_vit=True,
has_qformer=True,
freeze_qformer=True,
num_query_token=32,
llama_model="",
prompt_path="",
prompt_template="",
max_txt_len=32,
end_sym='\n',
low_resource=False, # use 8 bit and put vit in cpu
device_8bit=0, # the device of 8bit model should be set when loading and cannot be changed anymore.
):
super().__init__(
vit_model=vit_model,
img_size=img_size,
drop_path_rate=drop_path_rate,
use_grad_checkpoint=use_grad_checkpoint,
vit_precision=vit_precision,
freeze_vit=freeze_vit,
llama_model=llama_model,
max_txt_len=max_txt_len,
end_sym=end_sym,
low_resource=low_resource,
device_8bit=device_8bit,
)
self.has_qformer = has_qformer
if self.has_qformer:
print('Loading Q-Former')
self.Qformer, self.query_tokens = self.init_Qformer(
num_query_token, self.visual_encoder.num_features, freeze_qformer
)
self.load_from_pretrained(url_or_filename=q_former_model) # load q-former weights here
img_f_dim = self.Qformer.config.hidden_size
print('Loading Q-Former Done')
else:
img_f_dim = self.visual_encoder.num_features * 4
print('Do not use Q-Former here.')
self.llama_proj = nn.Linear(
img_f_dim, self.llama_model.config.hidden_size
)
if prompt_path:
with open(prompt_path, 'r') as f:
raw_prompts = f.read().splitlines()
filted_prompts = [raw_prompt for raw_prompt in raw_prompts if "<ImageHere>" in raw_prompt]
self.prompt_list = [prompt_template.format(p) for p in filted_prompts]
print('Load {} training prompts'.format(len(self.prompt_list)))
print('Prompt Example \n{}'.format(random.choice(self.prompt_list)))
else:
self.prompt_list = []
@classmethod
def init_Qformer(cls, num_query_token, vision_width, freeze):
encoder_config = BertConfig.from_pretrained("bert-base-uncased")
encoder_config.encoder_width = vision_width
# insert cross-attention layer every other block
encoder_config.add_cross_attention = True
encoder_config.cross_attention_freq = 2
encoder_config.query_length = num_query_token
Qformer = BertLMHeadModel(config=encoder_config)
query_tokens = nn.Parameter(
torch.zeros(1, num_query_token, encoder_config.hidden_size)
)
query_tokens.data.normal_(mean=0.0, std=encoder_config.initializer_range)
Qformer.cls = None
Qformer.bert.embeddings.word_embeddings = None
Qformer.bert.embeddings.position_embeddings = None
for layer in Qformer.bert.encoder.layer:
layer.output = None
layer.intermediate = None
if freeze:
for name, param in Qformer.named_parameters():
param.requires_grad = False
Qformer = Qformer.eval()
Qformer.train = disabled_train
query_tokens.requires_grad = False
logging.info("freeze Qformer")
return Qformer, query_tokens
def encode_img(self, image):
device = image.device
if len(image.shape) > 4:
image = image.reshape(-1, *image.shape[-3:])
with self.maybe_autocast():
image_embeds = self.ln_vision(self.visual_encoder(image)).to(device)
if self.has_qformer:
image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(device)
query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
query_output = self.Qformer.bert(
query_embeds=query_tokens,
encoder_hidden_states=image_embeds,
encoder_attention_mask=image_atts,
return_dict=True,
)
inputs_llama = self.llama_proj(query_output.last_hidden_state)
else:
image_embeds = image_embeds[:, 1:, :]
bs, pn, hs = image_embeds.shape
image_embeds = image_embeds.view(bs, int(pn / 4), int(hs * 4))
inputs_llama = self.llama_proj(image_embeds)
atts_llama = torch.ones(inputs_llama.size()[:-1], dtype=torch.long).to(image.device)
return inputs_llama, atts_llama
@classmethod
def from_config(cls, cfg):
vit_model = cfg.get("vit_model", "eva_clip_g")
q_former_model = cfg.get("q_former_model", "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/blip2_pretrained_flant5xxl.pth")
img_size = cfg.get("image_size")
num_query_token = cfg.get("num_query_token")
llama_model = cfg.get("llama_model")
drop_path_rate = cfg.get("drop_path_rate", 0)
use_grad_checkpoint = cfg.get("use_grad_checkpoint", False)
vit_precision = cfg.get("vit_precision", "fp16")
freeze_vit = cfg.get("freeze_vit", True)
has_qformer = cfg.get("has_qformer", True)
freeze_qformer = cfg.get("freeze_qformer", True)
low_resource = cfg.get("low_resource", False)
device_8bit = cfg.get("device_8bit", 0)
prompt_path = cfg.get("prompt_path", "")
prompt_template = cfg.get("prompt_template", "")
max_txt_len = cfg.get("max_txt_len", 32)
end_sym = cfg.get("end_sym", '\n')
model = cls(
vit_model=vit_model,
q_former_model=q_former_model,
img_size=img_size,
drop_path_rate=drop_path_rate,
use_grad_checkpoint=use_grad_checkpoint,
vit_precision=vit_precision,
freeze_vit=freeze_vit,
has_qformer=has_qformer,
freeze_qformer=freeze_qformer,
num_query_token=num_query_token,
llama_model=llama_model,
prompt_path=prompt_path,
prompt_template=prompt_template,
max_txt_len=max_txt_len,
end_sym=end_sym,
low_resource=low_resource,
device_8bit=device_8bit,
)
ckpt_path = cfg.get("ckpt", "") # load weights of MiniGPT-4
if ckpt_path:
print("Load MiniGPT-4 Checkpoint: {}".format(ckpt_path))
ckpt = torch.load(ckpt_path, map_location="cpu")
msg = model.load_state_dict(ckpt['model'], strict=False)
return model

View File

@ -0,0 +1,401 @@
import logging
import random
import torch
from torch.cuda.amp import autocast as autocast
import torch.nn as nn
from minigpt4.common.registry import registry
from minigpt4.models.base_model import BaseModel
class MiniGPTBase(BaseModel):
"""
Base class for MiniGPT-4 and MiniGPT-v2
"""
def __init__(
self,
vit_model="eva_clip_g",
img_size=224,
drop_path_rate=0,
use_grad_checkpoint=False,
vit_precision="fp16",
freeze_vit=True,
llama_model="",
max_txt_len=32,
max_context_len=3800,
prompt_template="",
end_sym='\n',
low_resource=False, # use 8 bit and put vit in cpu
device_8bit=0, # the device of 8bit model should be set when loading and cannot be changed anymore.
lora_r=0, # lora_r means lora is not used
lora_target_modules=["q_proj", "v_proj"],
lora_alpha=16,
lora_dropout=0.05,
):
super().__init__()
self.llama_model, self.llama_tokenizer = self.init_llm(
llama_model_path=llama_model,
low_resource=low_resource,
low_res_device=device_8bit,
lora_r=lora_r,
lora_target_modules=lora_target_modules,
lora_alpha=lora_alpha,
lora_dropout=lora_dropout,
)
self.visual_encoder, self.ln_vision = self.init_vision_encoder(
vit_model, img_size, drop_path_rate, use_grad_checkpoint, vit_precision, freeze_vit
)
self.max_txt_len = max_txt_len
self.max_context_len = max_context_len
self.end_sym = end_sym
self.prompt_template = prompt_template
self.prompt_list = []
def vit_to_cpu(self):
self.ln_vision.to("cpu")
self.ln_vision.float()
self.visual_encoder.to("cpu")
self.visual_encoder.float()
def get_context_emb(self, prompt, img_list):
device = img_list[0].device
prompt_segs = prompt.split('<ImageHere>')
assert len(prompt_segs) == len(img_list) + 1, "Unmatched numbers of image placeholders and images."
seg_tokens = [
self.llama_tokenizer(
seg, return_tensors="pt", add_special_tokens=i==0).to(device).input_ids # only add bos to the first seg
for i, seg in enumerate(prompt_segs)
]
seg_embs = [self.embed_tokens(seg_t) for seg_t in seg_tokens]
mixed_embs = [emb for pair in zip(seg_embs[:-1], img_list) for emb in pair] + [seg_embs[-1]]
mixed_embs = torch.cat(mixed_embs, dim=1)
return mixed_embs
def prompt_wrap(self, img_embeds, atts_img, prompts, lengths=None):
if prompts is None or len(prompts) == 0:
# prompts is not provided, just return the original image embedding
return img_embeds, atts_img
elif img_embeds is None:
# prompt is provided but there is no image embedding. return the prompt embedding in right padding
self.llama_tokenizer.padding_side = "right"
prompt_tokens = self.llama_tokenizer(
prompts,
return_tensors="pt",
padding="longest",
add_special_tokens=False
).to(self.device)
prompt_embeds = self.embed_tokens(prompt_tokens.input_ids)
atts_prompt = prompt_tokens.attention_mask
return prompt_embeds, atts_prompt
else:
# return the multi-modal embedding in right padding
emb_lists = []
if isinstance(prompts, str):
prompts = [prompts] * len(img_embeds)
for idx, (each_img_embed, each_prompt) in enumerate(zip(img_embeds, prompts)):
pn = each_img_embed.shape[-2]
if lengths is not None:
each_img_embed = each_img_embed.reshape(-1, each_img_embed.shape[-1])
each_img_embed = each_img_embed[:lengths[idx] * pn]
p_segs = each_prompt.split('<ImageHere>')
interleave_emb = []
for idx, seg in enumerate(p_segs[:-1]):
p_tokens = self.llama_tokenizer(
seg, return_tensors="pt", add_special_tokens=False).to(img_embeds.device)
p_embed = self.embed_tokens(p_tokens.input_ids)
interleave_emb.append(torch.cat([p_embed, each_img_embed[None][:, idx * pn:(idx + 1) * pn]], dim=1))
wrapped_emb = torch.cat(interleave_emb, dim=1)
p_tokens = self.llama_tokenizer(
p_segs[-1], return_tensors="pt", add_special_tokens=False).to(img_embeds.device)
p_embed = self.embed_tokens(p_tokens.input_ids)
wrapped_emb = torch.cat([wrapped_emb, p_embed], dim=1)
emb_lists.append(wrapped_emb)
emb_lens = [emb.shape[1] for emb in emb_lists]
pad_emb = self.embed_tokens(torch.tensor(self.llama_tokenizer.pad_token_id, device=img_embeds.device))
max_length = max(emb_lens) if max(emb_lens) < self.max_context_len else self.max_context_len
wrapped_embs = pad_emb.expand(len(emb_lens), max_length, -1).clone()
wrapped_atts = torch.zeros([len(emb_lens), max_length], dtype=torch.int, device=img_embeds.device)
for i, emb in enumerate(emb_lists):
length = emb_lens[i] if emb_lens[i] < self.max_context_len else self.max_context_len
wrapped_embs[i, :length] = emb[:, :length]
wrapped_atts[i, :length] = 1
return wrapped_embs, wrapped_atts
def concat_emb_input_output(self, input_embs, input_atts, output_embs, output_atts):
"""
Concatenate the batched input embedding and batched output embedding together.
Both the input and the output embedding should be right padded.
"""
input_lens = []
cat_embs = []
cat_atts = []
for i in range(input_embs.size(0)):
input_len = input_atts[i].sum()
input_lens.append(input_len)
cat_embs.append(
torch.cat([
input_embs[i][:input_len],
output_embs[i],
input_embs[i][input_len:]
])
)
cat_atts.append(
torch.cat([
input_atts[i][:input_len],
output_atts[i],
input_atts[i][input_len:]
])
)
cat_embs = torch.stack(cat_embs)
cat_atts = torch.stack(cat_atts)
return cat_embs, cat_atts, input_lens
def tokenize_conversation(self, conv_q, conv_a):
"""concatenate conversation and make sure the model is only trained to regress the answer"""
to_regress_token_ids_list = []
targets_list = []
batch_size = len(conv_q)
for batch_idx in range(batch_size):
questions, answers = conv_q[batch_idx], conv_a[batch_idx]
questions = [self.llama_tokenizer(q,
return_tensors="pt",
add_special_tokens=False).to(self.device) for q in questions[1:]] # the first question is handled in the prompt wrap function, skip it
answers = [self.llama_tokenizer(q,
return_tensors="pt",
add_special_tokens=False).to(self.device) for q in answers]
cur_id = []
cur_target = []
for i in range(len(questions)):
cur_id.append(answers[i].input_ids)
cur_target.append(answers[i].input_ids)
cur_id.append(questions[i].input_ids)
cur_target.append(torch.ones_like(questions[i].input_ids) * -100)
cur_id.append(answers[-1].input_ids)
cur_target.append(answers[-1].input_ids)
cur_id = torch.cat(cur_id, dim=1)
cur_target = torch.cat(cur_target, dim=1)
to_regress_token_ids_list.append(cur_id)
targets_list.append(cur_target)
max_len = min(max([target.shape[1] for target in targets_list]), self.max_txt_len)
to_regress_token_ids = torch.ones([batch_size, max_len],
dtype=cur_id.dtype, device=self.device) * self.llama_tokenizer.pad_token_id
targets = torch.ones([batch_size, max_len],
dtype=cur_id.dtype, device=self.device) * -100
for batch_idx in range(batch_size):
cur_len = to_regress_token_ids_list[batch_idx].shape[1]
to_regress_token_ids[batch_idx, :cur_len] = to_regress_token_ids_list[batch_idx][0, :max_len]
targets[batch_idx, :cur_len] = targets_list[batch_idx][0, :max_len]
to_regress_token_attn = (to_regress_token_ids != self.llama_tokenizer.pad_token_id).to(torch.int)
return to_regress_token_ids, to_regress_token_attn, targets
def preparing_embedding(self, samples):
### prepare input tokens
if 'image' in samples:
img_embeds, img_atts = self.encode_img(samples["image"])
else:
img_embeds = img_atts = None
if 'conv_q' in samples:
# handeling conversation datasets
conv_q, conv_a = samples['conv_q'], samples['conv_a']
connect_sym = samples['connect_sym'][0]
conv_q = [q.split(connect_sym)for q in conv_q]
conv_a = [a.split(connect_sym) for a in conv_a]
conv_q = [[self.prompt_template.format(item) for item in items] for items in conv_q]
cond_embeds, cond_atts = self.prompt_wrap(img_embeds, img_atts, [q[0] for q in conv_q])
regress_token_ids, regress_atts, part_targets = self.tokenize_conversation(conv_q, conv_a)
else:
if "instruction_input" in samples:
instruction = samples["instruction_input"]
elif self.prompt_list:
instruction = random.choice(self.prompt_list)
else:
instruction = None
if self.chat_template:
instruction = [self.prompt_template.format(instruct) for instruct in instruction]
if 'length' in samples:
# the input is a image train (like videos)
bsz, pn, hs = img_embeds.shape
img_embeds = img_embeds.reshape(len(samples['image']), -1, pn, hs)
cond_embeds, cond_atts = self.prompt_wrap(img_embeds, img_atts, instruction, samples['length'])
else:
cond_embeds, cond_atts = self.prompt_wrap(img_embeds, img_atts, instruction)
### prepare target tokens
self.llama_tokenizer.padding_side = "right"
text = [t + self.end_sym for t in samples["answer"]]
regress_tokens = self.llama_tokenizer(
text,
return_tensors="pt",
padding="longest",
truncation=True,
max_length=self.max_txt_len,
add_special_tokens=False
).to(self.device)
regress_token_ids = regress_tokens.input_ids
regress_atts = regress_tokens.attention_mask
part_targets = regress_token_ids.masked_fill(
regress_token_ids == self.llama_tokenizer.pad_token_id, -100
)
regress_embeds = self.embed_tokens(regress_token_ids)
return cond_embeds, cond_atts, regress_embeds, regress_atts, part_targets
def forward(self, samples, reduction='mean'):
# prepare the embedding to condition and the embedding to regress
cond_embeds, cond_atts, regress_embeds, regress_atts, part_targets = \
self.preparing_embedding(samples)
# concat the embedding to condition and the embedding to regress
inputs_embeds, attention_mask, input_lens = \
self.concat_emb_input_output(cond_embeds, cond_atts, regress_embeds, regress_atts)
# get bos token embedding
bos = torch.ones_like(part_targets[:, :1]) * self.llama_tokenizer.bos_token_id
bos_embeds = self.embed_tokens(bos)
bos_atts = cond_atts[:, :1]
# add bos token at the begining
inputs_embeds = torch.cat([bos_embeds, inputs_embeds], dim=1)
attention_mask = torch.cat([bos_atts, attention_mask], dim=1)
# ensemble the final targets
targets = torch.ones([inputs_embeds.shape[0], inputs_embeds.shape[1]],
dtype=torch.long).to(self.device).fill_(-100)
for i, target in enumerate(part_targets):
targets[i, input_lens[i]+1:input_lens[i]+len(target)+1] = target # plus 1 for bos
with self.maybe_autocast():
outputs = self.llama_model(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
return_dict=True,
labels=targets,
reduction=reduction
)
loss = outputs.loss
return {"loss": loss}
def embed_tokens(self, token_ids):
if hasattr(self.llama_model.base_model, 'model'): ## lora wrapped model
embeds = self.llama_model.base_model.model.model.embed_tokens(token_ids)
else:
embeds = self.llama_model.base_model.embed_tokens(token_ids)
return embeds
@torch.no_grad()
def generate(
self,
images,
texts,
num_beams=1,
max_new_tokens=20,
min_length=1,
top_p=0.9,
repetition_penalty=1,
length_penalty=1,
temperature=1,
do_sample=False,
stop_words_ids=[2],
):
'''
function for generate test use
'''
stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(
stops=[torch.tensor([i]).to(self.device) for i in stop_words_ids])])
img_embeds, atts_img = self.encode_img(images.to(self.device))
image_lists = [[image_emb[None]] for image_emb in img_embeds]
batch_embs = [self.get_context_emb(text, img_list) for text, img_list in zip(texts, image_lists)]
batch_size = len(batch_embs)
max_len = max([emb.shape[1] for emb in batch_embs])
emb_dim = batch_embs[0].shape[2]
dtype = batch_embs[0].dtype
device = batch_embs[0].device
embs = torch.zeros([batch_size, max_len, emb_dim], dtype=dtype, device=device)
attn_mask = torch.zeros([batch_size, max_len], dtype=torch.int, device=device)
for i, emb in enumerate(batch_embs):
emb_len = emb.shape[1]
embs[i, -emb_len:] = emb[0]
attn_mask[i, -emb_len:] = 1
with self.maybe_autocast():
outputs = self.llama_model.generate(
inputs_embeds=embs,
attention_mask=attn_mask,
max_new_tokens=max_new_tokens,
num_beams=num_beams,
length_penalty=length_penalty,
temperature=temperature,
do_sample=do_sample,
min_length=min_length,
top_p=top_p,
repetition_penalty=repetition_penalty
# stopping_criteria=stopping_criteria,
)
answers = []
for output_token in outputs:
if output_token[0] == 0:
output_token = output_token[1:]
output_texts = self.llama_tokenizer.decode(output_token, skip_special_tokens=True)
output_texts = output_texts.split('</s>')[0] # remove the stop sign </s>
output_texts = output_texts.replace("<s>", "")
output_texts = output_texts.split(r'[/INST]')[-1].strip()
answers.append(output_texts)
return answers
@torch.no_grad()
def multi_select(self, images, texts, answers, num_cand=None):
all_losses = []
for answer in answers:
choice_samples = {
'image': images,
'instruction_input': texts,
'answer': answer
}
loss = self.forward(choice_samples, reduction='none')['loss'].reshape(-1, 1)
all_losses.append(loss)
torch.cuda.empty_cache()
all_losses = torch.cat(all_losses, dim=-1)
if num_cand is not None:
for i in range(all_losses.shape[0]):
all_losses[i, num_cand[i]:] = 9999
output_class_ranks = torch.argsort(all_losses, dim=-1)
return output_class_ranks.tolist()

View File

@ -0,0 +1,139 @@
import logging
import random
import torch
from torch.cuda.amp import autocast as autocast
import torch.nn as nn
from minigpt4.common.registry import registry
from minigpt4.models.base_model import disabled_train
from minigpt4.models.minigpt_base import MiniGPTBase
from minigpt4.models.Qformer import BertConfig, BertLMHeadModel
@registry.register_model("minigpt_v2")
class MiniGPTv2(MiniGPTBase):
"""
MiniGPT-v2 model
"""
PRETRAINED_MODEL_CONFIG_DICT = {
"pretrain": "configs/models/minigpt_v2.yaml",
}
def __init__(
self,
vit_model="eva_clip_g",
img_size=448,
drop_path_rate=0,
use_grad_checkpoint=False,
vit_precision="fp16",
freeze_vit=True,
llama_model="",
prompt_template='[INST] {} [/INST]',
max_txt_len=300,
end_sym='\n',
lora_r=64,
lora_target_modules=["q_proj", "v_proj"],
lora_alpha=16,
lora_dropout=0.05,
chat_template=False,
use_grad_checkpoint_llm=False,
max_context_len=3800,
low_resource=False, # use 8 bit and put vit in cpu
device_8bit=0, # the device of 8bit model should be set when loading and cannot be changed anymore.
):
super().__init__(
vit_model=vit_model,
img_size=img_size,
drop_path_rate=drop_path_rate,
use_grad_checkpoint=use_grad_checkpoint,
vit_precision=vit_precision,
freeze_vit=freeze_vit,
llama_model=llama_model,
max_txt_len=max_txt_len,
max_context_len=max_context_len,
end_sym=end_sym,
prompt_template=prompt_template,
low_resource=low_resource,
device_8bit=device_8bit,
lora_r=lora_r,
lora_target_modules=lora_target_modules,
lora_alpha=lora_alpha,
lora_dropout=lora_dropout,
)
img_f_dim = self.visual_encoder.num_features * 4
self.llama_proj = nn.Linear(
img_f_dim, self.llama_model.config.hidden_size
)
self.chat_template = chat_template
if use_grad_checkpoint_llm:
self.llama_model.gradient_checkpointing_enable()
def encode_img(self, image):
device = image.device
if len(image.shape) > 4:
image = image.reshape(-1, *image.shape[-3:])
with self.maybe_autocast():
image_embeds = self.ln_vision(self.visual_encoder(image)).to(device)
image_embeds = image_embeds[:, 1:, :]
bs, pn, hs = image_embeds.shape
image_embeds = image_embeds.view(bs, int(pn / 4), int(hs * 4))
inputs_llama = self.llama_proj(image_embeds)
atts_llama = torch.ones(inputs_llama.size()[:-1], dtype=torch.long).to(image.device)
return inputs_llama, atts_llama
@classmethod
def from_config(cls, cfg):
vit_model = cfg.get("vit_model", "eva_clip_g")
img_size = cfg.get("image_size")
llama_model = cfg.get("llama_model")
drop_path_rate = cfg.get("drop_path_rate", 0)
use_grad_checkpoint = cfg.get("use_grad_checkpoint", False)
vit_precision = cfg.get("vit_precision", "fp16")
freeze_vit = cfg.get("freeze_vit", True)
low_resource = cfg.get("low_resource", False)
prompt_template = cfg.get("prompt_template", '[INST] {} [/INST]')
max_txt_len = cfg.get("max_txt_len", 300)
end_sym = cfg.get("end_sym", '\n')
lora_r = cfg.get("lora_r", 64)
lora_alpha = cfg.get("lora_alpha", 16)
chat_template = cfg.get("chat_template", False)
use_grad_checkpoint_llm = cfg.get("use_grad_checkpoint_llm", False)
max_context_len = cfg.get("max_context_len", 3800)
model = cls(
vit_model=vit_model,
img_size=img_size,
drop_path_rate=drop_path_rate,
use_grad_checkpoint=use_grad_checkpoint,
vit_precision=vit_precision,
freeze_vit=freeze_vit,
llama_model=llama_model,
prompt_template=prompt_template,
max_txt_len=max_txt_len,
low_resource=low_resource,
end_sym=end_sym,
lora_r=lora_r,
lora_alpha=lora_alpha,
chat_template=chat_template,
use_grad_checkpoint_llm=use_grad_checkpoint_llm,
max_context_len=max_context_len,
)
ckpt_path = cfg.get("ckpt", "") # load weights of MiniGPT-4
if ckpt_path:
print("Load Minigpt-4-LLM Checkpoint: {}".format(ckpt_path))
ckpt = torch.load(ckpt_path, map_location="cpu")
msg = model.load_state_dict(ckpt['model'], strict=False)
return model

View File

@ -1,628 +1,17 @@
# This script is based on https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py
""" PyTorch LLaMA model."""
import math
from typing import List, Optional, Tuple, Union
import torch
import torch.utils.checkpoint
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
import torch.nn.functional as F
from torch.nn import CrossEntropyLoss
from transformers.activations import ACT2FN
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
from transformers.modeling_utils import PreTrainedModel
from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
from transformers.models.llama.configuration_llama import LlamaConfig
from transformers.utils import add_start_docstrings_to_model_forward, replace_return_docstrings
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.models.llama.modeling_llama import LLAMA_INPUTS_DOCSTRING, _CONFIG_FOR_DOC
from transformers.models.llama.modeling_llama import LlamaForCausalLM as LlamaForCausalLMOrig
logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "LlamaConfig"
# Copied from transformers.models.bart.modeling_bart._make_causal_mask
def _make_causal_mask(
input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
):
"""
Make causal mask used for bi-directional self-attention.
"""
bsz, tgt_len = input_ids_shape
mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min, device=device), device=device)
mask_cond = torch.arange(mask.size(-1), device=device)
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
mask = mask.to(dtype)
if past_key_values_length > 0:
mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
# Copied from transformers.models.bart.modeling_bart._expand_mask
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
"""
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
"""
bsz, src_len = mask.size()
tgt_len = tgt_len if tgt_len is not None else src_len
expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
inverted_mask = 1.0 - expanded_mask
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
class LlamaRMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
"""
LlamaRMSNorm is equivalent to T5LayerNorm
"""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
# convert into half-precision if necessary
if self.weight.dtype in [torch.float16, torch.bfloat16]:
hidden_states = hidden_states.to(self.weight.dtype)
return self.weight * hidden_states
class LlamaRotaryEmbedding(torch.nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
super().__init__()
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
self.register_buffer("inv_freq", inv_freq)
# Build here to make `torch.jit.trace` work.
self.max_seq_len_cached = max_position_embeddings
t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False)
self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False)
def forward(self, x, seq_len=None):
# x: [bs, num_attention_heads, seq_len, head_size]
# This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.
if seq_len > self.max_seq_len_cached:
self.max_seq_len_cached = seq_len
t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype)
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False)
self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False)
return (
self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
)
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
gather_indices = position_ids[:, None, :, None] # [bs, 1, seq_len, 1]
gather_indices = gather_indices.repeat(1, cos.shape[1], 1, cos.shape[3])
cos = torch.gather(cos.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices)
sin = torch.gather(sin.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
class LlamaMLP(nn.Module):
def __init__(
self,
hidden_size: int,
intermediate_size: int,
hidden_act: str,
):
super().__init__()
self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False)
self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
self.act_fn = ACT2FN[hidden_act]
def forward(self, x):
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
class LlamaAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(self, config: LlamaConfig):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.hidden_size // self.num_heads
self.max_position_embeddings = config.max_position_embeddings
if (self.head_dim * self.num_heads) != self.hidden_size:
raise ValueError(
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
f" and `num_heads`: {self.num_heads})."
)
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
self.k_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
self.v_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
self.rotary_emb = LlamaRotaryEmbedding(self.head_dim, max_position_embeddings=self.max_position_embeddings)
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
# [bsz, nh, t, hd]
if past_key_value is not None:
# reuse k, v, self_attention
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
past_key_value = (key_states, value_states) if use_cache else None
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
raise ValueError(
f"Attention weights should be of size {(bsz * self.num_heads, q_len, kv_seq_len)}, but is"
f" {attn_weights.size()}"
)
if attention_mask is not None:
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
)
attn_weights = attn_weights + attention_mask
attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min))
# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
attn_output = torch.matmul(attn_weights, value_states)
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
f" {attn_output.size()}"
)
attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
attn_output = self.o_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value
class LlamaDecoderLayer(nn.Module):
def __init__(self, config: LlamaConfig):
super().__init__()
self.hidden_size = config.hidden_size
self.self_attn = LlamaAttention(config=config)
self.mlp = LlamaMLP(
hidden_size=self.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
)
self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
"""
Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
(see `past_key_values`).
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
"""
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
# Self Attention
hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
)
hidden_states = residual + hidden_states
# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
outputs = (hidden_states,)
if output_attentions:
outputs += (self_attn_weights,)
if use_cache:
outputs += (present_key_value,)
return outputs
LLAMA_START_DOCSTRING = r"""
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
etc.)
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
and behavior.
Parameters:
config ([`LlamaConfig`]):
Model configuration class with all the parameters of the model. Initializing with a config file does not
load the weights associated with the model, only the configuration. Check out the
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""
@add_start_docstrings(
"The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
LLAMA_START_DOCSTRING,
)
class LlamaPreTrainedModel(PreTrainedModel):
config_class = LlamaConfig
base_model_prefix = "model"
supports_gradient_checkpointing = True
_no_split_modules = ["LlamaDecoderLayer"]
_keys_to_ignore_on_load_unexpected = [r"decoder\.version"]
def _init_weights(self, module):
std = self.config.initializer_range
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, LlamaModel):
module.gradient_checkpointing = value
LLAMA_INPUTS_DOCSTRING = r"""
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
it.
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
[What are input IDs?](../glossary#input-ids)
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
`past_key_values`).
If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
information on the default strategy.
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
config.n_positions - 1]`.
[What are position IDs?](../glossary#position-ids)
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
`decoder_input_ids` of shape `(batch_size, sequence_length)`.
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
model's internal embedding lookup matrix.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
`past_key_values`).
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
@add_start_docstrings(
"The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
LLAMA_START_DOCSTRING,
)
class LlamaModel(LlamaPreTrainedModel):
"""
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`]
Args:
config: LlamaConfig
"""
def __init__(self, config: LlamaConfig):
super().__init__(config)
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
self.layers = nn.ModuleList([LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)])
self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.gradient_checkpointing = False
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.embed_tokens
def set_input_embeddings(self, value):
self.embed_tokens = value
# Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
# create causal mask
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
combined_attention_mask = None
if input_shape[-1] > 1:
combined_attention_mask = _make_causal_mask(
input_shape,
inputs_embeds.dtype,
device=inputs_embeds.device,
past_key_values_length=past_key_values_length,
)
if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
inputs_embeds.device
)
combined_attention_mask = (
expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
)
return combined_attention_mask
@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
query_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# retrieve input_ids and inputs_embeds
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
elif input_ids is not None:
batch_size, seq_length = input_ids.shape
elif inputs_embeds is not None:
batch_size, seq_length, _ = inputs_embeds.shape
else:
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
if query_embeds is not None:
inputs_embeds = torch.cat([query_embeds, inputs_embeds], dim=1)
batch_size, seq_length, _ = inputs_embeds.shape
seq_length_with_past = seq_length
past_key_values_length = 0
if past_key_values is not None:
past_key_values_length = past_key_values[0][0].shape[2]
seq_length_with_past = seq_length_with_past + past_key_values_length
if position_ids is None:
device = input_ids.device if input_ids is not None else inputs_embeds.device
position_ids = torch.arange(
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
)
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
else:
position_ids = position_ids.view(-1, seq_length).long()
# embed positions
if attention_mask is None:
attention_mask = torch.ones(
(batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
)
attention_mask = self._prepare_decoder_attention_mask(
attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
)
hidden_states = inputs_embeds
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
next_decoder_cache = () if use_cache else None
for idx, decoder_layer in enumerate(self.layers):
if output_hidden_states:
all_hidden_states += (hidden_states,)
past_key_value = past_key_values[idx] if past_key_values is not None else None
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, output_attentions, None)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(decoder_layer),
hidden_states,
attention_mask,
position_ids,
None,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
if output_attentions:
all_self_attns += (layer_outputs[1],)
hidden_states = self.norm(hidden_states)
# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
next_cache = next_decoder_cache if use_cache else None
if not return_dict:
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
class LlamaForCausalLM(LlamaPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.model = LlamaModel(config)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.model.embed_tokens
def set_input_embeddings(self, value):
self.model.embed_tokens = value
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def set_decoder(self, decoder):
self.model = decoder
def get_decoder(self):
return self.model
class LlamaForCausalLM(LlamaForCausalLMOrig):
@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
@ -633,12 +22,12 @@ class LlamaForCausalLM(LlamaPreTrainedModel):
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
query_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
reduction: Optional[str] = "mean",
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
Args:
@ -657,13 +46,13 @@ class LlamaForCausalLM(LlamaPreTrainedModel):
>>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
>>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
>>> prompt = "Hey, are you consciours? Can you talk to me?"
>>> prompt = "Hey, are you conscious? Can you talk to me?"
>>> inputs = tokenizer(prompt, return_tensors="pt")
>>> # Generate
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you."
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
```"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
@ -679,7 +68,6 @@ class LlamaForCausalLM(LlamaPreTrainedModel):
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
query_embeds=query_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
@ -687,7 +75,13 @@ class LlamaForCausalLM(LlamaPreTrainedModel):
)
hidden_states = outputs[0]
logits = self.lm_head(hidden_states)
if self.config.pretraining_tp > 1:
lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0)
logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)]
logits = torch.cat(logits, dim=-1)
else:
logits = self.lm_head(hidden_states)
logits = logits.float()
loss = None
if labels is not None:
@ -695,12 +89,14 @@ class LlamaForCausalLM(LlamaPreTrainedModel):
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
loss_fct = CrossEntropyLoss(reduction=reduction)
shift_logits = shift_logits.view(-1, self.config.vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)
if reduction == "none":
loss = loss.view(logits.size(0), -1).mean(1)
if not return_dict:
output = (logits,) + outputs[1:]
@ -713,43 +109,3 @@ class LlamaForCausalLM(LlamaPreTrainedModel):
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def prepare_inputs_for_generation(
self, input_ids, query_embeds=None, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
):
if past_key_values:
input_ids = input_ids[:, -1:]
position_ids = kwargs.get("position_ids", None)
if attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values:
position_ids = position_ids[:, -1].unsqueeze(-1)
query_embeds = None
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids}
model_inputs.update(
{
"position_ids": position_ids,
"query_embeds": query_embeds,
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),
"attention_mask": attention_mask,
}
)
return model_inputs
@staticmethod
def _reorder_cache(past_key_values, beam_idx):
reordered_past = ()
for layer_past in past_key_values:
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
return reordered_past

View File

@ -1,5 +1,5 @@
model:
arch: mini_gpt4
arch: minigpt4
model_type: pretrain_llama2

View File

@ -1,5 +1,5 @@
model:
arch: mini_gpt4
arch: minigpt4
model_type: pretrain_llama2
max_txt_len: 160

View File

@ -1,5 +1,5 @@
model:
arch: mini_gpt4
arch: minigpt4
model_type: pretrain_vicuna0

View File

@ -1,5 +1,5 @@
model:
arch: mini_gpt4
arch: minigpt4
model_type: pretrain_vicuna0
max_txt_len: 160