diff --git a/README.md b/README.md
index 0e98d20..28e4f22 100644
--- a/README.md
+++ b/README.md
@@ -1,5 +1,5 @@
# MiniGPT-4: Enhancing Vision-language Understanding with Advanced Large Language Models
-[Deyao Zhu](https://tsutikgiau.github.io/)* (On Job Market!), [Jun Chen](https://junchen14.github.io/)* (On Job Market!), [Xiaoqian Shen](https://xiaoqian-shen.github.io), [Xiang Li](https://xiangli.ac.cn), and [Mohamed Elhoseiny](https://www.mohamed-elhoseiny.com/). *Equal Contribution
+[Deyao Zhu](https://tsutikgiau.github.io/)* , [Jun Chen](https://junchen14.github.io/)* (On Job Market!), [Xiaoqian Shen](https://xiaoqian-shen.github.io), [Xiang Li](https://xiangli.ac.cn), and [Mohamed Elhoseiny](https://www.mohamed-elhoseiny.com/). *Equal Contribution
**King Abdullah University of Science and Technology**
@@ -7,7 +7,7 @@
## News
-We now provide a pretrained MiniGPT-4 aligned with Vicuna-7B! The demo GPU memory consumption now can be as low as 12GB.
+We now provide a llama 2 version of MiniGPT-4
## Online Demo
@@ -52,49 +52,52 @@ conda activate minigpt4
```
-**2. Prepare the pretrained Vicuna weights**
+**2. Prepare the pretrained LLM weights**
-The current version of MiniGPT-4 is built on the v0 version of Vicuna-13B.
-Please refer to our instruction [here](PrepareVicuna.md)
-to prepare the Vicuna weights.
-The final weights would be in a single folder in a structure similar to the following:
+Currently, we provide both Vicuna V0 and Llama 2 version of MiniGPT-4.
+Download the corresponding LLM weights from the following huggingface space via clone the repository using git-lfs.
+
+| Vicuna V0 13B | Vicuna V0 7B | Llama 2 Chat 7B |
+:------------------------------------------------------------------------------------------------:|:----------------------------------------------------------------------------------------------:|:----------------------------------------------------------------------------------------------:
+ [Downlad](https://huggingface.co/Vision-CAIR/vicuna/tree/main) | [Download](https://huggingface.co/Vision-CAIR/vicuna-7b/tree/main) | [Download](https://huggingface.co/meta-llama/Llama-2-7b-chat/tree/main)
-```
-vicuna_weights
-├── config.json
-├── generation_config.json
-├── pytorch_model.bin.index.json
-├── pytorch_model-00001-of-00003.bin
-...
-```
Then, set the path to the vicuna weight in the model config file
-[here](minigpt4/configs/models/minigpt4.yaml#L16) at Line 16.
+[here](minigpt4/configs/models/minigpt4_vicuna0.yaml#L18) at Line 18
+and/or the path to the llama2 weight in the model config file
+[here](minigpt4/configs/models/minigpt4_llama2.yaml#L15) at Line 15.
**3. Prepare the pretrained MiniGPT-4 checkpoint**
Download the pretrained checkpoints according to the Vicuna model you prepare.
-| Checkpoint Aligned with Vicuna 13B | Checkpoint Aligned with Vicuna 7B |
-:------------------------------------------------------------------------------------------------:|:----------------------------------------------------------------------------------------------:
- [Downlad](https://drive.google.com/file/d/1a4zLvaiDBr-36pasffmgpvH5P7CKmpze/view?usp=share_link) | [Download](https://drive.google.com/file/d/1RY9jV0dyqLX-o38LrumkKRh6Jtaop58R/view?usp=sharing)
+| Checkpoint Aligned with Vicuna 13B | Checkpoint Aligned with Vicuna 7B | Checkpoint Aligned with Llama 2 Chat 7B |
+:------------------------------------------------------------------------------------------------:|:-----------------------------------------------------------------------------------------------:|:----------------------------------------------------------------------------------------------:
+ [Downlad](https://drive.google.com/file/d/1a4zLvaiDBr-36pasffmgpvH5P7CKmpze/view?usp=share_link) | [Download](https://drive.google.com/file/d/1RY9jV0dyqLX-o38LrumkKRh6Jtaop58R/view?usp=sharing) | [Download](https://drive.google.com/file/d/11nAPjEok8eAGGEG1N2vXo3kBLCg0WgUk/view?usp=sharing)
Then, set the path to the pretrained checkpoint in the evaluation config file
-in [eval_configs/minigpt4_eval.yaml](eval_configs/minigpt4_eval.yaml#L10) at Line 11.
+in [eval_configs/minigpt4_eval.yaml](eval_configs/minigpt4_eval.yaml#L10) at Line 8 for Vicuna version or [eval_configs/minigpt4_llama2_eval.yaml](eval_configs/minigpt4_llama2_eval.yaml#L10) for LLama2 version.
### Launching Demo Locally
-Try out our demo [demo.py](demo.py) on your local machine by running
+Try out our demo [demo.py](demo.py) for the vicuna version on your local machine by running
```
python demo.py --cfg-path eval_configs/minigpt4_eval.yaml --gpu-id 0
```
-To save GPU memory, Vicuna loads as 8 bit by default, with a beam search width of 1.
-This configuration requires about 23G GPU memory for Vicuna 13B and 11.5G GPU memory for Vicuna 7B.
+or for Llama 2 version by
+
+```
+python demo.py --cfg-path eval_configs/minigpt4_llama2_eval.yaml --gpu-id 0
+```
+
+
+To save GPU memory, LLMs loads as 8 bit by default, with a beam search width of 1.
+This configuration requires about 23G GPU memory for 13B LLM and 11.5G GPU memory for 7B LLM.
For more powerful GPUs, you can run the model
in 16 bit by setting low_resource to False in the config file
[minigpt4_eval.yaml](eval_configs/minigpt4_eval.yaml) and use a larger beam search width.
diff --git a/demo.py b/demo.py
index b3659f1..483b56c 100644
--- a/demo.py
+++ b/demo.py
@@ -10,7 +10,7 @@ import gradio as gr
from minigpt4.common.config import Config
from minigpt4.common.dist_utils import get_rank
from minigpt4.common.registry import registry
-from minigpt4.conversation.conversation import Chat, CONV_VISION
+from minigpt4.conversation.conversation import Chat, CONV_VISION_Vicuna0, CONV_VISION_LLama2
# imports modules for registration
from minigpt4.datasets.builders import *
@@ -50,6 +50,9 @@ def setup_seeds(config):
# Model Initialization
# ========================================
+conv_dict = {'pretrain_vicuna0': CONV_VISION_Vicuna0,
+ 'pretrain_llama2': CONV_VISION_LLama2}
+
print('Initializing Chat')
args = parse_args()
cfg = Config(args)
@@ -59,15 +62,19 @@ model_config.device_8bit = args.gpu_id
model_cls = registry.get_model_class(model_config.arch)
model = model_cls.from_config(model_config).to('cuda:{}'.format(args.gpu_id))
+CONV_VISION = conv_dict[model_config.model_type]
+
vis_processor_cfg = cfg.datasets_cfg.cc_sbu_align.vis_processor.train
vis_processor = registry.get_processor_class(vis_processor_cfg.name).from_config(vis_processor_cfg)
chat = Chat(model, vis_processor, device='cuda:{}'.format(args.gpu_id))
print('Initialization Finished')
+
# ========================================
# Gradio Setting
# ========================================
+
def gradio_reset(chat_state, img_list):
if chat_state is not None:
chat_state.messages = []
@@ -75,6 +82,7 @@ def gradio_reset(chat_state, img_list):
img_list = []
return None, gr.update(value=None, interactive=True), gr.update(placeholder='Please upload your image first', interactive=False),gr.update(value="Upload & Start Chat", interactive=True), chat_state, img_list
+
def upload_img(gr_img, text_input, chat_state):
if gr_img is None:
return None, None, gr.update(interactive=True), chat_state, None
@@ -83,6 +91,7 @@ def upload_img(gr_img, text_input, chat_state):
llm_message = chat.upload_img(gr_img, chat_state, img_list)
return gr.update(interactive=False), gr.update(interactive=True, placeholder='Type and press Enter'), gr.update(value="Start Chatting", interactive=False), chat_state, img_list
+
def gradio_ask(user_message, chatbot, chat_state):
if len(user_message) == 0:
return gr.update(interactive=True, placeholder='Input should not be empty!'), chatbot, chat_state
@@ -101,6 +110,7 @@ def gradio_answer(chatbot, chat_state, img_list, num_beams, temperature):
chatbot[-1][1] = llm_message
return chatbot, chat_state, img_list
+
title = """
Demo of MiniGPT-4
"""
description = """This is the demo of MiniGPT-4. Upload your images and start chatting!
"""
article = """


diff --git a/eval_configs/minigpt4_eval.yaml b/eval_configs/minigpt4_eval.yaml
index f9e55a3..b7298bb 100644
--- a/eval_configs/minigpt4_eval.yaml
+++ b/eval_configs/minigpt4_eval.yaml
@@ -1,14 +1,11 @@
model:
arch: mini_gpt4
- model_type: pretrain_vicuna
- freeze_vit: True
- freeze_qformer: True
+ model_type: pretrain_vicuna0
max_txt_len: 160
end_sym: "###"
low_resource: True
- prompt_path: "prompts/alignment.txt"
prompt_template: '###Human: {} ###Assistant: '
- ckpt: '/path/to/pretrained/ckpt/'
+ ckpt: '/home/zhud/ibex/pretrained_minigpt4.pth'
datasets:
diff --git a/eval_configs/minigpt4_llama2_eval.yaml b/eval_configs/minigpt4_llama2_eval.yaml
new file mode 100644
index 0000000..d30d03a
--- /dev/null
+++ b/eval_configs/minigpt4_llama2_eval.yaml
@@ -0,0 +1,22 @@
+model:
+ arch: mini_gpt4
+ model_type: pretrain_llama2
+ max_txt_len: 160
+ end_sym: ""
+ low_resource: True
+ prompt_template: '[INST] {} [/INST] '
+ ckpt: '/home/zhud/c2090/zhud/project/MiniGPT-4/minigpt4/output/minigpt4_stage2_finetune/20230826182/checkpoint_4.pth'
+
+
+datasets:
+ cc_sbu_align:
+ vis_processor:
+ train:
+ name: "blip2_image_eval"
+ image_size: 224
+ text_processor:
+ train:
+ name: "blip_caption"
+
+run:
+ task: image_text_pretrain
diff --git a/minigpt4/common/dist_utils.py b/minigpt4/common/dist_utils.py
index 9280150..a6fc1b9 100644
--- a/minigpt4/common/dist_utils.py
+++ b/minigpt4/common/dist_utils.py
@@ -55,7 +55,10 @@ def is_main_process():
def init_distributed_mode(args):
- if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
+ if args.distributed is False:
+ print("Not using distributed mode")
+ return
+ elif "RANK" in os.environ and "WORLD_SIZE" in os.environ:
args.rank = int(os.environ["RANK"])
args.world_size = int(os.environ["WORLD_SIZE"])
args.gpu = int(os.environ["LOCAL_RANK"])
diff --git a/minigpt4/configs/models/minigpt4_llama2.yaml b/minigpt4/configs/models/minigpt4_llama2.yaml
new file mode 100644
index 0000000..c201bdc
--- /dev/null
+++ b/minigpt4/configs/models/minigpt4_llama2.yaml
@@ -0,0 +1,29 @@
+model:
+ arch: mini_gpt4
+
+ # vit encoder
+ image_size: 224
+ drop_path_rate: 0
+ use_grad_checkpoint: False
+ vit_precision: "fp16"
+ freeze_vit: True
+ has_qformer: False
+
+ # generation configs
+ prompt: ""
+
+ llama_model: "/path/to/llama2/weight"
+
+preprocess:
+ vis_processor:
+ train:
+ name: "blip2_image_train"
+ image_size: 224
+ eval:
+ name: "blip2_image_eval"
+ image_size: 224
+ text_processor:
+ train:
+ name: "blip_caption"
+ eval:
+ name: "blip_caption"
diff --git a/minigpt4/configs/models/minigpt4.yaml b/minigpt4/configs/models/minigpt4_vicuna0.yaml
similarity index 91%
rename from minigpt4/configs/models/minigpt4.yaml
rename to minigpt4/configs/models/minigpt4_vicuna0.yaml
index 87af448..34bd2ed 100644
--- a/minigpt4/configs/models/minigpt4.yaml
+++ b/minigpt4/configs/models/minigpt4_vicuna0.yaml
@@ -12,12 +12,11 @@ model:
# Q-Former
num_query_token: 32
- # Vicuna
- llama_model: "/path/to/vicuna/weights/"
-
# generation configs
prompt: ""
+ llama_model: "/path/to/vicuna/weight"
+
preprocess:
vis_processor:
train:
diff --git a/minigpt4/conversation/conversation.py b/minigpt4/conversation/conversation.py
index 676d89f..7678814 100644
--- a/minigpt4/conversation/conversation.py
+++ b/minigpt4/conversation/conversation.py
@@ -39,18 +39,18 @@ class Conversation:
ret = self.system + self.sep
for role, message in self.messages:
if message:
- ret += role + ": " + message + self.sep
+ ret += role + message + self.sep
else:
- ret += role + ":"
+ ret += role
return ret
elif self.sep_style == SeparatorStyle.TWO:
seps = [self.sep, self.sep2]
ret = self.system + seps[0]
for i, (role, message) in enumerate(self.messages):
if message:
- ret += role + ": " + message + seps[i % 2]
+ ret += role + message + seps[i % 2]
else:
- ret += role + ":"
+ ret += role
return ret
else:
raise ValueError(f"Invalid style: {self.sep_style}")
@@ -106,16 +106,26 @@ class StoppingCriteriaSub(StoppingCriteria):
return False
-CONV_VISION = Conversation(
+CONV_VISION_Vicuna0 = Conversation(
system="Give the following image:
ImageContent. "
"You will be able to see the image once I provide it to you. Please answer my questions.",
- roles=("Human", "Assistant"),
+ roles=("Human: ", "Assistant: "),
messages=[],
offset=2,
sep_style=SeparatorStyle.SINGLE,
sep="###",
)
+CONV_VISION_LLama2 = Conversation(
+ system="Give the following image:
ImageContent. "
+ "You will be able to see the image once I provide it to you. Please answer my questions.",
+ roles=("[INST] ", " [/INST] "),
+ messages=[],
+ offset=2,
+ sep_style=SeparatorStyle.SINGLE,
+ sep="",
+)
+
class Chat:
diff --git a/minigpt4/datasets/datasets/cc_sbu_dataset.py b/minigpt4/datasets/datasets/cc_sbu_dataset.py
index f42bbce..80b658d 100644
--- a/minigpt4/datasets/datasets/cc_sbu_dataset.py
+++ b/minigpt4/datasets/datasets/cc_sbu_dataset.py
@@ -22,7 +22,7 @@ class CCSBUDataset(BaseDataset):
def to_dict(self, sample):
return {
"image": sample[0],
- "text_input": self.text_processor(sample[1]["caption"]),
+ "answer": self.text_processor(sample[1]["caption"]),
}
@@ -42,6 +42,6 @@ class CCSBUAlignDataset(CaptionDataset):
return {
"image": image,
- "text_input": caption,
+ "answer": caption,
"image_id": self.img_ids[ann["image_id"]],
}
\ No newline at end of file
diff --git a/minigpt4/datasets/datasets/laion_dataset.py b/minigpt4/datasets/datasets/laion_dataset.py
index 1becbe4..6f3ce87 100644
--- a/minigpt4/datasets/datasets/laion_dataset.py
+++ b/minigpt4/datasets/datasets/laion_dataset.py
@@ -26,6 +26,6 @@ class LaionDataset(BaseDataset):
def to_dict(self, sample):
return {
"image": sample[0],
- "text_input": self.text_processor(sample[1]["caption"]),
+ "answer": self.text_processor(sample[1]["caption"]),
}
diff --git a/minigpt4/models/mini_gpt4.py b/minigpt4/models/mini_gpt4.py
index 667edd5..faed3d5 100644
--- a/minigpt4/models/mini_gpt4.py
+++ b/minigpt4/models/mini_gpt4.py
@@ -7,9 +7,17 @@ import torch.nn as nn
from minigpt4.common.registry import registry
from minigpt4.models.blip2 import Blip2Base, disabled_train
-from minigpt4.models.modeling_llama import LlamaForCausalLM
+from transformers.models.llama.modeling_llama import LlamaForCausalLM
from transformers import LlamaTokenizer
+from peft import (
+ LoraConfig,
+ get_peft_model,
+ get_peft_model_state_dict,
+ prepare_model_for_int8_training,
+ set_peft_model_state_dict,
+)
+
@registry.register_model("mini_gpt4")
class MiniGPT4(Blip2Base):
@@ -18,7 +26,8 @@ class MiniGPT4(Blip2Base):
"""
PRETRAINED_MODEL_CONFIG_DICT = {
- "pretrain_vicuna": "configs/models/minigpt4.yaml",
+ "pretrain_vicuna0": "configs/models/minigpt4_vicuna0.yaml",
+ "pretrain_llama2": "configs/models/minigpt4_llama2.yaml",
}
def __init__(
@@ -30,6 +39,7 @@ class MiniGPT4(Blip2Base):
use_grad_checkpoint=False,
vit_precision="fp16",
freeze_vit=True,
+ has_qformer=True,
freeze_qformer=True,
num_query_token=32,
llama_model="",
@@ -39,6 +49,10 @@ class MiniGPT4(Blip2Base):
end_sym='\n',
low_resource=False, # use 8 bit and put vit in cpu
device_8bit=0, # the device of 8bit model should be set when loading and cannot be changed anymore.
+ lora_r=0,
+ lora_target_modules=["q_proj", "v_proj"],
+ lora_alpha=16,
+ lora_dropout=0.05,
):
super().__init__()
@@ -61,30 +75,37 @@ class MiniGPT4(Blip2Base):
logging.info("freeze vision encoder")
print('Loading VIT Done')
- print('Loading Q-Former')
- self.Qformer, self.query_tokens = self.init_Qformer(
- num_query_token, self.visual_encoder.num_features
- )
- self.Qformer.cls = None
- self.Qformer.bert.embeddings.word_embeddings = None
- self.Qformer.bert.embeddings.position_embeddings = None
- for layer in self.Qformer.bert.encoder.layer:
- layer.output = None
- layer.intermediate = None
- self.load_from_pretrained(url_or_filename=q_former_model)
+ self.has_qformer = has_qformer
+ if self.has_qformer:
+ print('Loading Q-Former')
+ self.Qformer, self.query_tokens = self.init_Qformer(
+ num_query_token, self.visual_encoder.num_features
+ )
+ self.Qformer.cls = None
+ self.Qformer.bert.embeddings.word_embeddings = None
+ self.Qformer.bert.embeddings.position_embeddings = None
+ for layer in self.Qformer.bert.encoder.layer:
+ layer.output = None
+ layer.intermediate = None
+ self.load_from_pretrained(url_or_filename=q_former_model)
- if freeze_qformer:
- for name, param in self.Qformer.named_parameters():
- param.requires_grad = False
- self.Qformer = self.Qformer.eval()
- self.Qformer.train = disabled_train
- self.query_tokens.requires_grad = False
- logging.info("freeze Qformer")
- print('Loading Q-Former Done')
+ if freeze_qformer:
+ for name, param in self.Qformer.named_parameters():
+ param.requires_grad = False
+ self.Qformer = self.Qformer.eval()
+ self.Qformer.train = disabled_train
+ self.query_tokens.requires_grad = False
+ logging.info("freeze Qformer")
+
+ img_f_dim = self.Qformer.config.hidden_size
+ print('Loading Q-Former Done')
+ else:
+ img_f_dim = self.visual_encoder.num_features * 4
+ print('Do not use Q-Former here.')
print('Loading LLAMA')
self.llama_tokenizer = LlamaTokenizer.from_pretrained(llama_model, use_fast=False)
- self.llama_tokenizer.pad_token = self.llama_tokenizer.eos_token
+ self.llama_tokenizer.pad_token = "$$"
if self.low_resource:
self.llama_model = LlamaForCausalLM.from_pretrained(
@@ -99,12 +120,31 @@ class MiniGPT4(Blip2Base):
torch_dtype=torch.float16,
)
- for name, param in self.llama_model.named_parameters():
- param.requires_grad = False
+ if lora_r > 0:
+ self.llama_model = prepare_model_for_int8_training(self.llama_model)
+ loraconfig = LoraConfig(
+ r=lora_r,
+ lora_alpha=lora_alpha,
+ target_modules=lora_target_modules,
+ lora_dropout=lora_dropout,
+ bias="none",
+ task_type="CAUSAL_LM"
+ )
+ self.llama_model = get_peft_model(self.llama_model, loraconfig)
+
+ # if ckpt_path:
+ # print('load the llm under lora')
+ # ckpt = torch.load(ckpt_path)
+ # set_peft_model_state_dict(self.llama_model,ckpt)
+ self.llama_model.print_trainable_parameters()
+
+ else:
+ for name, param in self.llama_model.named_parameters():
+ param.requires_grad = False
print('Loading LLAMA Done')
self.llama_proj = nn.Linear(
- self.Qformer.config.hidden_size, self.llama_model.config.hidden_size
+ img_f_dim, self.llama_model.config.hidden_size
)
self.max_txt_len = max_txt_len
self.end_sym = end_sym
@@ -133,50 +173,109 @@ class MiniGPT4(Blip2Base):
with self.maybe_autocast():
image_embeds = self.ln_vision(self.visual_encoder(image)).to(device)
- image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(device)
+ if self.has_qformer:
+ image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(device)
- query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
- query_output = self.Qformer.bert(
- query_embeds=query_tokens,
- encoder_hidden_states=image_embeds,
- encoder_attention_mask=image_atts,
- return_dict=True,
- )
+ query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
+ query_output = self.Qformer.bert(
+ query_embeds=query_tokens,
+ encoder_hidden_states=image_embeds,
+ encoder_attention_mask=image_atts,
+ return_dict=True,
+ )
- inputs_llama = self.llama_proj(query_output.last_hidden_state)
+ inputs_llama = self.llama_proj(query_output.last_hidden_state)
+ else:
+ image_embeds = image_embeds[:, 1:, :]
+ bs, pn, hs = image_embeds.shape
+ image_embeds = image_embeds.view(bs, int(pn / 4), int(hs * 4))
+
+ inputs_llama = self.llama_proj(image_embeds)
atts_llama = torch.ones(inputs_llama.size()[:-1], dtype=torch.long).to(image.device)
return inputs_llama, atts_llama
- def prompt_wrap(self, img_embeds, atts_img, prompt):
- if prompt:
- batch_size = img_embeds.shape[0]
- p_before, p_after = prompt.split('')
- p_before_tokens = self.llama_tokenizer(
- p_before, return_tensors="pt", add_special_tokens=False).to(img_embeds.device)
- p_after_tokens = self.llama_tokenizer(
- p_after, return_tensors="pt", add_special_tokens=False).to(img_embeds.device)
- p_before_embeds = self.llama_model.model.embed_tokens(p_before_tokens.input_ids).expand(batch_size, -1, -1)
- p_after_embeds = self.llama_model.model.embed_tokens(p_after_tokens.input_ids).expand(batch_size, -1, -1)
- wrapped_img_embeds = torch.cat([p_before_embeds, img_embeds, p_after_embeds], dim=1)
- wrapped_atts_img = atts_img[:, :1].expand(-1, wrapped_img_embeds.shape[1])
- return wrapped_img_embeds, wrapped_atts_img
+ def get_context_emb(self, prompt, img_list):
+ device = img_list[0].device
+ prompt_segs = prompt.split('')
+ assert len(prompt_segs) == len(img_list) + 1, "Unmatched numbers of image placeholders and images."
+ seg_tokens = [
+ self.llama_tokenizer(
+ seg, return_tensors="pt", add_special_tokens=i == 0).to(device).input_ids
+ # only add bos to the first seg
+ for i, seg in enumerate(prompt_segs)
+ ]
+ seg_embs = [self.embed_tokens(seg_t) for seg_t in seg_tokens]
+
+ mixed_embs = [emb for pair in zip(seg_embs[:-1], img_list) for emb in pair] + [seg_embs[-1]]
+ mixed_embs = torch.cat(mixed_embs, dim=1)
+ return mixed_embs
+
+ def prompt_wrap(self, img_embeds, atts_img, prompts):
+ if prompts:
+ emb_lists = []
+ if isinstance(prompts, str):
+ prompts = [prompts] * len(img_embeds)
+
+ for each_img_embed, each_prompt in zip(img_embeds, prompts):
+ p_before, p_after = each_prompt.split('')
+
+ p_before_tokens = self.llama_tokenizer(
+ p_before, return_tensors="pt", add_special_tokens=False).to(img_embeds.device)
+ p_after_tokens = self.llama_tokenizer(
+ p_after, return_tensors="pt", add_special_tokens=False).to(img_embeds.device)
+ p_before_embed = self.embed_tokens(p_before_tokens.input_ids)
+ p_after_embed = self.embed_tokens(p_after_tokens.input_ids)
+ wrapped_emb = torch.cat([p_before_embed, each_img_embed[None], p_after_embed], dim=1)
+ emb_lists.append(wrapped_emb)
+ emb_lens = [emb.shape[1] for emb in emb_lists]
+ pad_emb = self.embed_tokens(torch.tensor(self.llama_tokenizer.pad_token_id, device=img_embeds.device))
+ wrapped_embs = pad_emb.expand(len(emb_lens), max(emb_lens), -1).clone()
+ wrapped_atts = torch.zeros([len(emb_lens), max(emb_lens)], dtype=torch.int, device=img_embeds.device)
+ for i, emb in enumerate(emb_lists):
+ wrapped_embs[i, :emb_lens[i]] = emb
+ wrapped_atts[i, :emb_lens[i]] = 1
+ return wrapped_embs, wrapped_atts
else:
return img_embeds, atts_img
+ def concat_emb_input_output(self, input_embs, input_atts, output_embs, output_atts):
+ input_lens = []
+ cat_embs = []
+ cat_atts = []
+ for i in range(input_embs.size(0)):
+ input_len = input_atts[i].sum()
+ input_lens.append(input_len)
+ cat_embs.append(
+ torch.cat([
+ input_embs[i][:input_len],
+ output_embs[i],
+ input_embs[i][input_len:]
+ ])
+ )
+ cat_atts.append(
+ torch.cat([
+ input_atts[i][:input_len],
+ output_atts[i],
+ input_atts[i][input_len:]
+ ])
+ )
+ cat_embs = torch.stack(cat_embs)
+ cat_atts = torch.stack(cat_atts)
+ return cat_embs, cat_atts, input_lens
+
def forward(self, samples):
image = samples["image"]
img_embeds, atts_img = self.encode_img(image)
- if hasattr(samples, 'question_split'): # VQA dataset
- print('VQA Batch')
- vqa_prompt = '###Human:
'
- img_embeds, atts_img = self.prompt_wrap(img_embeds, atts_img, vqa_prompt)
- elif self.prompt_list:
- prompt = random.choice(self.prompt_list)
- img_embeds, atts_img = self.prompt_wrap(img_embeds, atts_img, prompt)
+
+ if self.prompt_list:
+ instruction = random.choice(self.prompt_list)
+ else:
+ instruction = samples["instruction_input"] if "instruction_input" in samples else None
+
+ img_embeds, atts_img = self.prompt_wrap(img_embeds, atts_img, instruction)
self.llama_tokenizer.padding_side = "right"
-
- text = [t + self.end_sym for t in samples["text_input"]]
+ text = [t + self.end_sym for t in samples["answer"]]
to_regress_tokens = self.llama_tokenizer(
text,
@@ -187,26 +286,29 @@ class MiniGPT4(Blip2Base):
add_special_tokens=False
).to(image.device)
- targets = to_regress_tokens.input_ids.masked_fill(
- to_regress_tokens.input_ids == self.llama_tokenizer.pad_token_id, -100
- )
-
- empty_targets = (
- torch.ones([atts_img.shape[0], atts_img.shape[1]+1],
- dtype=torch.long).to(image.device).fill_(-100) # plus one for bos
- )
- targets = torch.cat([empty_targets, targets], dim=1)
-
batch_size = img_embeds.shape[0]
bos = torch.ones([batch_size, 1],
dtype=to_regress_tokens.input_ids.dtype,
device=to_regress_tokens.input_ids.device) * self.llama_tokenizer.bos_token_id
- bos_embeds = self.llama_model.model.embed_tokens(bos)
+ bos_embeds = self.embed_tokens(bos)
atts_bos = atts_img[:, :1]
- to_regress_embeds = self.llama_model.model.embed_tokens(to_regress_tokens.input_ids)
- inputs_embeds = torch.cat([bos_embeds, img_embeds, to_regress_embeds], dim=1)
- attention_mask = torch.cat([atts_bos, atts_img, to_regress_tokens.attention_mask], dim=1)
+ to_regress_embeds = self.embed_tokens(to_regress_tokens.input_ids)
+ inputs_embeds, attention_mask, input_lens = \
+ self.concat_emb_input_output(img_embeds, atts_img, to_regress_embeds, to_regress_tokens.attention_mask)
+ inputs_embeds = torch.cat([bos_embeds, inputs_embeds], dim=1)
+ attention_mask = torch.cat([atts_bos, attention_mask], dim=1)
+
+ part_targets = to_regress_tokens.input_ids.masked_fill(
+ to_regress_tokens.input_ids == self.llama_tokenizer.pad_token_id, -100
+ )
+ targets = (
+ torch.ones([inputs_embeds.shape[0], inputs_embeds.shape[1]],
+ dtype=torch.long).to(image.device).fill_(-100)
+ )
+
+ for i, target in enumerate(part_targets):
+ targets[i, input_lens[i] + 1:input_lens[i] + len(target) + 1] = target # plus 1 for bos
with self.maybe_autocast():
outputs = self.llama_model(
@@ -219,6 +321,13 @@ class MiniGPT4(Blip2Base):
return {"loss": loss}
+ def embed_tokens(self, token_ids):
+ if hasattr(self.llama_model.base_model, 'model'): ## lora wrapped model
+ embeds = self.llama_model.base_model.model.model.embed_tokens(token_ids)
+ else:
+ embeds = self.llama_model.base_model.embed_tokens(token_ids)
+ return embeds
+
@classmethod
def from_config(cls, cfg):
vit_model = cfg.get("vit_model", "eva_clip_g")
@@ -231,6 +340,7 @@ class MiniGPT4(Blip2Base):
use_grad_checkpoint = cfg.get("use_grad_checkpoint", False)
vit_precision = cfg.get("vit_precision", "fp16")
freeze_vit = cfg.get("freeze_vit", True)
+ has_qformer = cfg.get("has_qformer", True)
freeze_qformer = cfg.get("freeze_qformer", True)
low_resource = cfg.get("low_resource", False)
device_8bit = cfg.get("device_8bit", 0)
@@ -240,6 +350,9 @@ class MiniGPT4(Blip2Base):
max_txt_len = cfg.get("max_txt_len", 32)
end_sym = cfg.get("end_sym", '\n')
+ lora_r = cfg.get("lora_r", 0)
+ lora_alpha = cfg.get("lora_alpha", 32)
+
model = cls(
vit_model=vit_model,
q_former_model=q_former_model,
@@ -248,6 +361,7 @@ class MiniGPT4(Blip2Base):
use_grad_checkpoint=use_grad_checkpoint,
vit_precision=vit_precision,
freeze_vit=freeze_vit,
+ has_qformer=has_qformer,
freeze_qformer=freeze_qformer,
num_query_token=num_query_token,
llama_model=llama_model,
@@ -257,6 +371,8 @@ class MiniGPT4(Blip2Base):
end_sym=end_sym,
low_resource=low_resource,
device_8bit=device_8bit,
+ lora_r=lora_r,
+ lora_alpha=lora_alpha,
)
ckpt_path = cfg.get("ckpt", "") # load weights of MiniGPT-4
diff --git a/train_configs/minigpt4_llama2_stage1_pretrain.yaml b/train_configs/minigpt4_llama2_stage1_pretrain.yaml
new file mode 100644
index 0000000..6920aab
--- /dev/null
+++ b/train_configs/minigpt4_llama2_stage1_pretrain.yaml
@@ -0,0 +1,55 @@
+model:
+ arch: mini_gpt4
+ model_type: pretrain_llama2
+
+
+datasets:
+ laion:
+ vis_processor:
+ train:
+ name: "blip2_image_train"
+ image_size: 224
+ text_processor:
+ train:
+ name: "blip_caption"
+ sample_ratio: 115
+ cc_sbu:
+ vis_processor:
+ train:
+ name: "blip2_image_train"
+ image_size: 224
+ text_processor:
+ train:
+ name: "blip_caption"
+ sample_ratio: 14
+
+
+run:
+ task: image_text_pretrain
+ # optimizer
+ lr_sched: "linear_warmup_cosine_lr"
+ init_lr: 1e-4
+ min_lr: 8e-5
+ warmup_lr: 1e-6
+
+ weight_decay: 0.05
+ max_epoch: 4
+ batch_size_train: 64
+ batch_size_eval: 64
+ num_workers: 4
+ warmup_steps: 5000
+ iters_per_epoch: 5000
+
+ seed: 42
+ output_dir: "output/minigpt4_stage1_pretrain"
+
+ amp: True
+ resume_ckpt_path: null
+
+ evaluate: False
+ train_splits: ["train"]
+
+ device: "cuda"
+ world_size: 1
+ dist_url: "env://"
+ distributed: True
\ No newline at end of file
diff --git a/train_configs/minigpt4_llama2_stage2_finetune.yaml b/train_configs/minigpt4_llama2_stage2_finetune.yaml
new file mode 100644
index 0000000..9a6ac2d
--- /dev/null
+++ b/train_configs/minigpt4_llama2_stage2_finetune.yaml
@@ -0,0 +1,50 @@
+model:
+ arch: mini_gpt4
+ model_type: pretrain_llama2
+
+ max_txt_len: 160
+ end_sym: ""
+ prompt_path: "prompts/alignment.txt"
+ prompt_template: '[INST] {} [/INST] '
+ ckpt: '/path/to/stage1/checkpoint/'
+
+
+datasets:
+ cc_sbu_align:
+ vis_processor:
+ train:
+ name: "blip2_image_train"
+ image_size: 224
+ text_processor:
+ train:
+ name: "blip_caption"
+
+run:
+ task: image_text_pretrain
+ # optimizer
+ lr_sched: "linear_warmup_cosine_lr"
+ init_lr: 3e-5
+ min_lr: 1e-5
+ warmup_lr: 1e-6
+
+ weight_decay: 0.05
+ max_epoch: 5
+ iters_per_epoch: 200
+ batch_size_train: 12
+ batch_size_eval: 12
+ num_workers: 4
+ warmup_steps: 200
+
+ seed: 42
+ output_dir: "output/minigpt4_stage2_finetune"
+
+ amp: True
+ resume_ckpt_path: null
+
+ evaluate: False
+ train_splits: ["train"]
+
+ device: "cuda"
+ world_size: 1
+ dist_url: "env://"
+ distributed: True
\ No newline at end of file
diff --git a/train_configs/minigpt4_stage1_pretrain.yaml b/train_configs/minigpt4_stage1_pretrain.yaml
index 044246c..4ec1597 100644
--- a/train_configs/minigpt4_stage1_pretrain.yaml
+++ b/train_configs/minigpt4_stage1_pretrain.yaml
@@ -1,8 +1,6 @@
model:
arch: mini_gpt4
- model_type: pretrain_vicuna
- freeze_vit: True
- freeze_qformer: True
+ model_type: pretrain_vicuna0
datasets:
diff --git a/train_configs/minigpt4_stage2_finetune.yaml b/train_configs/minigpt4_stage2_finetune.yaml
index 1013bea..54cedb4 100644
--- a/train_configs/minigpt4_stage2_finetune.yaml
+++ b/train_configs/minigpt4_stage2_finetune.yaml
@@ -1,8 +1,7 @@
model:
arch: mini_gpt4
- model_type: pretrain_vicuna
- freeze_vit: True
- freeze_qformer: True
+ model_type: pretrain_vicuna0
+
max_txt_len: 160
end_sym: "###"
prompt_path: "prompts/alignment.txt"