mirror of
https://github.com/Vision-CAIR/MiniGPT-4.git
synced 2025-04-04 18:10:47 +00:00
Merge branch 'Vision-CAIR:main' into main
This commit is contained in:
commit
43fb425fef
49
README.md
49
README.md
@ -1,5 +1,5 @@
|
||||
# MiniGPT-4: Enhancing Vision-language Understanding with Advanced Large Language Models
|
||||
[Deyao Zhu](https://tsutikgiau.github.io/)* (On Job Market!), [Jun Chen](https://junchen14.github.io/)* (On Job Market!), [Xiaoqian Shen](https://xiaoqian-shen.github.io), [Xiang Li](https://xiangli.ac.cn), and [Mohamed Elhoseiny](https://www.mohamed-elhoseiny.com/). *Equal Contribution
|
||||
[Deyao Zhu](https://tsutikgiau.github.io/)* , [Jun Chen](https://junchen14.github.io/)* (On Job Market!), [Xiaoqian Shen](https://xiaoqian-shen.github.io), [Xiang Li](https://xiangli.ac.cn), and [Mohamed Elhoseiny](https://www.mohamed-elhoseiny.com/). *Equal Contribution
|
||||
|
||||
**King Abdullah University of Science and Technology**
|
||||
|
||||
@ -9,7 +9,7 @@
|
||||
|
||||
|
||||
## News
|
||||
We now provide a pretrained MiniGPT-4 aligned with Vicuna-7B! The demo GPU memory consumption now can be as low as 12GB.
|
||||
We now provide a llama 2 version of MiniGPT-4
|
||||
|
||||
|
||||
## Online Demo
|
||||
@ -54,49 +54,52 @@ conda activate minigpt4
|
||||
```
|
||||
|
||||
|
||||
**2. Prepare the pretrained Vicuna weights**
|
||||
**2. Prepare the pretrained LLM weights**
|
||||
|
||||
The current version of MiniGPT-4 is built on the v0 version of Vicuna-13B.
|
||||
Please refer to our instruction [here](PrepareVicuna.md)
|
||||
to prepare the Vicuna weights.
|
||||
The final weights would be in a single folder in a structure similar to the following:
|
||||
Currently, we provide both Vicuna V0 and Llama 2 version of MiniGPT-4.
|
||||
Download the corresponding LLM weights from the following huggingface space via clone the repository using git-lfs.
|
||||
|
||||
| Vicuna V0 13B | Vicuna V0 7B | Llama 2 Chat 7B |
|
||||
:------------------------------------------------------------------------------------------------:|:----------------------------------------------------------------------------------------------:|:----------------------------------------------------------------------------------------------:
|
||||
[Downlad](https://huggingface.co/Vision-CAIR/vicuna/tree/main) | [Download](https://huggingface.co/Vision-CAIR/vicuna-7b/tree/main) | [Download](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf/tree/main)
|
||||
|
||||
```
|
||||
vicuna_weights
|
||||
├── config.json
|
||||
├── generation_config.json
|
||||
├── pytorch_model.bin.index.json
|
||||
├── pytorch_model-00001-of-00003.bin
|
||||
...
|
||||
```
|
||||
|
||||
Then, set the path to the vicuna weight in the model config file
|
||||
[here](minigpt4/configs/models/minigpt4.yaml#L16) at Line 16.
|
||||
[here](minigpt4/configs/models/minigpt4_vicuna0.yaml#L18) at Line 18
|
||||
and/or the path to the llama2 weight in the model config file
|
||||
[here](minigpt4/configs/models/minigpt4_llama2.yaml#L15) at Line 15.
|
||||
|
||||
**3. Prepare the pretrained MiniGPT-4 checkpoint**
|
||||
|
||||
Download the pretrained checkpoints according to the Vicuna model you prepare.
|
||||
|
||||
| Checkpoint Aligned with Vicuna 13B | Checkpoint Aligned with Vicuna 7B |
|
||||
:------------------------------------------------------------------------------------------------:|:----------------------------------------------------------------------------------------------:
|
||||
[Downlad](https://drive.google.com/file/d/1a4zLvaiDBr-36pasffmgpvH5P7CKmpze/view?usp=share_link) | [Download](https://drive.google.com/file/d/1RY9jV0dyqLX-o38LrumkKRh6Jtaop58R/view?usp=sharing)
|
||||
| Checkpoint Aligned with Vicuna 13B | Checkpoint Aligned with Vicuna 7B | Checkpoint Aligned with Llama 2 Chat 7B |
|
||||
:------------------------------------------------------------------------------------------------:|:-----------------------------------------------------------------------------------------------:|:----------------------------------------------------------------------------------------------:
|
||||
[Downlad](https://drive.google.com/file/d/1a4zLvaiDBr-36pasffmgpvH5P7CKmpze/view?usp=share_link) | [Download](https://drive.google.com/file/d/1RY9jV0dyqLX-o38LrumkKRh6Jtaop58R/view?usp=sharing) | [Download](https://drive.google.com/file/d/11nAPjEok8eAGGEG1N2vXo3kBLCg0WgUk/view?usp=sharing)
|
||||
|
||||
|
||||
Then, set the path to the pretrained checkpoint in the evaluation config file
|
||||
in [eval_configs/minigpt4_eval.yaml](eval_configs/minigpt4_eval.yaml#L10) at Line 11.
|
||||
in [eval_configs/minigpt4_eval.yaml](eval_configs/minigpt4_eval.yaml#L10) at Line 8 for Vicuna version or [eval_configs/minigpt4_llama2_eval.yaml](eval_configs/minigpt4_llama2_eval.yaml#L10) for LLama2 version.
|
||||
|
||||
|
||||
|
||||
### Launching Demo Locally
|
||||
|
||||
Try out our demo [demo.py](demo.py) on your local machine by running
|
||||
Try out our demo [demo.py](demo.py) for the vicuna version on your local machine by running
|
||||
|
||||
```
|
||||
python demo.py --cfg-path eval_configs/minigpt4_eval.yaml --gpu-id 0
|
||||
```
|
||||
|
||||
To save GPU memory, Vicuna loads as 8 bit by default, with a beam search width of 1.
|
||||
This configuration requires about 23G GPU memory for Vicuna 13B and 11.5G GPU memory for Vicuna 7B.
|
||||
or for Llama 2 version by
|
||||
|
||||
```
|
||||
python demo.py --cfg-path eval_configs/minigpt4_llama2_eval.yaml --gpu-id 0
|
||||
```
|
||||
|
||||
|
||||
To save GPU memory, LLMs loads as 8 bit by default, with a beam search width of 1.
|
||||
This configuration requires about 23G GPU memory for 13B LLM and 11.5G GPU memory for 7B LLM.
|
||||
For more powerful GPUs, you can run the model
|
||||
in 16 bit by setting low_resource to False in the config file
|
||||
[minigpt4_eval.yaml](eval_configs/minigpt4_eval.yaml) and use a larger beam search width.
|
||||
|
12
demo.py
12
demo.py
@ -10,7 +10,7 @@ import gradio as gr
|
||||
from minigpt4.common.config import Config
|
||||
from minigpt4.common.dist_utils import get_rank
|
||||
from minigpt4.common.registry import registry
|
||||
from minigpt4.conversation.conversation import Chat, CONV_VISION
|
||||
from minigpt4.conversation.conversation import Chat, CONV_VISION_Vicuna0, CONV_VISION_LLama2
|
||||
|
||||
# imports modules for registration
|
||||
from minigpt4.datasets.builders import *
|
||||
@ -50,6 +50,9 @@ def setup_seeds(config):
|
||||
# Model Initialization
|
||||
# ========================================
|
||||
|
||||
conv_dict = {'pretrain_vicuna0': CONV_VISION_Vicuna0,
|
||||
'pretrain_llama2': CONV_VISION_LLama2}
|
||||
|
||||
print('Initializing Chat')
|
||||
args = parse_args()
|
||||
cfg = Config(args)
|
||||
@ -59,15 +62,19 @@ model_config.device_8bit = args.gpu_id
|
||||
model_cls = registry.get_model_class(model_config.arch)
|
||||
model = model_cls.from_config(model_config).to('cuda:{}'.format(args.gpu_id))
|
||||
|
||||
CONV_VISION = conv_dict[model_config.model_type]
|
||||
|
||||
vis_processor_cfg = cfg.datasets_cfg.cc_sbu_align.vis_processor.train
|
||||
vis_processor = registry.get_processor_class(vis_processor_cfg.name).from_config(vis_processor_cfg)
|
||||
chat = Chat(model, vis_processor, device='cuda:{}'.format(args.gpu_id))
|
||||
print('Initialization Finished')
|
||||
|
||||
|
||||
# ========================================
|
||||
# Gradio Setting
|
||||
# ========================================
|
||||
|
||||
|
||||
def gradio_reset(chat_state, img_list):
|
||||
if chat_state is not None:
|
||||
chat_state.messages = []
|
||||
@ -75,6 +82,7 @@ def gradio_reset(chat_state, img_list):
|
||||
img_list = []
|
||||
return None, gr.update(value=None, interactive=True), gr.update(placeholder='Please upload your image first', interactive=False),gr.update(value="Upload & Start Chat", interactive=True), chat_state, img_list
|
||||
|
||||
|
||||
def upload_img(gr_img, text_input, chat_state):
|
||||
if gr_img is None:
|
||||
return None, None, gr.update(interactive=True), chat_state, None
|
||||
@ -83,6 +91,7 @@ def upload_img(gr_img, text_input, chat_state):
|
||||
llm_message = chat.upload_img(gr_img, chat_state, img_list)
|
||||
return gr.update(interactive=False), gr.update(interactive=True, placeholder='Type and press Enter'), gr.update(value="Start Chatting", interactive=False), chat_state, img_list
|
||||
|
||||
|
||||
def gradio_ask(user_message, chatbot, chat_state):
|
||||
if len(user_message) == 0:
|
||||
return gr.update(interactive=True, placeholder='Input should not be empty!'), chatbot, chat_state
|
||||
@ -101,6 +110,7 @@ def gradio_answer(chatbot, chat_state, img_list, num_beams, temperature):
|
||||
chatbot[-1][1] = llm_message
|
||||
return chatbot, chat_state, img_list
|
||||
|
||||
|
||||
title = """<h1 align="center">Demo of MiniGPT-4</h1>"""
|
||||
description = """<h3>This is the demo of MiniGPT-4. Upload your images and start chatting!</h3>"""
|
||||
article = """<p><a href='https://minigpt-4.github.io'><img src='https://img.shields.io/badge/Project-Page-Green'></a></p><p><a href='https://github.com/Vision-CAIR/MiniGPT-4'><img src='https://img.shields.io/badge/Github-Code-blue'></a></p><p><a href='https://raw.githubusercontent.com/Vision-CAIR/MiniGPT-4/main/MiniGPT_4.pdf'><img src='https://img.shields.io/badge/Paper-PDF-red'></a></p>
|
||||
|
@ -1,14 +1,11 @@
|
||||
model:
|
||||
arch: mini_gpt4
|
||||
model_type: pretrain_vicuna
|
||||
freeze_vit: True
|
||||
freeze_qformer: True
|
||||
model_type: pretrain_vicuna0
|
||||
max_txt_len: 160
|
||||
end_sym: "###"
|
||||
low_resource: True
|
||||
prompt_path: "prompts/alignment.txt"
|
||||
prompt_template: '###Human: {} ###Assistant: '
|
||||
ckpt: '/path/to/pretrained/ckpt/'
|
||||
ckpt: '/path/to/checkpoint/'
|
||||
|
||||
|
||||
datasets:
|
||||
|
22
eval_configs/minigpt4_llama2_eval.yaml
Normal file
22
eval_configs/minigpt4_llama2_eval.yaml
Normal file
@ -0,0 +1,22 @@
|
||||
model:
|
||||
arch: mini_gpt4
|
||||
model_type: pretrain_llama2
|
||||
max_txt_len: 160
|
||||
end_sym: "</s>"
|
||||
low_resource: True
|
||||
prompt_template: '[INST] {} [/INST] '
|
||||
ckpt: '/path/to/checkpoint/'
|
||||
|
||||
|
||||
datasets:
|
||||
cc_sbu_align:
|
||||
vis_processor:
|
||||
train:
|
||||
name: "blip2_image_eval"
|
||||
image_size: 224
|
||||
text_processor:
|
||||
train:
|
||||
name: "blip_caption"
|
||||
|
||||
run:
|
||||
task: image_text_pretrain
|
@ -55,7 +55,10 @@ def is_main_process():
|
||||
|
||||
|
||||
def init_distributed_mode(args):
|
||||
if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
|
||||
if args.distributed is False:
|
||||
print("Not using distributed mode")
|
||||
return
|
||||
elif "RANK" in os.environ and "WORLD_SIZE" in os.environ:
|
||||
args.rank = int(os.environ["RANK"])
|
||||
args.world_size = int(os.environ["WORLD_SIZE"])
|
||||
args.gpu = int(os.environ["LOCAL_RANK"])
|
||||
|
29
minigpt4/configs/models/minigpt4_llama2.yaml
Normal file
29
minigpt4/configs/models/minigpt4_llama2.yaml
Normal file
@ -0,0 +1,29 @@
|
||||
model:
|
||||
arch: mini_gpt4
|
||||
|
||||
# vit encoder
|
||||
image_size: 224
|
||||
drop_path_rate: 0
|
||||
use_grad_checkpoint: False
|
||||
vit_precision: "fp16"
|
||||
freeze_vit: True
|
||||
has_qformer: False
|
||||
|
||||
# generation configs
|
||||
prompt: ""
|
||||
|
||||
llama_model: "/path/to/llama2/weight"
|
||||
|
||||
preprocess:
|
||||
vis_processor:
|
||||
train:
|
||||
name: "blip2_image_train"
|
||||
image_size: 224
|
||||
eval:
|
||||
name: "blip2_image_eval"
|
||||
image_size: 224
|
||||
text_processor:
|
||||
train:
|
||||
name: "blip_caption"
|
||||
eval:
|
||||
name: "blip_caption"
|
@ -12,12 +12,11 @@ model:
|
||||
# Q-Former
|
||||
num_query_token: 32
|
||||
|
||||
# Vicuna
|
||||
llama_model: "/path/to/vicuna/weights/"
|
||||
|
||||
# generation configs
|
||||
prompt: ""
|
||||
|
||||
llama_model: "/path/to/vicuna/weight"
|
||||
|
||||
preprocess:
|
||||
vis_processor:
|
||||
train:
|
@ -39,18 +39,18 @@ class Conversation:
|
||||
ret = self.system + self.sep
|
||||
for role, message in self.messages:
|
||||
if message:
|
||||
ret += role + ": " + message + self.sep
|
||||
ret += role + message + self.sep
|
||||
else:
|
||||
ret += role + ":"
|
||||
ret += role
|
||||
return ret
|
||||
elif self.sep_style == SeparatorStyle.TWO:
|
||||
seps = [self.sep, self.sep2]
|
||||
ret = self.system + seps[0]
|
||||
for i, (role, message) in enumerate(self.messages):
|
||||
if message:
|
||||
ret += role + ": " + message + seps[i % 2]
|
||||
ret += role + message + seps[i % 2]
|
||||
else:
|
||||
ret += role + ":"
|
||||
ret += role
|
||||
return ret
|
||||
else:
|
||||
raise ValueError(f"Invalid style: {self.sep_style}")
|
||||
@ -106,16 +106,26 @@ class StoppingCriteriaSub(StoppingCriteria):
|
||||
return False
|
||||
|
||||
|
||||
CONV_VISION = Conversation(
|
||||
CONV_VISION_Vicuna0 = Conversation(
|
||||
system="Give the following image: <Img>ImageContent</Img>. "
|
||||
"You will be able to see the image once I provide it to you. Please answer my questions.",
|
||||
roles=("Human", "Assistant"),
|
||||
roles=("Human: ", "Assistant: "),
|
||||
messages=[],
|
||||
offset=2,
|
||||
sep_style=SeparatorStyle.SINGLE,
|
||||
sep="###",
|
||||
)
|
||||
|
||||
CONV_VISION_LLama2 = Conversation(
|
||||
system="Give the following image: <Img>ImageContent</Img>. "
|
||||
"You will be able to see the image once I provide it to you. Please answer my questions.",
|
||||
roles=("<s>[INST] ", " [/INST] "),
|
||||
messages=[],
|
||||
offset=2,
|
||||
sep_style=SeparatorStyle.SINGLE,
|
||||
sep="",
|
||||
)
|
||||
|
||||
|
||||
|
||||
class Chat:
|
||||
|
@ -22,7 +22,7 @@ class CCSBUDataset(BaseDataset):
|
||||
def to_dict(self, sample):
|
||||
return {
|
||||
"image": sample[0],
|
||||
"text_input": self.text_processor(sample[1]["caption"]),
|
||||
"answer": self.text_processor(sample[1]["caption"]),
|
||||
}
|
||||
|
||||
|
||||
@ -42,6 +42,6 @@ class CCSBUAlignDataset(CaptionDataset):
|
||||
|
||||
return {
|
||||
"image": image,
|
||||
"text_input": caption,
|
||||
"answer": caption,
|
||||
"image_id": self.img_ids[ann["image_id"]],
|
||||
}
|
@ -26,6 +26,6 @@ class LaionDataset(BaseDataset):
|
||||
def to_dict(self, sample):
|
||||
return {
|
||||
"image": sample[0],
|
||||
"text_input": self.text_processor(sample[1]["caption"]),
|
||||
"answer": self.text_processor(sample[1]["caption"]),
|
||||
}
|
||||
|
||||
|
@ -24,7 +24,7 @@ class BaseModel(nn.Module):
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return list(self.parameters())[0].device
|
||||
return list(self.parameters())[-1].device
|
||||
|
||||
def load_checkpoint(self, url_or_filename):
|
||||
"""
|
||||
|
@ -7,9 +7,17 @@ import torch.nn as nn
|
||||
|
||||
from minigpt4.common.registry import registry
|
||||
from minigpt4.models.blip2 import Blip2Base, disabled_train
|
||||
from minigpt4.models.modeling_llama import LlamaForCausalLM
|
||||
from transformers.models.llama.modeling_llama import LlamaForCausalLM
|
||||
from transformers import LlamaTokenizer
|
||||
|
||||
from peft import (
|
||||
LoraConfig,
|
||||
get_peft_model,
|
||||
get_peft_model_state_dict,
|
||||
prepare_model_for_int8_training,
|
||||
set_peft_model_state_dict,
|
||||
)
|
||||
|
||||
|
||||
@registry.register_model("mini_gpt4")
|
||||
class MiniGPT4(Blip2Base):
|
||||
@ -18,7 +26,8 @@ class MiniGPT4(Blip2Base):
|
||||
"""
|
||||
|
||||
PRETRAINED_MODEL_CONFIG_DICT = {
|
||||
"pretrain_vicuna": "configs/models/minigpt4.yaml",
|
||||
"pretrain_vicuna0": "configs/models/minigpt4_vicuna0.yaml",
|
||||
"pretrain_llama2": "configs/models/minigpt4_llama2.yaml",
|
||||
}
|
||||
|
||||
def __init__(
|
||||
@ -30,6 +39,7 @@ class MiniGPT4(Blip2Base):
|
||||
use_grad_checkpoint=False,
|
||||
vit_precision="fp16",
|
||||
freeze_vit=True,
|
||||
has_qformer=True,
|
||||
freeze_qformer=True,
|
||||
num_query_token=32,
|
||||
llama_model="",
|
||||
@ -39,6 +49,10 @@ class MiniGPT4(Blip2Base):
|
||||
end_sym='\n',
|
||||
low_resource=False, # use 8 bit and put vit in cpu
|
||||
device_8bit=0, # the device of 8bit model should be set when loading and cannot be changed anymore.
|
||||
lora_r=0,
|
||||
lora_target_modules=["q_proj", "v_proj"],
|
||||
lora_alpha=16,
|
||||
lora_dropout=0.05,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@ -61,30 +75,37 @@ class MiniGPT4(Blip2Base):
|
||||
logging.info("freeze vision encoder")
|
||||
print('Loading VIT Done')
|
||||
|
||||
print('Loading Q-Former')
|
||||
self.Qformer, self.query_tokens = self.init_Qformer(
|
||||
num_query_token, self.visual_encoder.num_features
|
||||
)
|
||||
self.Qformer.cls = None
|
||||
self.Qformer.bert.embeddings.word_embeddings = None
|
||||
self.Qformer.bert.embeddings.position_embeddings = None
|
||||
for layer in self.Qformer.bert.encoder.layer:
|
||||
layer.output = None
|
||||
layer.intermediate = None
|
||||
self.load_from_pretrained(url_or_filename=q_former_model)
|
||||
self.has_qformer = has_qformer
|
||||
if self.has_qformer:
|
||||
print('Loading Q-Former')
|
||||
self.Qformer, self.query_tokens = self.init_Qformer(
|
||||
num_query_token, self.visual_encoder.num_features
|
||||
)
|
||||
self.Qformer.cls = None
|
||||
self.Qformer.bert.embeddings.word_embeddings = None
|
||||
self.Qformer.bert.embeddings.position_embeddings = None
|
||||
for layer in self.Qformer.bert.encoder.layer:
|
||||
layer.output = None
|
||||
layer.intermediate = None
|
||||
self.load_from_pretrained(url_or_filename=q_former_model)
|
||||
|
||||
if freeze_qformer:
|
||||
for name, param in self.Qformer.named_parameters():
|
||||
param.requires_grad = False
|
||||
self.Qformer = self.Qformer.eval()
|
||||
self.Qformer.train = disabled_train
|
||||
self.query_tokens.requires_grad = False
|
||||
logging.info("freeze Qformer")
|
||||
print('Loading Q-Former Done')
|
||||
if freeze_qformer:
|
||||
for name, param in self.Qformer.named_parameters():
|
||||
param.requires_grad = False
|
||||
self.Qformer = self.Qformer.eval()
|
||||
self.Qformer.train = disabled_train
|
||||
self.query_tokens.requires_grad = False
|
||||
logging.info("freeze Qformer")
|
||||
|
||||
img_f_dim = self.Qformer.config.hidden_size
|
||||
print('Loading Q-Former Done')
|
||||
else:
|
||||
img_f_dim = self.visual_encoder.num_features * 4
|
||||
print('Do not use Q-Former here.')
|
||||
|
||||
print('Loading LLAMA')
|
||||
self.llama_tokenizer = LlamaTokenizer.from_pretrained(llama_model, use_fast=False)
|
||||
self.llama_tokenizer.pad_token = self.llama_tokenizer.eos_token
|
||||
self.llama_tokenizer.pad_token = "$$"
|
||||
|
||||
if self.low_resource:
|
||||
self.llama_model = LlamaForCausalLM.from_pretrained(
|
||||
@ -99,12 +120,31 @@ class MiniGPT4(Blip2Base):
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
|
||||
for name, param in self.llama_model.named_parameters():
|
||||
param.requires_grad = False
|
||||
if lora_r > 0:
|
||||
self.llama_model = prepare_model_for_int8_training(self.llama_model)
|
||||
loraconfig = LoraConfig(
|
||||
r=lora_r,
|
||||
lora_alpha=lora_alpha,
|
||||
target_modules=lora_target_modules,
|
||||
lora_dropout=lora_dropout,
|
||||
bias="none",
|
||||
task_type="CAUSAL_LM"
|
||||
)
|
||||
self.llama_model = get_peft_model(self.llama_model, loraconfig)
|
||||
|
||||
# if ckpt_path:
|
||||
# print('load the llm under lora')
|
||||
# ckpt = torch.load(ckpt_path)
|
||||
# set_peft_model_state_dict(self.llama_model,ckpt)
|
||||
self.llama_model.print_trainable_parameters()
|
||||
|
||||
else:
|
||||
for name, param in self.llama_model.named_parameters():
|
||||
param.requires_grad = False
|
||||
print('Loading LLAMA Done')
|
||||
|
||||
self.llama_proj = nn.Linear(
|
||||
self.Qformer.config.hidden_size, self.llama_model.config.hidden_size
|
||||
img_f_dim, self.llama_model.config.hidden_size
|
||||
)
|
||||
self.max_txt_len = max_txt_len
|
||||
self.end_sym = end_sym
|
||||
@ -133,50 +173,109 @@ class MiniGPT4(Blip2Base):
|
||||
|
||||
with self.maybe_autocast():
|
||||
image_embeds = self.ln_vision(self.visual_encoder(image)).to(device)
|
||||
image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(device)
|
||||
if self.has_qformer:
|
||||
image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(device)
|
||||
|
||||
query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
|
||||
query_output = self.Qformer.bert(
|
||||
query_embeds=query_tokens,
|
||||
encoder_hidden_states=image_embeds,
|
||||
encoder_attention_mask=image_atts,
|
||||
return_dict=True,
|
||||
)
|
||||
query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
|
||||
query_output = self.Qformer.bert(
|
||||
query_embeds=query_tokens,
|
||||
encoder_hidden_states=image_embeds,
|
||||
encoder_attention_mask=image_atts,
|
||||
return_dict=True,
|
||||
)
|
||||
|
||||
inputs_llama = self.llama_proj(query_output.last_hidden_state)
|
||||
inputs_llama = self.llama_proj(query_output.last_hidden_state)
|
||||
else:
|
||||
image_embeds = image_embeds[:, 1:, :]
|
||||
bs, pn, hs = image_embeds.shape
|
||||
image_embeds = image_embeds.view(bs, int(pn / 4), int(hs * 4))
|
||||
|
||||
inputs_llama = self.llama_proj(image_embeds)
|
||||
atts_llama = torch.ones(inputs_llama.size()[:-1], dtype=torch.long).to(image.device)
|
||||
return inputs_llama, atts_llama
|
||||
|
||||
def prompt_wrap(self, img_embeds, atts_img, prompt):
|
||||
if prompt:
|
||||
batch_size = img_embeds.shape[0]
|
||||
p_before, p_after = prompt.split('<ImageHere>')
|
||||
p_before_tokens = self.llama_tokenizer(
|
||||
p_before, return_tensors="pt", add_special_tokens=False).to(img_embeds.device)
|
||||
p_after_tokens = self.llama_tokenizer(
|
||||
p_after, return_tensors="pt", add_special_tokens=False).to(img_embeds.device)
|
||||
p_before_embeds = self.llama_model.model.embed_tokens(p_before_tokens.input_ids).expand(batch_size, -1, -1)
|
||||
p_after_embeds = self.llama_model.model.embed_tokens(p_after_tokens.input_ids).expand(batch_size, -1, -1)
|
||||
wrapped_img_embeds = torch.cat([p_before_embeds, img_embeds, p_after_embeds], dim=1)
|
||||
wrapped_atts_img = atts_img[:, :1].expand(-1, wrapped_img_embeds.shape[1])
|
||||
return wrapped_img_embeds, wrapped_atts_img
|
||||
def get_context_emb(self, prompt, img_list):
|
||||
device = img_list[0].device
|
||||
prompt_segs = prompt.split('<ImageHere>')
|
||||
assert len(prompt_segs) == len(img_list) + 1, "Unmatched numbers of image placeholders and images."
|
||||
seg_tokens = [
|
||||
self.llama_tokenizer(
|
||||
seg, return_tensors="pt", add_special_tokens=i == 0).to(device).input_ids
|
||||
# only add bos to the first seg
|
||||
for i, seg in enumerate(prompt_segs)
|
||||
]
|
||||
seg_embs = [self.embed_tokens(seg_t) for seg_t in seg_tokens]
|
||||
|
||||
mixed_embs = [emb for pair in zip(seg_embs[:-1], img_list) for emb in pair] + [seg_embs[-1]]
|
||||
mixed_embs = torch.cat(mixed_embs, dim=1)
|
||||
return mixed_embs
|
||||
|
||||
def prompt_wrap(self, img_embeds, atts_img, prompts):
|
||||
if prompts:
|
||||
emb_lists = []
|
||||
if isinstance(prompts, str):
|
||||
prompts = [prompts] * len(img_embeds)
|
||||
|
||||
for each_img_embed, each_prompt in zip(img_embeds, prompts):
|
||||
p_before, p_after = each_prompt.split('<ImageHere>')
|
||||
|
||||
p_before_tokens = self.llama_tokenizer(
|
||||
p_before, return_tensors="pt", add_special_tokens=False).to(img_embeds.device)
|
||||
p_after_tokens = self.llama_tokenizer(
|
||||
p_after, return_tensors="pt", add_special_tokens=False).to(img_embeds.device)
|
||||
p_before_embed = self.embed_tokens(p_before_tokens.input_ids)
|
||||
p_after_embed = self.embed_tokens(p_after_tokens.input_ids)
|
||||
wrapped_emb = torch.cat([p_before_embed, each_img_embed[None], p_after_embed], dim=1)
|
||||
emb_lists.append(wrapped_emb)
|
||||
emb_lens = [emb.shape[1] for emb in emb_lists]
|
||||
pad_emb = self.embed_tokens(torch.tensor(self.llama_tokenizer.pad_token_id, device=img_embeds.device))
|
||||
wrapped_embs = pad_emb.expand(len(emb_lens), max(emb_lens), -1).clone()
|
||||
wrapped_atts = torch.zeros([len(emb_lens), max(emb_lens)], dtype=torch.int, device=img_embeds.device)
|
||||
for i, emb in enumerate(emb_lists):
|
||||
wrapped_embs[i, :emb_lens[i]] = emb
|
||||
wrapped_atts[i, :emb_lens[i]] = 1
|
||||
return wrapped_embs, wrapped_atts
|
||||
else:
|
||||
return img_embeds, atts_img
|
||||
|
||||
def concat_emb_input_output(self, input_embs, input_atts, output_embs, output_atts):
|
||||
input_lens = []
|
||||
cat_embs = []
|
||||
cat_atts = []
|
||||
for i in range(input_embs.size(0)):
|
||||
input_len = input_atts[i].sum()
|
||||
input_lens.append(input_len)
|
||||
cat_embs.append(
|
||||
torch.cat([
|
||||
input_embs[i][:input_len],
|
||||
output_embs[i],
|
||||
input_embs[i][input_len:]
|
||||
])
|
||||
)
|
||||
cat_atts.append(
|
||||
torch.cat([
|
||||
input_atts[i][:input_len],
|
||||
output_atts[i],
|
||||
input_atts[i][input_len:]
|
||||
])
|
||||
)
|
||||
cat_embs = torch.stack(cat_embs)
|
||||
cat_atts = torch.stack(cat_atts)
|
||||
return cat_embs, cat_atts, input_lens
|
||||
|
||||
def forward(self, samples):
|
||||
image = samples["image"]
|
||||
img_embeds, atts_img = self.encode_img(image)
|
||||
if hasattr(samples, 'question_split'): # VQA dataset
|
||||
print('VQA Batch')
|
||||
vqa_prompt = '###Human: <Img><ImageHere></Img> '
|
||||
img_embeds, atts_img = self.prompt_wrap(img_embeds, atts_img, vqa_prompt)
|
||||
elif self.prompt_list:
|
||||
prompt = random.choice(self.prompt_list)
|
||||
img_embeds, atts_img = self.prompt_wrap(img_embeds, atts_img, prompt)
|
||||
|
||||
if self.prompt_list:
|
||||
instruction = random.choice(self.prompt_list)
|
||||
else:
|
||||
instruction = samples["instruction_input"] if "instruction_input" in samples else None
|
||||
|
||||
img_embeds, atts_img = self.prompt_wrap(img_embeds, atts_img, instruction)
|
||||
|
||||
self.llama_tokenizer.padding_side = "right"
|
||||
|
||||
text = [t + self.end_sym for t in samples["text_input"]]
|
||||
text = [t + self.end_sym for t in samples["answer"]]
|
||||
|
||||
to_regress_tokens = self.llama_tokenizer(
|
||||
text,
|
||||
@ -187,26 +286,29 @@ class MiniGPT4(Blip2Base):
|
||||
add_special_tokens=False
|
||||
).to(image.device)
|
||||
|
||||
targets = to_regress_tokens.input_ids.masked_fill(
|
||||
to_regress_tokens.input_ids == self.llama_tokenizer.pad_token_id, -100
|
||||
)
|
||||
|
||||
empty_targets = (
|
||||
torch.ones([atts_img.shape[0], atts_img.shape[1]+1],
|
||||
dtype=torch.long).to(image.device).fill_(-100) # plus one for bos
|
||||
)
|
||||
targets = torch.cat([empty_targets, targets], dim=1)
|
||||
|
||||
batch_size = img_embeds.shape[0]
|
||||
bos = torch.ones([batch_size, 1],
|
||||
dtype=to_regress_tokens.input_ids.dtype,
|
||||
device=to_regress_tokens.input_ids.device) * self.llama_tokenizer.bos_token_id
|
||||
bos_embeds = self.llama_model.model.embed_tokens(bos)
|
||||
bos_embeds = self.embed_tokens(bos)
|
||||
atts_bos = atts_img[:, :1]
|
||||
|
||||
to_regress_embeds = self.llama_model.model.embed_tokens(to_regress_tokens.input_ids)
|
||||
inputs_embeds = torch.cat([bos_embeds, img_embeds, to_regress_embeds], dim=1)
|
||||
attention_mask = torch.cat([atts_bos, atts_img, to_regress_tokens.attention_mask], dim=1)
|
||||
to_regress_embeds = self.embed_tokens(to_regress_tokens.input_ids)
|
||||
inputs_embeds, attention_mask, input_lens = \
|
||||
self.concat_emb_input_output(img_embeds, atts_img, to_regress_embeds, to_regress_tokens.attention_mask)
|
||||
inputs_embeds = torch.cat([bos_embeds, inputs_embeds], dim=1)
|
||||
attention_mask = torch.cat([atts_bos, attention_mask], dim=1)
|
||||
|
||||
part_targets = to_regress_tokens.input_ids.masked_fill(
|
||||
to_regress_tokens.input_ids == self.llama_tokenizer.pad_token_id, -100
|
||||
)
|
||||
targets = (
|
||||
torch.ones([inputs_embeds.shape[0], inputs_embeds.shape[1]],
|
||||
dtype=torch.long).to(image.device).fill_(-100)
|
||||
)
|
||||
|
||||
for i, target in enumerate(part_targets):
|
||||
targets[i, input_lens[i] + 1:input_lens[i] + len(target) + 1] = target # plus 1 for bos
|
||||
|
||||
with self.maybe_autocast():
|
||||
outputs = self.llama_model(
|
||||
@ -219,6 +321,13 @@ class MiniGPT4(Blip2Base):
|
||||
|
||||
return {"loss": loss}
|
||||
|
||||
def embed_tokens(self, token_ids):
|
||||
if hasattr(self.llama_model.base_model, 'model'): ## lora wrapped model
|
||||
embeds = self.llama_model.base_model.model.model.embed_tokens(token_ids)
|
||||
else:
|
||||
embeds = self.llama_model.base_model.embed_tokens(token_ids)
|
||||
return embeds
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, cfg):
|
||||
vit_model = cfg.get("vit_model", "eva_clip_g")
|
||||
@ -231,6 +340,7 @@ class MiniGPT4(Blip2Base):
|
||||
use_grad_checkpoint = cfg.get("use_grad_checkpoint", False)
|
||||
vit_precision = cfg.get("vit_precision", "fp16")
|
||||
freeze_vit = cfg.get("freeze_vit", True)
|
||||
has_qformer = cfg.get("has_qformer", True)
|
||||
freeze_qformer = cfg.get("freeze_qformer", True)
|
||||
low_resource = cfg.get("low_resource", False)
|
||||
device_8bit = cfg.get("device_8bit", 0)
|
||||
@ -240,6 +350,9 @@ class MiniGPT4(Blip2Base):
|
||||
max_txt_len = cfg.get("max_txt_len", 32)
|
||||
end_sym = cfg.get("end_sym", '\n')
|
||||
|
||||
lora_r = cfg.get("lora_r", 0)
|
||||
lora_alpha = cfg.get("lora_alpha", 32)
|
||||
|
||||
model = cls(
|
||||
vit_model=vit_model,
|
||||
q_former_model=q_former_model,
|
||||
@ -248,6 +361,7 @@ class MiniGPT4(Blip2Base):
|
||||
use_grad_checkpoint=use_grad_checkpoint,
|
||||
vit_precision=vit_precision,
|
||||
freeze_vit=freeze_vit,
|
||||
has_qformer=has_qformer,
|
||||
freeze_qformer=freeze_qformer,
|
||||
num_query_token=num_query_token,
|
||||
llama_model=llama_model,
|
||||
@ -257,6 +371,8 @@ class MiniGPT4(Blip2Base):
|
||||
end_sym=end_sym,
|
||||
low_resource=low_resource,
|
||||
device_8bit=device_8bit,
|
||||
lora_r=lora_r,
|
||||
lora_alpha=lora_alpha,
|
||||
)
|
||||
|
||||
ckpt_path = cfg.get("ckpt", "") # load weights of MiniGPT-4
|
||||
|
55
train_configs/minigpt4_llama2_stage1_pretrain.yaml
Normal file
55
train_configs/minigpt4_llama2_stage1_pretrain.yaml
Normal file
@ -0,0 +1,55 @@
|
||||
model:
|
||||
arch: mini_gpt4
|
||||
model_type: pretrain_llama2
|
||||
|
||||
|
||||
datasets:
|
||||
laion:
|
||||
vis_processor:
|
||||
train:
|
||||
name: "blip2_image_train"
|
||||
image_size: 224
|
||||
text_processor:
|
||||
train:
|
||||
name: "blip_caption"
|
||||
sample_ratio: 115
|
||||
cc_sbu:
|
||||
vis_processor:
|
||||
train:
|
||||
name: "blip2_image_train"
|
||||
image_size: 224
|
||||
text_processor:
|
||||
train:
|
||||
name: "blip_caption"
|
||||
sample_ratio: 14
|
||||
|
||||
|
||||
run:
|
||||
task: image_text_pretrain
|
||||
# optimizer
|
||||
lr_sched: "linear_warmup_cosine_lr"
|
||||
init_lr: 1e-4
|
||||
min_lr: 8e-5
|
||||
warmup_lr: 1e-6
|
||||
|
||||
weight_decay: 0.05
|
||||
max_epoch: 4
|
||||
batch_size_train: 64
|
||||
batch_size_eval: 64
|
||||
num_workers: 4
|
||||
warmup_steps: 5000
|
||||
iters_per_epoch: 5000
|
||||
|
||||
seed: 42
|
||||
output_dir: "output/minigpt4_stage1_pretrain"
|
||||
|
||||
amp: True
|
||||
resume_ckpt_path: null
|
||||
|
||||
evaluate: False
|
||||
train_splits: ["train"]
|
||||
|
||||
device: "cuda"
|
||||
world_size: 1
|
||||
dist_url: "env://"
|
||||
distributed: True
|
50
train_configs/minigpt4_llama2_stage2_finetune.yaml
Normal file
50
train_configs/minigpt4_llama2_stage2_finetune.yaml
Normal file
@ -0,0 +1,50 @@
|
||||
model:
|
||||
arch: mini_gpt4
|
||||
model_type: pretrain_llama2
|
||||
|
||||
max_txt_len: 160
|
||||
end_sym: "</s>"
|
||||
prompt_path: "prompts/alignment.txt"
|
||||
prompt_template: '[INST] {} [/INST] '
|
||||
ckpt: '/path/to/stage1/checkpoint/'
|
||||
|
||||
|
||||
datasets:
|
||||
cc_sbu_align:
|
||||
vis_processor:
|
||||
train:
|
||||
name: "blip2_image_train"
|
||||
image_size: 224
|
||||
text_processor:
|
||||
train:
|
||||
name: "blip_caption"
|
||||
|
||||
run:
|
||||
task: image_text_pretrain
|
||||
# optimizer
|
||||
lr_sched: "linear_warmup_cosine_lr"
|
||||
init_lr: 3e-5
|
||||
min_lr: 1e-5
|
||||
warmup_lr: 1e-6
|
||||
|
||||
weight_decay: 0.05
|
||||
max_epoch: 5
|
||||
iters_per_epoch: 200
|
||||
batch_size_train: 12
|
||||
batch_size_eval: 12
|
||||
num_workers: 4
|
||||
warmup_steps: 200
|
||||
|
||||
seed: 42
|
||||
output_dir: "output/minigpt4_stage2_finetune"
|
||||
|
||||
amp: True
|
||||
resume_ckpt_path: null
|
||||
|
||||
evaluate: False
|
||||
train_splits: ["train"]
|
||||
|
||||
device: "cuda"
|
||||
world_size: 1
|
||||
dist_url: "env://"
|
||||
distributed: True
|
@ -1,8 +1,6 @@
|
||||
model:
|
||||
arch: mini_gpt4
|
||||
model_type: pretrain_vicuna
|
||||
freeze_vit: True
|
||||
freeze_qformer: True
|
||||
model_type: pretrain_vicuna0
|
||||
|
||||
|
||||
datasets:
|
||||
|
@ -1,8 +1,7 @@
|
||||
model:
|
||||
arch: mini_gpt4
|
||||
model_type: pretrain_vicuna
|
||||
freeze_vit: True
|
||||
freeze_qformer: True
|
||||
model_type: pretrain_vicuna0
|
||||
|
||||
max_txt_len: 160
|
||||
end_sym: "###"
|
||||
prompt_path: "prompts/alignment.txt"
|
||||
|
Loading…
Reference in New Issue
Block a user