include llama2

This commit is contained in:
Deyao Zhu 2023-08-28 21:26:00 +03:00
parent bbd7883d1c
commit fb8e2c656a
15 changed files with 409 additions and 118 deletions

View File

@ -1,5 +1,5 @@
# MiniGPT-4: Enhancing Vision-language Understanding with Advanced Large Language Models # MiniGPT-4: Enhancing Vision-language Understanding with Advanced Large Language Models
[Deyao Zhu](https://tsutikgiau.github.io/)* (On Job Market!), [Jun Chen](https://junchen14.github.io/)* (On Job Market!), [Xiaoqian Shen](https://xiaoqian-shen.github.io), [Xiang Li](https://xiangli.ac.cn), and [Mohamed Elhoseiny](https://www.mohamed-elhoseiny.com/). *Equal Contribution [Deyao Zhu](https://tsutikgiau.github.io/)* , [Jun Chen](https://junchen14.github.io/)* (On Job Market!), [Xiaoqian Shen](https://xiaoqian-shen.github.io), [Xiang Li](https://xiangli.ac.cn), and [Mohamed Elhoseiny](https://www.mohamed-elhoseiny.com/). *Equal Contribution
**King Abdullah University of Science and Technology** **King Abdullah University of Science and Technology**
@ -7,7 +7,7 @@
## News ## News
We now provide a pretrained MiniGPT-4 aligned with Vicuna-7B! The demo GPU memory consumption now can be as low as 12GB. We now provide a llama 2 version of MiniGPT-4
## Online Demo ## Online Demo
@ -52,49 +52,52 @@ conda activate minigpt4
``` ```
**2. Prepare the pretrained Vicuna weights** **2. Prepare the pretrained LLM weights**
The current version of MiniGPT-4 is built on the v0 version of Vicuna-13B. Currently, we provide both Vicuna V0 and Llama 2 version of MiniGPT-4.
Please refer to our instruction [here](PrepareVicuna.md) Download the corresponding LLM weights from the following huggingface space via clone the repository using git-lfs.
to prepare the Vicuna weights.
The final weights would be in a single folder in a structure similar to the following: | Vicuna V0 13B | Vicuna V0 7B | Llama 2 Chat 7B |
:------------------------------------------------------------------------------------------------:|:----------------------------------------------------------------------------------------------:|:----------------------------------------------------------------------------------------------:
[Downlad](https://huggingface.co/Vision-CAIR/vicuna/tree/main) | [Download](https://huggingface.co/Vision-CAIR/vicuna-7b/tree/main) | [Download](https://huggingface.co/meta-llama/Llama-2-7b-chat/tree/main)
```
vicuna_weights
├── config.json
├── generation_config.json
├── pytorch_model.bin.index.json
├── pytorch_model-00001-of-00003.bin
...
```
Then, set the path to the vicuna weight in the model config file Then, set the path to the vicuna weight in the model config file
[here](minigpt4/configs/models/minigpt4.yaml#L16) at Line 16. [here](minigpt4/configs/models/minigpt4_vicuna0.yaml#L18) at Line 18
and/or the path to the llama2 weight in the model config file
[here](minigpt4/configs/models/minigpt4_llama2.yaml#L15) at Line 15.
**3. Prepare the pretrained MiniGPT-4 checkpoint** **3. Prepare the pretrained MiniGPT-4 checkpoint**
Download the pretrained checkpoints according to the Vicuna model you prepare. Download the pretrained checkpoints according to the Vicuna model you prepare.
| Checkpoint Aligned with Vicuna 13B | Checkpoint Aligned with Vicuna 7B | | Checkpoint Aligned with Vicuna 13B | Checkpoint Aligned with Vicuna 7B | Checkpoint Aligned with Llama 2 Chat 7B |
:------------------------------------------------------------------------------------------------:|:----------------------------------------------------------------------------------------------: :------------------------------------------------------------------------------------------------:|:-----------------------------------------------------------------------------------------------:|:----------------------------------------------------------------------------------------------:
[Downlad](https://drive.google.com/file/d/1a4zLvaiDBr-36pasffmgpvH5P7CKmpze/view?usp=share_link) | [Download](https://drive.google.com/file/d/1RY9jV0dyqLX-o38LrumkKRh6Jtaop58R/view?usp=sharing) [Downlad](https://drive.google.com/file/d/1a4zLvaiDBr-36pasffmgpvH5P7CKmpze/view?usp=share_link) | [Download](https://drive.google.com/file/d/1RY9jV0dyqLX-o38LrumkKRh6Jtaop58R/view?usp=sharing) | [Download](https://drive.google.com/file/d/11nAPjEok8eAGGEG1N2vXo3kBLCg0WgUk/view?usp=sharing)
Then, set the path to the pretrained checkpoint in the evaluation config file Then, set the path to the pretrained checkpoint in the evaluation config file
in [eval_configs/minigpt4_eval.yaml](eval_configs/minigpt4_eval.yaml#L10) at Line 11. in [eval_configs/minigpt4_eval.yaml](eval_configs/minigpt4_eval.yaml#L10) at Line 8 for Vicuna version or [eval_configs/minigpt4_llama2_eval.yaml](eval_configs/minigpt4_llama2_eval.yaml#L10) for LLama2 version.
### Launching Demo Locally ### Launching Demo Locally
Try out our demo [demo.py](demo.py) on your local machine by running Try out our demo [demo.py](demo.py) for the vicuna version on your local machine by running
``` ```
python demo.py --cfg-path eval_configs/minigpt4_eval.yaml --gpu-id 0 python demo.py --cfg-path eval_configs/minigpt4_eval.yaml --gpu-id 0
``` ```
To save GPU memory, Vicuna loads as 8 bit by default, with a beam search width of 1. or for Llama 2 version by
This configuration requires about 23G GPU memory for Vicuna 13B and 11.5G GPU memory for Vicuna 7B.
```
python demo.py --cfg-path eval_configs/minigpt4_llama2_eval.yaml --gpu-id 0
```
To save GPU memory, LLMs loads as 8 bit by default, with a beam search width of 1.
This configuration requires about 23G GPU memory for 13B LLM and 11.5G GPU memory for 7B LLM.
For more powerful GPUs, you can run the model For more powerful GPUs, you can run the model
in 16 bit by setting low_resource to False in the config file in 16 bit by setting low_resource to False in the config file
[minigpt4_eval.yaml](eval_configs/minigpt4_eval.yaml) and use a larger beam search width. [minigpt4_eval.yaml](eval_configs/minigpt4_eval.yaml) and use a larger beam search width.

12
demo.py
View File

@ -10,7 +10,7 @@ import gradio as gr
from minigpt4.common.config import Config from minigpt4.common.config import Config
from minigpt4.common.dist_utils import get_rank from minigpt4.common.dist_utils import get_rank
from minigpt4.common.registry import registry from minigpt4.common.registry import registry
from minigpt4.conversation.conversation import Chat, CONV_VISION from minigpt4.conversation.conversation import Chat, CONV_VISION_Vicuna0, CONV_VISION_LLama2
# imports modules for registration # imports modules for registration
from minigpt4.datasets.builders import * from minigpt4.datasets.builders import *
@ -50,6 +50,9 @@ def setup_seeds(config):
# Model Initialization # Model Initialization
# ======================================== # ========================================
conv_dict = {'pretrain_vicuna0': CONV_VISION_Vicuna0,
'pretrain_llama2': CONV_VISION_LLama2}
print('Initializing Chat') print('Initializing Chat')
args = parse_args() args = parse_args()
cfg = Config(args) cfg = Config(args)
@ -59,15 +62,19 @@ model_config.device_8bit = args.gpu_id
model_cls = registry.get_model_class(model_config.arch) model_cls = registry.get_model_class(model_config.arch)
model = model_cls.from_config(model_config).to('cuda:{}'.format(args.gpu_id)) model = model_cls.from_config(model_config).to('cuda:{}'.format(args.gpu_id))
CONV_VISION = conv_dict[model_config.model_type]
vis_processor_cfg = cfg.datasets_cfg.cc_sbu_align.vis_processor.train vis_processor_cfg = cfg.datasets_cfg.cc_sbu_align.vis_processor.train
vis_processor = registry.get_processor_class(vis_processor_cfg.name).from_config(vis_processor_cfg) vis_processor = registry.get_processor_class(vis_processor_cfg.name).from_config(vis_processor_cfg)
chat = Chat(model, vis_processor, device='cuda:{}'.format(args.gpu_id)) chat = Chat(model, vis_processor, device='cuda:{}'.format(args.gpu_id))
print('Initialization Finished') print('Initialization Finished')
# ======================================== # ========================================
# Gradio Setting # Gradio Setting
# ======================================== # ========================================
def gradio_reset(chat_state, img_list): def gradio_reset(chat_state, img_list):
if chat_state is not None: if chat_state is not None:
chat_state.messages = [] chat_state.messages = []
@ -75,6 +82,7 @@ def gradio_reset(chat_state, img_list):
img_list = [] img_list = []
return None, gr.update(value=None, interactive=True), gr.update(placeholder='Please upload your image first', interactive=False),gr.update(value="Upload & Start Chat", interactive=True), chat_state, img_list return None, gr.update(value=None, interactive=True), gr.update(placeholder='Please upload your image first', interactive=False),gr.update(value="Upload & Start Chat", interactive=True), chat_state, img_list
def upload_img(gr_img, text_input, chat_state): def upload_img(gr_img, text_input, chat_state):
if gr_img is None: if gr_img is None:
return None, None, gr.update(interactive=True), chat_state, None return None, None, gr.update(interactive=True), chat_state, None
@ -83,6 +91,7 @@ def upload_img(gr_img, text_input, chat_state):
llm_message = chat.upload_img(gr_img, chat_state, img_list) llm_message = chat.upload_img(gr_img, chat_state, img_list)
return gr.update(interactive=False), gr.update(interactive=True, placeholder='Type and press Enter'), gr.update(value="Start Chatting", interactive=False), chat_state, img_list return gr.update(interactive=False), gr.update(interactive=True, placeholder='Type and press Enter'), gr.update(value="Start Chatting", interactive=False), chat_state, img_list
def gradio_ask(user_message, chatbot, chat_state): def gradio_ask(user_message, chatbot, chat_state):
if len(user_message) == 0: if len(user_message) == 0:
return gr.update(interactive=True, placeholder='Input should not be empty!'), chatbot, chat_state return gr.update(interactive=True, placeholder='Input should not be empty!'), chatbot, chat_state
@ -101,6 +110,7 @@ def gradio_answer(chatbot, chat_state, img_list, num_beams, temperature):
chatbot[-1][1] = llm_message chatbot[-1][1] = llm_message
return chatbot, chat_state, img_list return chatbot, chat_state, img_list
title = """<h1 align="center">Demo of MiniGPT-4</h1>""" title = """<h1 align="center">Demo of MiniGPT-4</h1>"""
description = """<h3>This is the demo of MiniGPT-4. Upload your images and start chatting!</h3>""" description = """<h3>This is the demo of MiniGPT-4. Upload your images and start chatting!</h3>"""
article = """<p><a href='https://minigpt-4.github.io'><img src='https://img.shields.io/badge/Project-Page-Green'></a></p><p><a href='https://github.com/Vision-CAIR/MiniGPT-4'><img src='https://img.shields.io/badge/Github-Code-blue'></a></p><p><a href='https://raw.githubusercontent.com/Vision-CAIR/MiniGPT-4/main/MiniGPT_4.pdf'><img src='https://img.shields.io/badge/Paper-PDF-red'></a></p> article = """<p><a href='https://minigpt-4.github.io'><img src='https://img.shields.io/badge/Project-Page-Green'></a></p><p><a href='https://github.com/Vision-CAIR/MiniGPT-4'><img src='https://img.shields.io/badge/Github-Code-blue'></a></p><p><a href='https://raw.githubusercontent.com/Vision-CAIR/MiniGPT-4/main/MiniGPT_4.pdf'><img src='https://img.shields.io/badge/Paper-PDF-red'></a></p>

View File

@ -1,14 +1,11 @@
model: model:
arch: mini_gpt4 arch: mini_gpt4
model_type: pretrain_vicuna model_type: pretrain_vicuna0
freeze_vit: True
freeze_qformer: True
max_txt_len: 160 max_txt_len: 160
end_sym: "###" end_sym: "###"
low_resource: True low_resource: True
prompt_path: "prompts/alignment.txt"
prompt_template: '###Human: {} ###Assistant: ' prompt_template: '###Human: {} ###Assistant: '
ckpt: '/path/to/pretrained/ckpt/' ckpt: '/home/zhud/ibex/pretrained_minigpt4.pth'
datasets: datasets:

View File

@ -0,0 +1,22 @@
model:
arch: mini_gpt4
model_type: pretrain_llama2
max_txt_len: 160
end_sym: "</s>"
low_resource: True
prompt_template: '[INST] {} [/INST] '
ckpt: '/home/zhud/c2090/zhud/project/MiniGPT-4/minigpt4/output/minigpt4_stage2_finetune/20230826182/checkpoint_4.pth'
datasets:
cc_sbu_align:
vis_processor:
train:
name: "blip2_image_eval"
image_size: 224
text_processor:
train:
name: "blip_caption"
run:
task: image_text_pretrain

View File

@ -55,7 +55,10 @@ def is_main_process():
def init_distributed_mode(args): def init_distributed_mode(args):
if "RANK" in os.environ and "WORLD_SIZE" in os.environ: if args.distributed is False:
print("Not using distributed mode")
return
elif "RANK" in os.environ and "WORLD_SIZE" in os.environ:
args.rank = int(os.environ["RANK"]) args.rank = int(os.environ["RANK"])
args.world_size = int(os.environ["WORLD_SIZE"]) args.world_size = int(os.environ["WORLD_SIZE"])
args.gpu = int(os.environ["LOCAL_RANK"]) args.gpu = int(os.environ["LOCAL_RANK"])

View File

@ -0,0 +1,29 @@
model:
arch: mini_gpt4
# vit encoder
image_size: 224
drop_path_rate: 0
use_grad_checkpoint: False
vit_precision: "fp16"
freeze_vit: True
has_qformer: False
# generation configs
prompt: ""
llama_model: "/path/to/llama2/weight"
preprocess:
vis_processor:
train:
name: "blip2_image_train"
image_size: 224
eval:
name: "blip2_image_eval"
image_size: 224
text_processor:
train:
name: "blip_caption"
eval:
name: "blip_caption"

View File

@ -12,12 +12,11 @@ model:
# Q-Former # Q-Former
num_query_token: 32 num_query_token: 32
# Vicuna
llama_model: "/path/to/vicuna/weights/"
# generation configs # generation configs
prompt: "" prompt: ""
llama_model: "/path/to/vicuna/weight"
preprocess: preprocess:
vis_processor: vis_processor:
train: train:

View File

@ -39,18 +39,18 @@ class Conversation:
ret = self.system + self.sep ret = self.system + self.sep
for role, message in self.messages: for role, message in self.messages:
if message: if message:
ret += role + ": " + message + self.sep ret += role + message + self.sep
else: else:
ret += role + ":" ret += role
return ret return ret
elif self.sep_style == SeparatorStyle.TWO: elif self.sep_style == SeparatorStyle.TWO:
seps = [self.sep, self.sep2] seps = [self.sep, self.sep2]
ret = self.system + seps[0] ret = self.system + seps[0]
for i, (role, message) in enumerate(self.messages): for i, (role, message) in enumerate(self.messages):
if message: if message:
ret += role + ": " + message + seps[i % 2] ret += role + message + seps[i % 2]
else: else:
ret += role + ":" ret += role
return ret return ret
else: else:
raise ValueError(f"Invalid style: {self.sep_style}") raise ValueError(f"Invalid style: {self.sep_style}")
@ -106,16 +106,26 @@ class StoppingCriteriaSub(StoppingCriteria):
return False return False
CONV_VISION = Conversation( CONV_VISION_Vicuna0 = Conversation(
system="Give the following image: <Img>ImageContent</Img>. " system="Give the following image: <Img>ImageContent</Img>. "
"You will be able to see the image once I provide it to you. Please answer my questions.", "You will be able to see the image once I provide it to you. Please answer my questions.",
roles=("Human", "Assistant"), roles=("Human: ", "Assistant: "),
messages=[], messages=[],
offset=2, offset=2,
sep_style=SeparatorStyle.SINGLE, sep_style=SeparatorStyle.SINGLE,
sep="###", sep="###",
) )
CONV_VISION_LLama2 = Conversation(
system="Give the following image: <Img>ImageContent</Img>. "
"You will be able to see the image once I provide it to you. Please answer my questions.",
roles=("<s>[INST] ", " [/INST] "),
messages=[],
offset=2,
sep_style=SeparatorStyle.SINGLE,
sep="",
)
class Chat: class Chat:

View File

@ -22,7 +22,7 @@ class CCSBUDataset(BaseDataset):
def to_dict(self, sample): def to_dict(self, sample):
return { return {
"image": sample[0], "image": sample[0],
"text_input": self.text_processor(sample[1]["caption"]), "answer": self.text_processor(sample[1]["caption"]),
} }
@ -42,6 +42,6 @@ class CCSBUAlignDataset(CaptionDataset):
return { return {
"image": image, "image": image,
"text_input": caption, "answer": caption,
"image_id": self.img_ids[ann["image_id"]], "image_id": self.img_ids[ann["image_id"]],
} }

View File

@ -26,6 +26,6 @@ class LaionDataset(BaseDataset):
def to_dict(self, sample): def to_dict(self, sample):
return { return {
"image": sample[0], "image": sample[0],
"text_input": self.text_processor(sample[1]["caption"]), "answer": self.text_processor(sample[1]["caption"]),
} }

View File

@ -7,9 +7,17 @@ import torch.nn as nn
from minigpt4.common.registry import registry from minigpt4.common.registry import registry
from minigpt4.models.blip2 import Blip2Base, disabled_train from minigpt4.models.blip2 import Blip2Base, disabled_train
from minigpt4.models.modeling_llama import LlamaForCausalLM from transformers.models.llama.modeling_llama import LlamaForCausalLM
from transformers import LlamaTokenizer from transformers import LlamaTokenizer
from peft import (
LoraConfig,
get_peft_model,
get_peft_model_state_dict,
prepare_model_for_int8_training,
set_peft_model_state_dict,
)
@registry.register_model("mini_gpt4") @registry.register_model("mini_gpt4")
class MiniGPT4(Blip2Base): class MiniGPT4(Blip2Base):
@ -18,7 +26,8 @@ class MiniGPT4(Blip2Base):
""" """
PRETRAINED_MODEL_CONFIG_DICT = { PRETRAINED_MODEL_CONFIG_DICT = {
"pretrain_vicuna": "configs/models/minigpt4.yaml", "pretrain_vicuna0": "configs/models/minigpt4_vicuna0.yaml",
"pretrain_llama2": "configs/models/minigpt4_llama2.yaml",
} }
def __init__( def __init__(
@ -30,6 +39,7 @@ class MiniGPT4(Blip2Base):
use_grad_checkpoint=False, use_grad_checkpoint=False,
vit_precision="fp16", vit_precision="fp16",
freeze_vit=True, freeze_vit=True,
has_qformer=True,
freeze_qformer=True, freeze_qformer=True,
num_query_token=32, num_query_token=32,
llama_model="", llama_model="",
@ -39,6 +49,10 @@ class MiniGPT4(Blip2Base):
end_sym='\n', end_sym='\n',
low_resource=False, # use 8 bit and put vit in cpu low_resource=False, # use 8 bit and put vit in cpu
device_8bit=0, # the device of 8bit model should be set when loading and cannot be changed anymore. device_8bit=0, # the device of 8bit model should be set when loading and cannot be changed anymore.
lora_r=0,
lora_target_modules=["q_proj", "v_proj"],
lora_alpha=16,
lora_dropout=0.05,
): ):
super().__init__() super().__init__()
@ -61,30 +75,37 @@ class MiniGPT4(Blip2Base):
logging.info("freeze vision encoder") logging.info("freeze vision encoder")
print('Loading VIT Done') print('Loading VIT Done')
print('Loading Q-Former') self.has_qformer = has_qformer
self.Qformer, self.query_tokens = self.init_Qformer( if self.has_qformer:
num_query_token, self.visual_encoder.num_features print('Loading Q-Former')
) self.Qformer, self.query_tokens = self.init_Qformer(
self.Qformer.cls = None num_query_token, self.visual_encoder.num_features
self.Qformer.bert.embeddings.word_embeddings = None )
self.Qformer.bert.embeddings.position_embeddings = None self.Qformer.cls = None
for layer in self.Qformer.bert.encoder.layer: self.Qformer.bert.embeddings.word_embeddings = None
layer.output = None self.Qformer.bert.embeddings.position_embeddings = None
layer.intermediate = None for layer in self.Qformer.bert.encoder.layer:
self.load_from_pretrained(url_or_filename=q_former_model) layer.output = None
layer.intermediate = None
self.load_from_pretrained(url_or_filename=q_former_model)
if freeze_qformer: if freeze_qformer:
for name, param in self.Qformer.named_parameters(): for name, param in self.Qformer.named_parameters():
param.requires_grad = False param.requires_grad = False
self.Qformer = self.Qformer.eval() self.Qformer = self.Qformer.eval()
self.Qformer.train = disabled_train self.Qformer.train = disabled_train
self.query_tokens.requires_grad = False self.query_tokens.requires_grad = False
logging.info("freeze Qformer") logging.info("freeze Qformer")
print('Loading Q-Former Done')
img_f_dim = self.Qformer.config.hidden_size
print('Loading Q-Former Done')
else:
img_f_dim = self.visual_encoder.num_features * 4
print('Do not use Q-Former here.')
print('Loading LLAMA') print('Loading LLAMA')
self.llama_tokenizer = LlamaTokenizer.from_pretrained(llama_model, use_fast=False) self.llama_tokenizer = LlamaTokenizer.from_pretrained(llama_model, use_fast=False)
self.llama_tokenizer.pad_token = self.llama_tokenizer.eos_token self.llama_tokenizer.pad_token = "$$"
if self.low_resource: if self.low_resource:
self.llama_model = LlamaForCausalLM.from_pretrained( self.llama_model = LlamaForCausalLM.from_pretrained(
@ -99,12 +120,31 @@ class MiniGPT4(Blip2Base):
torch_dtype=torch.float16, torch_dtype=torch.float16,
) )
for name, param in self.llama_model.named_parameters(): if lora_r > 0:
param.requires_grad = False self.llama_model = prepare_model_for_int8_training(self.llama_model)
loraconfig = LoraConfig(
r=lora_r,
lora_alpha=lora_alpha,
target_modules=lora_target_modules,
lora_dropout=lora_dropout,
bias="none",
task_type="CAUSAL_LM"
)
self.llama_model = get_peft_model(self.llama_model, loraconfig)
# if ckpt_path:
# print('load the llm under lora')
# ckpt = torch.load(ckpt_path)
# set_peft_model_state_dict(self.llama_model,ckpt)
self.llama_model.print_trainable_parameters()
else:
for name, param in self.llama_model.named_parameters():
param.requires_grad = False
print('Loading LLAMA Done') print('Loading LLAMA Done')
self.llama_proj = nn.Linear( self.llama_proj = nn.Linear(
self.Qformer.config.hidden_size, self.llama_model.config.hidden_size img_f_dim, self.llama_model.config.hidden_size
) )
self.max_txt_len = max_txt_len self.max_txt_len = max_txt_len
self.end_sym = end_sym self.end_sym = end_sym
@ -133,50 +173,109 @@ class MiniGPT4(Blip2Base):
with self.maybe_autocast(): with self.maybe_autocast():
image_embeds = self.ln_vision(self.visual_encoder(image)).to(device) image_embeds = self.ln_vision(self.visual_encoder(image)).to(device)
image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(device) if self.has_qformer:
image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(device)
query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1) query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
query_output = self.Qformer.bert( query_output = self.Qformer.bert(
query_embeds=query_tokens, query_embeds=query_tokens,
encoder_hidden_states=image_embeds, encoder_hidden_states=image_embeds,
encoder_attention_mask=image_atts, encoder_attention_mask=image_atts,
return_dict=True, return_dict=True,
) )
inputs_llama = self.llama_proj(query_output.last_hidden_state) inputs_llama = self.llama_proj(query_output.last_hidden_state)
else:
image_embeds = image_embeds[:, 1:, :]
bs, pn, hs = image_embeds.shape
image_embeds = image_embeds.view(bs, int(pn / 4), int(hs * 4))
inputs_llama = self.llama_proj(image_embeds)
atts_llama = torch.ones(inputs_llama.size()[:-1], dtype=torch.long).to(image.device) atts_llama = torch.ones(inputs_llama.size()[:-1], dtype=torch.long).to(image.device)
return inputs_llama, atts_llama return inputs_llama, atts_llama
def prompt_wrap(self, img_embeds, atts_img, prompt): def get_context_emb(self, prompt, img_list):
if prompt: device = img_list[0].device
batch_size = img_embeds.shape[0] prompt_segs = prompt.split('<ImageHere>')
p_before, p_after = prompt.split('<ImageHere>') assert len(prompt_segs) == len(img_list) + 1, "Unmatched numbers of image placeholders and images."
p_before_tokens = self.llama_tokenizer( seg_tokens = [
p_before, return_tensors="pt", add_special_tokens=False).to(img_embeds.device) self.llama_tokenizer(
p_after_tokens = self.llama_tokenizer( seg, return_tensors="pt", add_special_tokens=i == 0).to(device).input_ids
p_after, return_tensors="pt", add_special_tokens=False).to(img_embeds.device) # only add bos to the first seg
p_before_embeds = self.llama_model.model.embed_tokens(p_before_tokens.input_ids).expand(batch_size, -1, -1) for i, seg in enumerate(prompt_segs)
p_after_embeds = self.llama_model.model.embed_tokens(p_after_tokens.input_ids).expand(batch_size, -1, -1) ]
wrapped_img_embeds = torch.cat([p_before_embeds, img_embeds, p_after_embeds], dim=1) seg_embs = [self.embed_tokens(seg_t) for seg_t in seg_tokens]
wrapped_atts_img = atts_img[:, :1].expand(-1, wrapped_img_embeds.shape[1])
return wrapped_img_embeds, wrapped_atts_img mixed_embs = [emb for pair in zip(seg_embs[:-1], img_list) for emb in pair] + [seg_embs[-1]]
mixed_embs = torch.cat(mixed_embs, dim=1)
return mixed_embs
def prompt_wrap(self, img_embeds, atts_img, prompts):
if prompts:
emb_lists = []
if isinstance(prompts, str):
prompts = [prompts] * len(img_embeds)
for each_img_embed, each_prompt in zip(img_embeds, prompts):
p_before, p_after = each_prompt.split('<ImageHere>')
p_before_tokens = self.llama_tokenizer(
p_before, return_tensors="pt", add_special_tokens=False).to(img_embeds.device)
p_after_tokens = self.llama_tokenizer(
p_after, return_tensors="pt", add_special_tokens=False).to(img_embeds.device)
p_before_embed = self.embed_tokens(p_before_tokens.input_ids)
p_after_embed = self.embed_tokens(p_after_tokens.input_ids)
wrapped_emb = torch.cat([p_before_embed, each_img_embed[None], p_after_embed], dim=1)
emb_lists.append(wrapped_emb)
emb_lens = [emb.shape[1] for emb in emb_lists]
pad_emb = self.embed_tokens(torch.tensor(self.llama_tokenizer.pad_token_id, device=img_embeds.device))
wrapped_embs = pad_emb.expand(len(emb_lens), max(emb_lens), -1).clone()
wrapped_atts = torch.zeros([len(emb_lens), max(emb_lens)], dtype=torch.int, device=img_embeds.device)
for i, emb in enumerate(emb_lists):
wrapped_embs[i, :emb_lens[i]] = emb
wrapped_atts[i, :emb_lens[i]] = 1
return wrapped_embs, wrapped_atts
else: else:
return img_embeds, atts_img return img_embeds, atts_img
def concat_emb_input_output(self, input_embs, input_atts, output_embs, output_atts):
input_lens = []
cat_embs = []
cat_atts = []
for i in range(input_embs.size(0)):
input_len = input_atts[i].sum()
input_lens.append(input_len)
cat_embs.append(
torch.cat([
input_embs[i][:input_len],
output_embs[i],
input_embs[i][input_len:]
])
)
cat_atts.append(
torch.cat([
input_atts[i][:input_len],
output_atts[i],
input_atts[i][input_len:]
])
)
cat_embs = torch.stack(cat_embs)
cat_atts = torch.stack(cat_atts)
return cat_embs, cat_atts, input_lens
def forward(self, samples): def forward(self, samples):
image = samples["image"] image = samples["image"]
img_embeds, atts_img = self.encode_img(image) img_embeds, atts_img = self.encode_img(image)
if hasattr(samples, 'question_split'): # VQA dataset
print('VQA Batch') if self.prompt_list:
vqa_prompt = '###Human: <Img><ImageHere></Img> ' instruction = random.choice(self.prompt_list)
img_embeds, atts_img = self.prompt_wrap(img_embeds, atts_img, vqa_prompt) else:
elif self.prompt_list: instruction = samples["instruction_input"] if "instruction_input" in samples else None
prompt = random.choice(self.prompt_list)
img_embeds, atts_img = self.prompt_wrap(img_embeds, atts_img, prompt) img_embeds, atts_img = self.prompt_wrap(img_embeds, atts_img, instruction)
self.llama_tokenizer.padding_side = "right" self.llama_tokenizer.padding_side = "right"
text = [t + self.end_sym for t in samples["answer"]]
text = [t + self.end_sym for t in samples["text_input"]]
to_regress_tokens = self.llama_tokenizer( to_regress_tokens = self.llama_tokenizer(
text, text,
@ -187,26 +286,29 @@ class MiniGPT4(Blip2Base):
add_special_tokens=False add_special_tokens=False
).to(image.device) ).to(image.device)
targets = to_regress_tokens.input_ids.masked_fill(
to_regress_tokens.input_ids == self.llama_tokenizer.pad_token_id, -100
)
empty_targets = (
torch.ones([atts_img.shape[0], atts_img.shape[1]+1],
dtype=torch.long).to(image.device).fill_(-100) # plus one for bos
)
targets = torch.cat([empty_targets, targets], dim=1)
batch_size = img_embeds.shape[0] batch_size = img_embeds.shape[0]
bos = torch.ones([batch_size, 1], bos = torch.ones([batch_size, 1],
dtype=to_regress_tokens.input_ids.dtype, dtype=to_regress_tokens.input_ids.dtype,
device=to_regress_tokens.input_ids.device) * self.llama_tokenizer.bos_token_id device=to_regress_tokens.input_ids.device) * self.llama_tokenizer.bos_token_id
bos_embeds = self.llama_model.model.embed_tokens(bos) bos_embeds = self.embed_tokens(bos)
atts_bos = atts_img[:, :1] atts_bos = atts_img[:, :1]
to_regress_embeds = self.llama_model.model.embed_tokens(to_regress_tokens.input_ids) to_regress_embeds = self.embed_tokens(to_regress_tokens.input_ids)
inputs_embeds = torch.cat([bos_embeds, img_embeds, to_regress_embeds], dim=1) inputs_embeds, attention_mask, input_lens = \
attention_mask = torch.cat([atts_bos, atts_img, to_regress_tokens.attention_mask], dim=1) self.concat_emb_input_output(img_embeds, atts_img, to_regress_embeds, to_regress_tokens.attention_mask)
inputs_embeds = torch.cat([bos_embeds, inputs_embeds], dim=1)
attention_mask = torch.cat([atts_bos, attention_mask], dim=1)
part_targets = to_regress_tokens.input_ids.masked_fill(
to_regress_tokens.input_ids == self.llama_tokenizer.pad_token_id, -100
)
targets = (
torch.ones([inputs_embeds.shape[0], inputs_embeds.shape[1]],
dtype=torch.long).to(image.device).fill_(-100)
)
for i, target in enumerate(part_targets):
targets[i, input_lens[i] + 1:input_lens[i] + len(target) + 1] = target # plus 1 for bos
with self.maybe_autocast(): with self.maybe_autocast():
outputs = self.llama_model( outputs = self.llama_model(
@ -219,6 +321,13 @@ class MiniGPT4(Blip2Base):
return {"loss": loss} return {"loss": loss}
def embed_tokens(self, token_ids):
if hasattr(self.llama_model.base_model, 'model'): ## lora wrapped model
embeds = self.llama_model.base_model.model.model.embed_tokens(token_ids)
else:
embeds = self.llama_model.base_model.embed_tokens(token_ids)
return embeds
@classmethod @classmethod
def from_config(cls, cfg): def from_config(cls, cfg):
vit_model = cfg.get("vit_model", "eva_clip_g") vit_model = cfg.get("vit_model", "eva_clip_g")
@ -231,6 +340,7 @@ class MiniGPT4(Blip2Base):
use_grad_checkpoint = cfg.get("use_grad_checkpoint", False) use_grad_checkpoint = cfg.get("use_grad_checkpoint", False)
vit_precision = cfg.get("vit_precision", "fp16") vit_precision = cfg.get("vit_precision", "fp16")
freeze_vit = cfg.get("freeze_vit", True) freeze_vit = cfg.get("freeze_vit", True)
has_qformer = cfg.get("has_qformer", True)
freeze_qformer = cfg.get("freeze_qformer", True) freeze_qformer = cfg.get("freeze_qformer", True)
low_resource = cfg.get("low_resource", False) low_resource = cfg.get("low_resource", False)
device_8bit = cfg.get("device_8bit", 0) device_8bit = cfg.get("device_8bit", 0)
@ -240,6 +350,9 @@ class MiniGPT4(Blip2Base):
max_txt_len = cfg.get("max_txt_len", 32) max_txt_len = cfg.get("max_txt_len", 32)
end_sym = cfg.get("end_sym", '\n') end_sym = cfg.get("end_sym", '\n')
lora_r = cfg.get("lora_r", 0)
lora_alpha = cfg.get("lora_alpha", 32)
model = cls( model = cls(
vit_model=vit_model, vit_model=vit_model,
q_former_model=q_former_model, q_former_model=q_former_model,
@ -248,6 +361,7 @@ class MiniGPT4(Blip2Base):
use_grad_checkpoint=use_grad_checkpoint, use_grad_checkpoint=use_grad_checkpoint,
vit_precision=vit_precision, vit_precision=vit_precision,
freeze_vit=freeze_vit, freeze_vit=freeze_vit,
has_qformer=has_qformer,
freeze_qformer=freeze_qformer, freeze_qformer=freeze_qformer,
num_query_token=num_query_token, num_query_token=num_query_token,
llama_model=llama_model, llama_model=llama_model,
@ -257,6 +371,8 @@ class MiniGPT4(Blip2Base):
end_sym=end_sym, end_sym=end_sym,
low_resource=low_resource, low_resource=low_resource,
device_8bit=device_8bit, device_8bit=device_8bit,
lora_r=lora_r,
lora_alpha=lora_alpha,
) )
ckpt_path = cfg.get("ckpt", "") # load weights of MiniGPT-4 ckpt_path = cfg.get("ckpt", "") # load weights of MiniGPT-4

View File

@ -0,0 +1,55 @@
model:
arch: mini_gpt4
model_type: pretrain_llama2
datasets:
laion:
vis_processor:
train:
name: "blip2_image_train"
image_size: 224
text_processor:
train:
name: "blip_caption"
sample_ratio: 115
cc_sbu:
vis_processor:
train:
name: "blip2_image_train"
image_size: 224
text_processor:
train:
name: "blip_caption"
sample_ratio: 14
run:
task: image_text_pretrain
# optimizer
lr_sched: "linear_warmup_cosine_lr"
init_lr: 1e-4
min_lr: 8e-5
warmup_lr: 1e-6
weight_decay: 0.05
max_epoch: 4
batch_size_train: 64
batch_size_eval: 64
num_workers: 4
warmup_steps: 5000
iters_per_epoch: 5000
seed: 42
output_dir: "output/minigpt4_stage1_pretrain"
amp: True
resume_ckpt_path: null
evaluate: False
train_splits: ["train"]
device: "cuda"
world_size: 1
dist_url: "env://"
distributed: True

View File

@ -0,0 +1,50 @@
model:
arch: mini_gpt4
model_type: pretrain_llama2
max_txt_len: 160
end_sym: "</s>"
prompt_path: "prompts/alignment.txt"
prompt_template: '[INST] {} [/INST] '
ckpt: '/path/to/stage1/checkpoint/'
datasets:
cc_sbu_align:
vis_processor:
train:
name: "blip2_image_train"
image_size: 224
text_processor:
train:
name: "blip_caption"
run:
task: image_text_pretrain
# optimizer
lr_sched: "linear_warmup_cosine_lr"
init_lr: 3e-5
min_lr: 1e-5
warmup_lr: 1e-6
weight_decay: 0.05
max_epoch: 5
iters_per_epoch: 200
batch_size_train: 12
batch_size_eval: 12
num_workers: 4
warmup_steps: 200
seed: 42
output_dir: "output/minigpt4_stage2_finetune"
amp: True
resume_ckpt_path: null
evaluate: False
train_splits: ["train"]
device: "cuda"
world_size: 1
dist_url: "env://"
distributed: True

View File

@ -1,8 +1,6 @@
model: model:
arch: mini_gpt4 arch: mini_gpt4
model_type: pretrain_vicuna model_type: pretrain_vicuna0
freeze_vit: True
freeze_qformer: True
datasets: datasets:

View File

@ -1,8 +1,7 @@
model: model:
arch: mini_gpt4 arch: mini_gpt4
model_type: pretrain_vicuna model_type: pretrain_vicuna0
freeze_vit: True
freeze_qformer: True
max_txt_len: 160 max_txt_len: 160
end_sym: "###" end_sym: "###"
prompt_path: "prompts/alignment.txt" prompt_path: "prompts/alignment.txt"