Merge branch 'Vision-CAIR:main' into main

This commit is contained in:
Sypherd 2023-09-12 10:01:26 -06:00 committed by GitHub
commit 43fb425fef
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 410 additions and 119 deletions

View File

@ -1,5 +1,5 @@
# MiniGPT-4: Enhancing Vision-language Understanding with Advanced Large Language Models
[Deyao Zhu](https://tsutikgiau.github.io/)* (On Job Market!), [Jun Chen](https://junchen14.github.io/)* (On Job Market!), [Xiaoqian Shen](https://xiaoqian-shen.github.io), [Xiang Li](https://xiangli.ac.cn), and [Mohamed Elhoseiny](https://www.mohamed-elhoseiny.com/). *Equal Contribution
[Deyao Zhu](https://tsutikgiau.github.io/)* , [Jun Chen](https://junchen14.github.io/)* (On Job Market!), [Xiaoqian Shen](https://xiaoqian-shen.github.io), [Xiang Li](https://xiangli.ac.cn), and [Mohamed Elhoseiny](https://www.mohamed-elhoseiny.com/). *Equal Contribution
**King Abdullah University of Science and Technology**
@ -9,7 +9,7 @@
## News
We now provide a pretrained MiniGPT-4 aligned with Vicuna-7B! The demo GPU memory consumption now can be as low as 12GB.
We now provide a llama 2 version of MiniGPT-4
## Online Demo
@ -54,49 +54,52 @@ conda activate minigpt4
```
**2. Prepare the pretrained Vicuna weights**
**2. Prepare the pretrained LLM weights**
The current version of MiniGPT-4 is built on the v0 version of Vicuna-13B.
Please refer to our instruction [here](PrepareVicuna.md)
to prepare the Vicuna weights.
The final weights would be in a single folder in a structure similar to the following:
Currently, we provide both Vicuna V0 and Llama 2 version of MiniGPT-4.
Download the corresponding LLM weights from the following huggingface space via clone the repository using git-lfs.
| Vicuna V0 13B | Vicuna V0 7B | Llama 2 Chat 7B |
:------------------------------------------------------------------------------------------------:|:----------------------------------------------------------------------------------------------:|:----------------------------------------------------------------------------------------------:
[Downlad](https://huggingface.co/Vision-CAIR/vicuna/tree/main) | [Download](https://huggingface.co/Vision-CAIR/vicuna-7b/tree/main) | [Download](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf/tree/main)
```
vicuna_weights
├── config.json
├── generation_config.json
├── pytorch_model.bin.index.json
├── pytorch_model-00001-of-00003.bin
...
```
Then, set the path to the vicuna weight in the model config file
[here](minigpt4/configs/models/minigpt4.yaml#L16) at Line 16.
[here](minigpt4/configs/models/minigpt4_vicuna0.yaml#L18) at Line 18
and/or the path to the llama2 weight in the model config file
[here](minigpt4/configs/models/minigpt4_llama2.yaml#L15) at Line 15.
**3. Prepare the pretrained MiniGPT-4 checkpoint**
Download the pretrained checkpoints according to the Vicuna model you prepare.
| Checkpoint Aligned with Vicuna 13B | Checkpoint Aligned with Vicuna 7B |
:------------------------------------------------------------------------------------------------:|:----------------------------------------------------------------------------------------------:
[Downlad](https://drive.google.com/file/d/1a4zLvaiDBr-36pasffmgpvH5P7CKmpze/view?usp=share_link) | [Download](https://drive.google.com/file/d/1RY9jV0dyqLX-o38LrumkKRh6Jtaop58R/view?usp=sharing)
| Checkpoint Aligned with Vicuna 13B | Checkpoint Aligned with Vicuna 7B | Checkpoint Aligned with Llama 2 Chat 7B |
:------------------------------------------------------------------------------------------------:|:-----------------------------------------------------------------------------------------------:|:----------------------------------------------------------------------------------------------:
[Downlad](https://drive.google.com/file/d/1a4zLvaiDBr-36pasffmgpvH5P7CKmpze/view?usp=share_link) | [Download](https://drive.google.com/file/d/1RY9jV0dyqLX-o38LrumkKRh6Jtaop58R/view?usp=sharing) | [Download](https://drive.google.com/file/d/11nAPjEok8eAGGEG1N2vXo3kBLCg0WgUk/view?usp=sharing)
Then, set the path to the pretrained checkpoint in the evaluation config file
in [eval_configs/minigpt4_eval.yaml](eval_configs/minigpt4_eval.yaml#L10) at Line 11.
in [eval_configs/minigpt4_eval.yaml](eval_configs/minigpt4_eval.yaml#L10) at Line 8 for Vicuna version or [eval_configs/minigpt4_llama2_eval.yaml](eval_configs/minigpt4_llama2_eval.yaml#L10) for LLama2 version.
### Launching Demo Locally
Try out our demo [demo.py](demo.py) on your local machine by running
Try out our demo [demo.py](demo.py) for the vicuna version on your local machine by running
```
python demo.py --cfg-path eval_configs/minigpt4_eval.yaml --gpu-id 0
```
To save GPU memory, Vicuna loads as 8 bit by default, with a beam search width of 1.
This configuration requires about 23G GPU memory for Vicuna 13B and 11.5G GPU memory for Vicuna 7B.
or for Llama 2 version by
```
python demo.py --cfg-path eval_configs/minigpt4_llama2_eval.yaml --gpu-id 0
```
To save GPU memory, LLMs loads as 8 bit by default, with a beam search width of 1.
This configuration requires about 23G GPU memory for 13B LLM and 11.5G GPU memory for 7B LLM.
For more powerful GPUs, you can run the model
in 16 bit by setting low_resource to False in the config file
[minigpt4_eval.yaml](eval_configs/minigpt4_eval.yaml) and use a larger beam search width.

12
demo.py
View File

@ -10,7 +10,7 @@ import gradio as gr
from minigpt4.common.config import Config
from minigpt4.common.dist_utils import get_rank
from minigpt4.common.registry import registry
from minigpt4.conversation.conversation import Chat, CONV_VISION
from minigpt4.conversation.conversation import Chat, CONV_VISION_Vicuna0, CONV_VISION_LLama2
# imports modules for registration
from minigpt4.datasets.builders import *
@ -50,6 +50,9 @@ def setup_seeds(config):
# Model Initialization
# ========================================
conv_dict = {'pretrain_vicuna0': CONV_VISION_Vicuna0,
'pretrain_llama2': CONV_VISION_LLama2}
print('Initializing Chat')
args = parse_args()
cfg = Config(args)
@ -59,15 +62,19 @@ model_config.device_8bit = args.gpu_id
model_cls = registry.get_model_class(model_config.arch)
model = model_cls.from_config(model_config).to('cuda:{}'.format(args.gpu_id))
CONV_VISION = conv_dict[model_config.model_type]
vis_processor_cfg = cfg.datasets_cfg.cc_sbu_align.vis_processor.train
vis_processor = registry.get_processor_class(vis_processor_cfg.name).from_config(vis_processor_cfg)
chat = Chat(model, vis_processor, device='cuda:{}'.format(args.gpu_id))
print('Initialization Finished')
# ========================================
# Gradio Setting
# ========================================
def gradio_reset(chat_state, img_list):
if chat_state is not None:
chat_state.messages = []
@ -75,6 +82,7 @@ def gradio_reset(chat_state, img_list):
img_list = []
return None, gr.update(value=None, interactive=True), gr.update(placeholder='Please upload your image first', interactive=False),gr.update(value="Upload & Start Chat", interactive=True), chat_state, img_list
def upload_img(gr_img, text_input, chat_state):
if gr_img is None:
return None, None, gr.update(interactive=True), chat_state, None
@ -83,6 +91,7 @@ def upload_img(gr_img, text_input, chat_state):
llm_message = chat.upload_img(gr_img, chat_state, img_list)
return gr.update(interactive=False), gr.update(interactive=True, placeholder='Type and press Enter'), gr.update(value="Start Chatting", interactive=False), chat_state, img_list
def gradio_ask(user_message, chatbot, chat_state):
if len(user_message) == 0:
return gr.update(interactive=True, placeholder='Input should not be empty!'), chatbot, chat_state
@ -101,6 +110,7 @@ def gradio_answer(chatbot, chat_state, img_list, num_beams, temperature):
chatbot[-1][1] = llm_message
return chatbot, chat_state, img_list
title = """<h1 align="center">Demo of MiniGPT-4</h1>"""
description = """<h3>This is the demo of MiniGPT-4. Upload your images and start chatting!</h3>"""
article = """<p><a href='https://minigpt-4.github.io'><img src='https://img.shields.io/badge/Project-Page-Green'></a></p><p><a href='https://github.com/Vision-CAIR/MiniGPT-4'><img src='https://img.shields.io/badge/Github-Code-blue'></a></p><p><a href='https://raw.githubusercontent.com/Vision-CAIR/MiniGPT-4/main/MiniGPT_4.pdf'><img src='https://img.shields.io/badge/Paper-PDF-red'></a></p>

View File

@ -1,14 +1,11 @@
model:
arch: mini_gpt4
model_type: pretrain_vicuna
freeze_vit: True
freeze_qformer: True
model_type: pretrain_vicuna0
max_txt_len: 160
end_sym: "###"
low_resource: True
prompt_path: "prompts/alignment.txt"
prompt_template: '###Human: {} ###Assistant: '
ckpt: '/path/to/pretrained/ckpt/'
ckpt: '/path/to/checkpoint/'
datasets:

View File

@ -0,0 +1,22 @@
model:
arch: mini_gpt4
model_type: pretrain_llama2
max_txt_len: 160
end_sym: "</s>"
low_resource: True
prompt_template: '[INST] {} [/INST] '
ckpt: '/path/to/checkpoint/'
datasets:
cc_sbu_align:
vis_processor:
train:
name: "blip2_image_eval"
image_size: 224
text_processor:
train:
name: "blip_caption"
run:
task: image_text_pretrain

View File

@ -55,7 +55,10 @@ def is_main_process():
def init_distributed_mode(args):
if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
if args.distributed is False:
print("Not using distributed mode")
return
elif "RANK" in os.environ and "WORLD_SIZE" in os.environ:
args.rank = int(os.environ["RANK"])
args.world_size = int(os.environ["WORLD_SIZE"])
args.gpu = int(os.environ["LOCAL_RANK"])

View File

@ -0,0 +1,29 @@
model:
arch: mini_gpt4
# vit encoder
image_size: 224
drop_path_rate: 0
use_grad_checkpoint: False
vit_precision: "fp16"
freeze_vit: True
has_qformer: False
# generation configs
prompt: ""
llama_model: "/path/to/llama2/weight"
preprocess:
vis_processor:
train:
name: "blip2_image_train"
image_size: 224
eval:
name: "blip2_image_eval"
image_size: 224
text_processor:
train:
name: "blip_caption"
eval:
name: "blip_caption"

View File

@ -12,12 +12,11 @@ model:
# Q-Former
num_query_token: 32
# Vicuna
llama_model: "/path/to/vicuna/weights/"
# generation configs
prompt: ""
llama_model: "/path/to/vicuna/weight"
preprocess:
vis_processor:
train:

View File

@ -39,18 +39,18 @@ class Conversation:
ret = self.system + self.sep
for role, message in self.messages:
if message:
ret += role + ": " + message + self.sep
ret += role + message + self.sep
else:
ret += role + ":"
ret += role
return ret
elif self.sep_style == SeparatorStyle.TWO:
seps = [self.sep, self.sep2]
ret = self.system + seps[0]
for i, (role, message) in enumerate(self.messages):
if message:
ret += role + ": " + message + seps[i % 2]
ret += role + message + seps[i % 2]
else:
ret += role + ":"
ret += role
return ret
else:
raise ValueError(f"Invalid style: {self.sep_style}")
@ -106,16 +106,26 @@ class StoppingCriteriaSub(StoppingCriteria):
return False
CONV_VISION = Conversation(
CONV_VISION_Vicuna0 = Conversation(
system="Give the following image: <Img>ImageContent</Img>. "
"You will be able to see the image once I provide it to you. Please answer my questions.",
roles=("Human", "Assistant"),
roles=("Human: ", "Assistant: "),
messages=[],
offset=2,
sep_style=SeparatorStyle.SINGLE,
sep="###",
)
CONV_VISION_LLama2 = Conversation(
system="Give the following image: <Img>ImageContent</Img>. "
"You will be able to see the image once I provide it to you. Please answer my questions.",
roles=("<s>[INST] ", " [/INST] "),
messages=[],
offset=2,
sep_style=SeparatorStyle.SINGLE,
sep="",
)
class Chat:

View File

@ -22,7 +22,7 @@ class CCSBUDataset(BaseDataset):
def to_dict(self, sample):
return {
"image": sample[0],
"text_input": self.text_processor(sample[1]["caption"]),
"answer": self.text_processor(sample[1]["caption"]),
}
@ -42,6 +42,6 @@ class CCSBUAlignDataset(CaptionDataset):
return {
"image": image,
"text_input": caption,
"answer": caption,
"image_id": self.img_ids[ann["image_id"]],
}

View File

@ -26,6 +26,6 @@ class LaionDataset(BaseDataset):
def to_dict(self, sample):
return {
"image": sample[0],
"text_input": self.text_processor(sample[1]["caption"]),
"answer": self.text_processor(sample[1]["caption"]),
}

View File

@ -24,7 +24,7 @@ class BaseModel(nn.Module):
@property
def device(self):
return list(self.parameters())[0].device
return list(self.parameters())[-1].device
def load_checkpoint(self, url_or_filename):
"""

View File

@ -7,9 +7,17 @@ import torch.nn as nn
from minigpt4.common.registry import registry
from minigpt4.models.blip2 import Blip2Base, disabled_train
from minigpt4.models.modeling_llama import LlamaForCausalLM
from transformers.models.llama.modeling_llama import LlamaForCausalLM
from transformers import LlamaTokenizer
from peft import (
LoraConfig,
get_peft_model,
get_peft_model_state_dict,
prepare_model_for_int8_training,
set_peft_model_state_dict,
)
@registry.register_model("mini_gpt4")
class MiniGPT4(Blip2Base):
@ -18,7 +26,8 @@ class MiniGPT4(Blip2Base):
"""
PRETRAINED_MODEL_CONFIG_DICT = {
"pretrain_vicuna": "configs/models/minigpt4.yaml",
"pretrain_vicuna0": "configs/models/minigpt4_vicuna0.yaml",
"pretrain_llama2": "configs/models/minigpt4_llama2.yaml",
}
def __init__(
@ -30,6 +39,7 @@ class MiniGPT4(Blip2Base):
use_grad_checkpoint=False,
vit_precision="fp16",
freeze_vit=True,
has_qformer=True,
freeze_qformer=True,
num_query_token=32,
llama_model="",
@ -39,6 +49,10 @@ class MiniGPT4(Blip2Base):
end_sym='\n',
low_resource=False, # use 8 bit and put vit in cpu
device_8bit=0, # the device of 8bit model should be set when loading and cannot be changed anymore.
lora_r=0,
lora_target_modules=["q_proj", "v_proj"],
lora_alpha=16,
lora_dropout=0.05,
):
super().__init__()
@ -61,30 +75,37 @@ class MiniGPT4(Blip2Base):
logging.info("freeze vision encoder")
print('Loading VIT Done')
print('Loading Q-Former')
self.Qformer, self.query_tokens = self.init_Qformer(
num_query_token, self.visual_encoder.num_features
)
self.Qformer.cls = None
self.Qformer.bert.embeddings.word_embeddings = None
self.Qformer.bert.embeddings.position_embeddings = None
for layer in self.Qformer.bert.encoder.layer:
layer.output = None
layer.intermediate = None
self.load_from_pretrained(url_or_filename=q_former_model)
self.has_qformer = has_qformer
if self.has_qformer:
print('Loading Q-Former')
self.Qformer, self.query_tokens = self.init_Qformer(
num_query_token, self.visual_encoder.num_features
)
self.Qformer.cls = None
self.Qformer.bert.embeddings.word_embeddings = None
self.Qformer.bert.embeddings.position_embeddings = None
for layer in self.Qformer.bert.encoder.layer:
layer.output = None
layer.intermediate = None
self.load_from_pretrained(url_or_filename=q_former_model)
if freeze_qformer:
for name, param in self.Qformer.named_parameters():
param.requires_grad = False
self.Qformer = self.Qformer.eval()
self.Qformer.train = disabled_train
self.query_tokens.requires_grad = False
logging.info("freeze Qformer")
print('Loading Q-Former Done')
if freeze_qformer:
for name, param in self.Qformer.named_parameters():
param.requires_grad = False
self.Qformer = self.Qformer.eval()
self.Qformer.train = disabled_train
self.query_tokens.requires_grad = False
logging.info("freeze Qformer")
img_f_dim = self.Qformer.config.hidden_size
print('Loading Q-Former Done')
else:
img_f_dim = self.visual_encoder.num_features * 4
print('Do not use Q-Former here.')
print('Loading LLAMA')
self.llama_tokenizer = LlamaTokenizer.from_pretrained(llama_model, use_fast=False)
self.llama_tokenizer.pad_token = self.llama_tokenizer.eos_token
self.llama_tokenizer.pad_token = "$$"
if self.low_resource:
self.llama_model = LlamaForCausalLM.from_pretrained(
@ -99,12 +120,31 @@ class MiniGPT4(Blip2Base):
torch_dtype=torch.float16,
)
for name, param in self.llama_model.named_parameters():
param.requires_grad = False
if lora_r > 0:
self.llama_model = prepare_model_for_int8_training(self.llama_model)
loraconfig = LoraConfig(
r=lora_r,
lora_alpha=lora_alpha,
target_modules=lora_target_modules,
lora_dropout=lora_dropout,
bias="none",
task_type="CAUSAL_LM"
)
self.llama_model = get_peft_model(self.llama_model, loraconfig)
# if ckpt_path:
# print('load the llm under lora')
# ckpt = torch.load(ckpt_path)
# set_peft_model_state_dict(self.llama_model,ckpt)
self.llama_model.print_trainable_parameters()
else:
for name, param in self.llama_model.named_parameters():
param.requires_grad = False
print('Loading LLAMA Done')
self.llama_proj = nn.Linear(
self.Qformer.config.hidden_size, self.llama_model.config.hidden_size
img_f_dim, self.llama_model.config.hidden_size
)
self.max_txt_len = max_txt_len
self.end_sym = end_sym
@ -133,50 +173,109 @@ class MiniGPT4(Blip2Base):
with self.maybe_autocast():
image_embeds = self.ln_vision(self.visual_encoder(image)).to(device)
image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(device)
if self.has_qformer:
image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(device)
query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
query_output = self.Qformer.bert(
query_embeds=query_tokens,
encoder_hidden_states=image_embeds,
encoder_attention_mask=image_atts,
return_dict=True,
)
query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
query_output = self.Qformer.bert(
query_embeds=query_tokens,
encoder_hidden_states=image_embeds,
encoder_attention_mask=image_atts,
return_dict=True,
)
inputs_llama = self.llama_proj(query_output.last_hidden_state)
inputs_llama = self.llama_proj(query_output.last_hidden_state)
else:
image_embeds = image_embeds[:, 1:, :]
bs, pn, hs = image_embeds.shape
image_embeds = image_embeds.view(bs, int(pn / 4), int(hs * 4))
inputs_llama = self.llama_proj(image_embeds)
atts_llama = torch.ones(inputs_llama.size()[:-1], dtype=torch.long).to(image.device)
return inputs_llama, atts_llama
def prompt_wrap(self, img_embeds, atts_img, prompt):
if prompt:
batch_size = img_embeds.shape[0]
p_before, p_after = prompt.split('<ImageHere>')
p_before_tokens = self.llama_tokenizer(
p_before, return_tensors="pt", add_special_tokens=False).to(img_embeds.device)
p_after_tokens = self.llama_tokenizer(
p_after, return_tensors="pt", add_special_tokens=False).to(img_embeds.device)
p_before_embeds = self.llama_model.model.embed_tokens(p_before_tokens.input_ids).expand(batch_size, -1, -1)
p_after_embeds = self.llama_model.model.embed_tokens(p_after_tokens.input_ids).expand(batch_size, -1, -1)
wrapped_img_embeds = torch.cat([p_before_embeds, img_embeds, p_after_embeds], dim=1)
wrapped_atts_img = atts_img[:, :1].expand(-1, wrapped_img_embeds.shape[1])
return wrapped_img_embeds, wrapped_atts_img
def get_context_emb(self, prompt, img_list):
device = img_list[0].device
prompt_segs = prompt.split('<ImageHere>')
assert len(prompt_segs) == len(img_list) + 1, "Unmatched numbers of image placeholders and images."
seg_tokens = [
self.llama_tokenizer(
seg, return_tensors="pt", add_special_tokens=i == 0).to(device).input_ids
# only add bos to the first seg
for i, seg in enumerate(prompt_segs)
]
seg_embs = [self.embed_tokens(seg_t) for seg_t in seg_tokens]
mixed_embs = [emb for pair in zip(seg_embs[:-1], img_list) for emb in pair] + [seg_embs[-1]]
mixed_embs = torch.cat(mixed_embs, dim=1)
return mixed_embs
def prompt_wrap(self, img_embeds, atts_img, prompts):
if prompts:
emb_lists = []
if isinstance(prompts, str):
prompts = [prompts] * len(img_embeds)
for each_img_embed, each_prompt in zip(img_embeds, prompts):
p_before, p_after = each_prompt.split('<ImageHere>')
p_before_tokens = self.llama_tokenizer(
p_before, return_tensors="pt", add_special_tokens=False).to(img_embeds.device)
p_after_tokens = self.llama_tokenizer(
p_after, return_tensors="pt", add_special_tokens=False).to(img_embeds.device)
p_before_embed = self.embed_tokens(p_before_tokens.input_ids)
p_after_embed = self.embed_tokens(p_after_tokens.input_ids)
wrapped_emb = torch.cat([p_before_embed, each_img_embed[None], p_after_embed], dim=1)
emb_lists.append(wrapped_emb)
emb_lens = [emb.shape[1] for emb in emb_lists]
pad_emb = self.embed_tokens(torch.tensor(self.llama_tokenizer.pad_token_id, device=img_embeds.device))
wrapped_embs = pad_emb.expand(len(emb_lens), max(emb_lens), -1).clone()
wrapped_atts = torch.zeros([len(emb_lens), max(emb_lens)], dtype=torch.int, device=img_embeds.device)
for i, emb in enumerate(emb_lists):
wrapped_embs[i, :emb_lens[i]] = emb
wrapped_atts[i, :emb_lens[i]] = 1
return wrapped_embs, wrapped_atts
else:
return img_embeds, atts_img
def concat_emb_input_output(self, input_embs, input_atts, output_embs, output_atts):
input_lens = []
cat_embs = []
cat_atts = []
for i in range(input_embs.size(0)):
input_len = input_atts[i].sum()
input_lens.append(input_len)
cat_embs.append(
torch.cat([
input_embs[i][:input_len],
output_embs[i],
input_embs[i][input_len:]
])
)
cat_atts.append(
torch.cat([
input_atts[i][:input_len],
output_atts[i],
input_atts[i][input_len:]
])
)
cat_embs = torch.stack(cat_embs)
cat_atts = torch.stack(cat_atts)
return cat_embs, cat_atts, input_lens
def forward(self, samples):
image = samples["image"]
img_embeds, atts_img = self.encode_img(image)
if hasattr(samples, 'question_split'): # VQA dataset
print('VQA Batch')
vqa_prompt = '###Human: <Img><ImageHere></Img> '
img_embeds, atts_img = self.prompt_wrap(img_embeds, atts_img, vqa_prompt)
elif self.prompt_list:
prompt = random.choice(self.prompt_list)
img_embeds, atts_img = self.prompt_wrap(img_embeds, atts_img, prompt)
if self.prompt_list:
instruction = random.choice(self.prompt_list)
else:
instruction = samples["instruction_input"] if "instruction_input" in samples else None
img_embeds, atts_img = self.prompt_wrap(img_embeds, atts_img, instruction)
self.llama_tokenizer.padding_side = "right"
text = [t + self.end_sym for t in samples["text_input"]]
text = [t + self.end_sym for t in samples["answer"]]
to_regress_tokens = self.llama_tokenizer(
text,
@ -187,26 +286,29 @@ class MiniGPT4(Blip2Base):
add_special_tokens=False
).to(image.device)
targets = to_regress_tokens.input_ids.masked_fill(
to_regress_tokens.input_ids == self.llama_tokenizer.pad_token_id, -100
)
empty_targets = (
torch.ones([atts_img.shape[0], atts_img.shape[1]+1],
dtype=torch.long).to(image.device).fill_(-100) # plus one for bos
)
targets = torch.cat([empty_targets, targets], dim=1)
batch_size = img_embeds.shape[0]
bos = torch.ones([batch_size, 1],
dtype=to_regress_tokens.input_ids.dtype,
device=to_regress_tokens.input_ids.device) * self.llama_tokenizer.bos_token_id
bos_embeds = self.llama_model.model.embed_tokens(bos)
bos_embeds = self.embed_tokens(bos)
atts_bos = atts_img[:, :1]
to_regress_embeds = self.llama_model.model.embed_tokens(to_regress_tokens.input_ids)
inputs_embeds = torch.cat([bos_embeds, img_embeds, to_regress_embeds], dim=1)
attention_mask = torch.cat([atts_bos, atts_img, to_regress_tokens.attention_mask], dim=1)
to_regress_embeds = self.embed_tokens(to_regress_tokens.input_ids)
inputs_embeds, attention_mask, input_lens = \
self.concat_emb_input_output(img_embeds, atts_img, to_regress_embeds, to_regress_tokens.attention_mask)
inputs_embeds = torch.cat([bos_embeds, inputs_embeds], dim=1)
attention_mask = torch.cat([atts_bos, attention_mask], dim=1)
part_targets = to_regress_tokens.input_ids.masked_fill(
to_regress_tokens.input_ids == self.llama_tokenizer.pad_token_id, -100
)
targets = (
torch.ones([inputs_embeds.shape[0], inputs_embeds.shape[1]],
dtype=torch.long).to(image.device).fill_(-100)
)
for i, target in enumerate(part_targets):
targets[i, input_lens[i] + 1:input_lens[i] + len(target) + 1] = target # plus 1 for bos
with self.maybe_autocast():
outputs = self.llama_model(
@ -219,6 +321,13 @@ class MiniGPT4(Blip2Base):
return {"loss": loss}
def embed_tokens(self, token_ids):
if hasattr(self.llama_model.base_model, 'model'): ## lora wrapped model
embeds = self.llama_model.base_model.model.model.embed_tokens(token_ids)
else:
embeds = self.llama_model.base_model.embed_tokens(token_ids)
return embeds
@classmethod
def from_config(cls, cfg):
vit_model = cfg.get("vit_model", "eva_clip_g")
@ -231,6 +340,7 @@ class MiniGPT4(Blip2Base):
use_grad_checkpoint = cfg.get("use_grad_checkpoint", False)
vit_precision = cfg.get("vit_precision", "fp16")
freeze_vit = cfg.get("freeze_vit", True)
has_qformer = cfg.get("has_qformer", True)
freeze_qformer = cfg.get("freeze_qformer", True)
low_resource = cfg.get("low_resource", False)
device_8bit = cfg.get("device_8bit", 0)
@ -240,6 +350,9 @@ class MiniGPT4(Blip2Base):
max_txt_len = cfg.get("max_txt_len", 32)
end_sym = cfg.get("end_sym", '\n')
lora_r = cfg.get("lora_r", 0)
lora_alpha = cfg.get("lora_alpha", 32)
model = cls(
vit_model=vit_model,
q_former_model=q_former_model,
@ -248,6 +361,7 @@ class MiniGPT4(Blip2Base):
use_grad_checkpoint=use_grad_checkpoint,
vit_precision=vit_precision,
freeze_vit=freeze_vit,
has_qformer=has_qformer,
freeze_qformer=freeze_qformer,
num_query_token=num_query_token,
llama_model=llama_model,
@ -257,6 +371,8 @@ class MiniGPT4(Blip2Base):
end_sym=end_sym,
low_resource=low_resource,
device_8bit=device_8bit,
lora_r=lora_r,
lora_alpha=lora_alpha,
)
ckpt_path = cfg.get("ckpt", "") # load weights of MiniGPT-4

View File

@ -0,0 +1,55 @@
model:
arch: mini_gpt4
model_type: pretrain_llama2
datasets:
laion:
vis_processor:
train:
name: "blip2_image_train"
image_size: 224
text_processor:
train:
name: "blip_caption"
sample_ratio: 115
cc_sbu:
vis_processor:
train:
name: "blip2_image_train"
image_size: 224
text_processor:
train:
name: "blip_caption"
sample_ratio: 14
run:
task: image_text_pretrain
# optimizer
lr_sched: "linear_warmup_cosine_lr"
init_lr: 1e-4
min_lr: 8e-5
warmup_lr: 1e-6
weight_decay: 0.05
max_epoch: 4
batch_size_train: 64
batch_size_eval: 64
num_workers: 4
warmup_steps: 5000
iters_per_epoch: 5000
seed: 42
output_dir: "output/minigpt4_stage1_pretrain"
amp: True
resume_ckpt_path: null
evaluate: False
train_splits: ["train"]
device: "cuda"
world_size: 1
dist_url: "env://"
distributed: True

View File

@ -0,0 +1,50 @@
model:
arch: mini_gpt4
model_type: pretrain_llama2
max_txt_len: 160
end_sym: "</s>"
prompt_path: "prompts/alignment.txt"
prompt_template: '[INST] {} [/INST] '
ckpt: '/path/to/stage1/checkpoint/'
datasets:
cc_sbu_align:
vis_processor:
train:
name: "blip2_image_train"
image_size: 224
text_processor:
train:
name: "blip_caption"
run:
task: image_text_pretrain
# optimizer
lr_sched: "linear_warmup_cosine_lr"
init_lr: 3e-5
min_lr: 1e-5
warmup_lr: 1e-6
weight_decay: 0.05
max_epoch: 5
iters_per_epoch: 200
batch_size_train: 12
batch_size_eval: 12
num_workers: 4
warmup_steps: 200
seed: 42
output_dir: "output/minigpt4_stage2_finetune"
amp: True
resume_ckpt_path: null
evaluate: False
train_splits: ["train"]
device: "cuda"
world_size: 1
dist_url: "env://"
distributed: True

View File

@ -1,8 +1,6 @@
model:
arch: mini_gpt4
model_type: pretrain_vicuna
freeze_vit: True
freeze_qformer: True
model_type: pretrain_vicuna0
datasets:

View File

@ -1,8 +1,7 @@
model:
arch: mini_gpt4
model_type: pretrain_vicuna
freeze_vit: True
freeze_qformer: True
model_type: pretrain_vicuna0
max_txt_len: 160
end_sym: "###"
prompt_path: "prompts/alignment.txt"