From fb8e2c656a882de1472a99ab7b9331e61be93955 Mon Sep 17 00:00:00 2001 From: Deyao Zhu Date: Mon, 28 Aug 2023 21:26:00 +0300 Subject: [PATCH 1/3] include llama2 --- README.md | 49 ++-- demo.py | 12 +- eval_configs/minigpt4_eval.yaml | 7 +- eval_configs/minigpt4_llama2_eval.yaml | 22 ++ minigpt4/common/dist_utils.py | 5 +- minigpt4/configs/models/minigpt4_llama2.yaml | 29 ++ .../{minigpt4.yaml => minigpt4_vicuna0.yaml} | 5 +- minigpt4/conversation/conversation.py | 22 +- minigpt4/datasets/datasets/cc_sbu_dataset.py | 4 +- minigpt4/datasets/datasets/laion_dataset.py | 2 +- minigpt4/models/mini_gpt4.py | 256 +++++++++++++----- .../minigpt4_llama2_stage1_pretrain.yaml | 55 ++++ .../minigpt4_llama2_stage2_finetune.yaml | 50 ++++ train_configs/minigpt4_stage1_pretrain.yaml | 4 +- train_configs/minigpt4_stage2_finetune.yaml | 5 +- 15 files changed, 409 insertions(+), 118 deletions(-) create mode 100644 eval_configs/minigpt4_llama2_eval.yaml create mode 100644 minigpt4/configs/models/minigpt4_llama2.yaml rename minigpt4/configs/models/{minigpt4.yaml => minigpt4_vicuna0.yaml} (91%) create mode 100644 train_configs/minigpt4_llama2_stage1_pretrain.yaml create mode 100644 train_configs/minigpt4_llama2_stage2_finetune.yaml diff --git a/README.md b/README.md index 0e98d20..28e4f22 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,5 @@ # MiniGPT-4: Enhancing Vision-language Understanding with Advanced Large Language Models -[Deyao Zhu](https://tsutikgiau.github.io/)* (On Job Market!), [Jun Chen](https://junchen14.github.io/)* (On Job Market!), [Xiaoqian Shen](https://xiaoqian-shen.github.io), [Xiang Li](https://xiangli.ac.cn), and [Mohamed Elhoseiny](https://www.mohamed-elhoseiny.com/). *Equal Contribution +[Deyao Zhu](https://tsutikgiau.github.io/)* , [Jun Chen](https://junchen14.github.io/)* (On Job Market!), [Xiaoqian Shen](https://xiaoqian-shen.github.io), [Xiang Li](https://xiangli.ac.cn), and [Mohamed Elhoseiny](https://www.mohamed-elhoseiny.com/). *Equal Contribution **King Abdullah University of Science and Technology** @@ -7,7 +7,7 @@ ## News -We now provide a pretrained MiniGPT-4 aligned with Vicuna-7B! The demo GPU memory consumption now can be as low as 12GB. +We now provide a llama 2 version of MiniGPT-4 ## Online Demo @@ -52,49 +52,52 @@ conda activate minigpt4 ``` -**2. Prepare the pretrained Vicuna weights** +**2. Prepare the pretrained LLM weights** -The current version of MiniGPT-4 is built on the v0 version of Vicuna-13B. -Please refer to our instruction [here](PrepareVicuna.md) -to prepare the Vicuna weights. -The final weights would be in a single folder in a structure similar to the following: +Currently, we provide both Vicuna V0 and Llama 2 version of MiniGPT-4. +Download the corresponding LLM weights from the following huggingface space via clone the repository using git-lfs. + +| Vicuna V0 13B | Vicuna V0 7B | Llama 2 Chat 7B | +:------------------------------------------------------------------------------------------------:|:----------------------------------------------------------------------------------------------:|:----------------------------------------------------------------------------------------------: + [Downlad](https://huggingface.co/Vision-CAIR/vicuna/tree/main) | [Download](https://huggingface.co/Vision-CAIR/vicuna-7b/tree/main) | [Download](https://huggingface.co/meta-llama/Llama-2-7b-chat/tree/main) -``` -vicuna_weights -├── config.json -├── generation_config.json -├── pytorch_model.bin.index.json -├── pytorch_model-00001-of-00003.bin -... -``` Then, set the path to the vicuna weight in the model config file -[here](minigpt4/configs/models/minigpt4.yaml#L16) at Line 16. +[here](minigpt4/configs/models/minigpt4_vicuna0.yaml#L18) at Line 18 +and/or the path to the llama2 weight in the model config file +[here](minigpt4/configs/models/minigpt4_llama2.yaml#L15) at Line 15. **3. Prepare the pretrained MiniGPT-4 checkpoint** Download the pretrained checkpoints according to the Vicuna model you prepare. -| Checkpoint Aligned with Vicuna 13B | Checkpoint Aligned with Vicuna 7B | -:------------------------------------------------------------------------------------------------:|:----------------------------------------------------------------------------------------------: - [Downlad](https://drive.google.com/file/d/1a4zLvaiDBr-36pasffmgpvH5P7CKmpze/view?usp=share_link) | [Download](https://drive.google.com/file/d/1RY9jV0dyqLX-o38LrumkKRh6Jtaop58R/view?usp=sharing) +| Checkpoint Aligned with Vicuna 13B | Checkpoint Aligned with Vicuna 7B | Checkpoint Aligned with Llama 2 Chat 7B | +:------------------------------------------------------------------------------------------------:|:-----------------------------------------------------------------------------------------------:|:----------------------------------------------------------------------------------------------: + [Downlad](https://drive.google.com/file/d/1a4zLvaiDBr-36pasffmgpvH5P7CKmpze/view?usp=share_link) | [Download](https://drive.google.com/file/d/1RY9jV0dyqLX-o38LrumkKRh6Jtaop58R/view?usp=sharing) | [Download](https://drive.google.com/file/d/11nAPjEok8eAGGEG1N2vXo3kBLCg0WgUk/view?usp=sharing) Then, set the path to the pretrained checkpoint in the evaluation config file -in [eval_configs/minigpt4_eval.yaml](eval_configs/minigpt4_eval.yaml#L10) at Line 11. +in [eval_configs/minigpt4_eval.yaml](eval_configs/minigpt4_eval.yaml#L10) at Line 8 for Vicuna version or [eval_configs/minigpt4_llama2_eval.yaml](eval_configs/minigpt4_llama2_eval.yaml#L10) for LLama2 version. ### Launching Demo Locally -Try out our demo [demo.py](demo.py) on your local machine by running +Try out our demo [demo.py](demo.py) for the vicuna version on your local machine by running ``` python demo.py --cfg-path eval_configs/minigpt4_eval.yaml --gpu-id 0 ``` -To save GPU memory, Vicuna loads as 8 bit by default, with a beam search width of 1. -This configuration requires about 23G GPU memory for Vicuna 13B and 11.5G GPU memory for Vicuna 7B. +or for Llama 2 version by + +``` +python demo.py --cfg-path eval_configs/minigpt4_llama2_eval.yaml --gpu-id 0 +``` + + +To save GPU memory, LLMs loads as 8 bit by default, with a beam search width of 1. +This configuration requires about 23G GPU memory for 13B LLM and 11.5G GPU memory for 7B LLM. For more powerful GPUs, you can run the model in 16 bit by setting low_resource to False in the config file [minigpt4_eval.yaml](eval_configs/minigpt4_eval.yaml) and use a larger beam search width. diff --git a/demo.py b/demo.py index b3659f1..483b56c 100644 --- a/demo.py +++ b/demo.py @@ -10,7 +10,7 @@ import gradio as gr from minigpt4.common.config import Config from minigpt4.common.dist_utils import get_rank from minigpt4.common.registry import registry -from minigpt4.conversation.conversation import Chat, CONV_VISION +from minigpt4.conversation.conversation import Chat, CONV_VISION_Vicuna0, CONV_VISION_LLama2 # imports modules for registration from minigpt4.datasets.builders import * @@ -50,6 +50,9 @@ def setup_seeds(config): # Model Initialization # ======================================== +conv_dict = {'pretrain_vicuna0': CONV_VISION_Vicuna0, + 'pretrain_llama2': CONV_VISION_LLama2} + print('Initializing Chat') args = parse_args() cfg = Config(args) @@ -59,15 +62,19 @@ model_config.device_8bit = args.gpu_id model_cls = registry.get_model_class(model_config.arch) model = model_cls.from_config(model_config).to('cuda:{}'.format(args.gpu_id)) +CONV_VISION = conv_dict[model_config.model_type] + vis_processor_cfg = cfg.datasets_cfg.cc_sbu_align.vis_processor.train vis_processor = registry.get_processor_class(vis_processor_cfg.name).from_config(vis_processor_cfg) chat = Chat(model, vis_processor, device='cuda:{}'.format(args.gpu_id)) print('Initialization Finished') + # ======================================== # Gradio Setting # ======================================== + def gradio_reset(chat_state, img_list): if chat_state is not None: chat_state.messages = [] @@ -75,6 +82,7 @@ def gradio_reset(chat_state, img_list): img_list = [] return None, gr.update(value=None, interactive=True), gr.update(placeholder='Please upload your image first', interactive=False),gr.update(value="Upload & Start Chat", interactive=True), chat_state, img_list + def upload_img(gr_img, text_input, chat_state): if gr_img is None: return None, None, gr.update(interactive=True), chat_state, None @@ -83,6 +91,7 @@ def upload_img(gr_img, text_input, chat_state): llm_message = chat.upload_img(gr_img, chat_state, img_list) return gr.update(interactive=False), gr.update(interactive=True, placeholder='Type and press Enter'), gr.update(value="Start Chatting", interactive=False), chat_state, img_list + def gradio_ask(user_message, chatbot, chat_state): if len(user_message) == 0: return gr.update(interactive=True, placeholder='Input should not be empty!'), chatbot, chat_state @@ -101,6 +110,7 @@ def gradio_answer(chatbot, chat_state, img_list, num_beams, temperature): chatbot[-1][1] = llm_message return chatbot, chat_state, img_list + title = """

Demo of MiniGPT-4

""" description = """

This is the demo of MiniGPT-4. Upload your images and start chatting!

""" article = """

diff --git a/eval_configs/minigpt4_eval.yaml b/eval_configs/minigpt4_eval.yaml index f9e55a3..b7298bb 100644 --- a/eval_configs/minigpt4_eval.yaml +++ b/eval_configs/minigpt4_eval.yaml @@ -1,14 +1,11 @@ model: arch: mini_gpt4 - model_type: pretrain_vicuna - freeze_vit: True - freeze_qformer: True + model_type: pretrain_vicuna0 max_txt_len: 160 end_sym: "###" low_resource: True - prompt_path: "prompts/alignment.txt" prompt_template: '###Human: {} ###Assistant: ' - ckpt: '/path/to/pretrained/ckpt/' + ckpt: '/home/zhud/ibex/pretrained_minigpt4.pth' datasets: diff --git a/eval_configs/minigpt4_llama2_eval.yaml b/eval_configs/minigpt4_llama2_eval.yaml new file mode 100644 index 0000000..d30d03a --- /dev/null +++ b/eval_configs/minigpt4_llama2_eval.yaml @@ -0,0 +1,22 @@ +model: + arch: mini_gpt4 + model_type: pretrain_llama2 + max_txt_len: 160 + end_sym: "" + low_resource: True + prompt_template: '[INST] {} [/INST] ' + ckpt: '/home/zhud/c2090/zhud/project/MiniGPT-4/minigpt4/output/minigpt4_stage2_finetune/20230826182/checkpoint_4.pth' + + +datasets: + cc_sbu_align: + vis_processor: + train: + name: "blip2_image_eval" + image_size: 224 + text_processor: + train: + name: "blip_caption" + +run: + task: image_text_pretrain diff --git a/minigpt4/common/dist_utils.py b/minigpt4/common/dist_utils.py index 9280150..a6fc1b9 100644 --- a/minigpt4/common/dist_utils.py +++ b/minigpt4/common/dist_utils.py @@ -55,7 +55,10 @@ def is_main_process(): def init_distributed_mode(args): - if "RANK" in os.environ and "WORLD_SIZE" in os.environ: + if args.distributed is False: + print("Not using distributed mode") + return + elif "RANK" in os.environ and "WORLD_SIZE" in os.environ: args.rank = int(os.environ["RANK"]) args.world_size = int(os.environ["WORLD_SIZE"]) args.gpu = int(os.environ["LOCAL_RANK"]) diff --git a/minigpt4/configs/models/minigpt4_llama2.yaml b/minigpt4/configs/models/minigpt4_llama2.yaml new file mode 100644 index 0000000..c201bdc --- /dev/null +++ b/minigpt4/configs/models/minigpt4_llama2.yaml @@ -0,0 +1,29 @@ +model: + arch: mini_gpt4 + + # vit encoder + image_size: 224 + drop_path_rate: 0 + use_grad_checkpoint: False + vit_precision: "fp16" + freeze_vit: True + has_qformer: False + + # generation configs + prompt: "" + + llama_model: "/path/to/llama2/weight" + +preprocess: + vis_processor: + train: + name: "blip2_image_train" + image_size: 224 + eval: + name: "blip2_image_eval" + image_size: 224 + text_processor: + train: + name: "blip_caption" + eval: + name: "blip_caption" diff --git a/minigpt4/configs/models/minigpt4.yaml b/minigpt4/configs/models/minigpt4_vicuna0.yaml similarity index 91% rename from minigpt4/configs/models/minigpt4.yaml rename to minigpt4/configs/models/minigpt4_vicuna0.yaml index 87af448..34bd2ed 100644 --- a/minigpt4/configs/models/minigpt4.yaml +++ b/minigpt4/configs/models/minigpt4_vicuna0.yaml @@ -12,12 +12,11 @@ model: # Q-Former num_query_token: 32 - # Vicuna - llama_model: "/path/to/vicuna/weights/" - # generation configs prompt: "" + llama_model: "/path/to/vicuna/weight" + preprocess: vis_processor: train: diff --git a/minigpt4/conversation/conversation.py b/minigpt4/conversation/conversation.py index 676d89f..7678814 100644 --- a/minigpt4/conversation/conversation.py +++ b/minigpt4/conversation/conversation.py @@ -39,18 +39,18 @@ class Conversation: ret = self.system + self.sep for role, message in self.messages: if message: - ret += role + ": " + message + self.sep + ret += role + message + self.sep else: - ret += role + ":" + ret += role return ret elif self.sep_style == SeparatorStyle.TWO: seps = [self.sep, self.sep2] ret = self.system + seps[0] for i, (role, message) in enumerate(self.messages): if message: - ret += role + ": " + message + seps[i % 2] + ret += role + message + seps[i % 2] else: - ret += role + ":" + ret += role return ret else: raise ValueError(f"Invalid style: {self.sep_style}") @@ -106,16 +106,26 @@ class StoppingCriteriaSub(StoppingCriteria): return False -CONV_VISION = Conversation( +CONV_VISION_Vicuna0 = Conversation( system="Give the following image: ImageContent. " "You will be able to see the image once I provide it to you. Please answer my questions.", - roles=("Human", "Assistant"), + roles=("Human: ", "Assistant: "), messages=[], offset=2, sep_style=SeparatorStyle.SINGLE, sep="###", ) +CONV_VISION_LLama2 = Conversation( + system="Give the following image: ImageContent. " + "You will be able to see the image once I provide it to you. Please answer my questions.", + roles=("[INST] ", " [/INST] "), + messages=[], + offset=2, + sep_style=SeparatorStyle.SINGLE, + sep="", +) + class Chat: diff --git a/minigpt4/datasets/datasets/cc_sbu_dataset.py b/minigpt4/datasets/datasets/cc_sbu_dataset.py index f42bbce..80b658d 100644 --- a/minigpt4/datasets/datasets/cc_sbu_dataset.py +++ b/minigpt4/datasets/datasets/cc_sbu_dataset.py @@ -22,7 +22,7 @@ class CCSBUDataset(BaseDataset): def to_dict(self, sample): return { "image": sample[0], - "text_input": self.text_processor(sample[1]["caption"]), + "answer": self.text_processor(sample[1]["caption"]), } @@ -42,6 +42,6 @@ class CCSBUAlignDataset(CaptionDataset): return { "image": image, - "text_input": caption, + "answer": caption, "image_id": self.img_ids[ann["image_id"]], } \ No newline at end of file diff --git a/minigpt4/datasets/datasets/laion_dataset.py b/minigpt4/datasets/datasets/laion_dataset.py index 1becbe4..6f3ce87 100644 --- a/minigpt4/datasets/datasets/laion_dataset.py +++ b/minigpt4/datasets/datasets/laion_dataset.py @@ -26,6 +26,6 @@ class LaionDataset(BaseDataset): def to_dict(self, sample): return { "image": sample[0], - "text_input": self.text_processor(sample[1]["caption"]), + "answer": self.text_processor(sample[1]["caption"]), } diff --git a/minigpt4/models/mini_gpt4.py b/minigpt4/models/mini_gpt4.py index 667edd5..faed3d5 100644 --- a/minigpt4/models/mini_gpt4.py +++ b/minigpt4/models/mini_gpt4.py @@ -7,9 +7,17 @@ import torch.nn as nn from minigpt4.common.registry import registry from minigpt4.models.blip2 import Blip2Base, disabled_train -from minigpt4.models.modeling_llama import LlamaForCausalLM +from transformers.models.llama.modeling_llama import LlamaForCausalLM from transformers import LlamaTokenizer +from peft import ( + LoraConfig, + get_peft_model, + get_peft_model_state_dict, + prepare_model_for_int8_training, + set_peft_model_state_dict, +) + @registry.register_model("mini_gpt4") class MiniGPT4(Blip2Base): @@ -18,7 +26,8 @@ class MiniGPT4(Blip2Base): """ PRETRAINED_MODEL_CONFIG_DICT = { - "pretrain_vicuna": "configs/models/minigpt4.yaml", + "pretrain_vicuna0": "configs/models/minigpt4_vicuna0.yaml", + "pretrain_llama2": "configs/models/minigpt4_llama2.yaml", } def __init__( @@ -30,6 +39,7 @@ class MiniGPT4(Blip2Base): use_grad_checkpoint=False, vit_precision="fp16", freeze_vit=True, + has_qformer=True, freeze_qformer=True, num_query_token=32, llama_model="", @@ -39,6 +49,10 @@ class MiniGPT4(Blip2Base): end_sym='\n', low_resource=False, # use 8 bit and put vit in cpu device_8bit=0, # the device of 8bit model should be set when loading and cannot be changed anymore. + lora_r=0, + lora_target_modules=["q_proj", "v_proj"], + lora_alpha=16, + lora_dropout=0.05, ): super().__init__() @@ -61,30 +75,37 @@ class MiniGPT4(Blip2Base): logging.info("freeze vision encoder") print('Loading VIT Done') - print('Loading Q-Former') - self.Qformer, self.query_tokens = self.init_Qformer( - num_query_token, self.visual_encoder.num_features - ) - self.Qformer.cls = None - self.Qformer.bert.embeddings.word_embeddings = None - self.Qformer.bert.embeddings.position_embeddings = None - for layer in self.Qformer.bert.encoder.layer: - layer.output = None - layer.intermediate = None - self.load_from_pretrained(url_or_filename=q_former_model) + self.has_qformer = has_qformer + if self.has_qformer: + print('Loading Q-Former') + self.Qformer, self.query_tokens = self.init_Qformer( + num_query_token, self.visual_encoder.num_features + ) + self.Qformer.cls = None + self.Qformer.bert.embeddings.word_embeddings = None + self.Qformer.bert.embeddings.position_embeddings = None + for layer in self.Qformer.bert.encoder.layer: + layer.output = None + layer.intermediate = None + self.load_from_pretrained(url_or_filename=q_former_model) - if freeze_qformer: - for name, param in self.Qformer.named_parameters(): - param.requires_grad = False - self.Qformer = self.Qformer.eval() - self.Qformer.train = disabled_train - self.query_tokens.requires_grad = False - logging.info("freeze Qformer") - print('Loading Q-Former Done') + if freeze_qformer: + for name, param in self.Qformer.named_parameters(): + param.requires_grad = False + self.Qformer = self.Qformer.eval() + self.Qformer.train = disabled_train + self.query_tokens.requires_grad = False + logging.info("freeze Qformer") + + img_f_dim = self.Qformer.config.hidden_size + print('Loading Q-Former Done') + else: + img_f_dim = self.visual_encoder.num_features * 4 + print('Do not use Q-Former here.') print('Loading LLAMA') self.llama_tokenizer = LlamaTokenizer.from_pretrained(llama_model, use_fast=False) - self.llama_tokenizer.pad_token = self.llama_tokenizer.eos_token + self.llama_tokenizer.pad_token = "$$" if self.low_resource: self.llama_model = LlamaForCausalLM.from_pretrained( @@ -99,12 +120,31 @@ class MiniGPT4(Blip2Base): torch_dtype=torch.float16, ) - for name, param in self.llama_model.named_parameters(): - param.requires_grad = False + if lora_r > 0: + self.llama_model = prepare_model_for_int8_training(self.llama_model) + loraconfig = LoraConfig( + r=lora_r, + lora_alpha=lora_alpha, + target_modules=lora_target_modules, + lora_dropout=lora_dropout, + bias="none", + task_type="CAUSAL_LM" + ) + self.llama_model = get_peft_model(self.llama_model, loraconfig) + + # if ckpt_path: + # print('load the llm under lora') + # ckpt = torch.load(ckpt_path) + # set_peft_model_state_dict(self.llama_model,ckpt) + self.llama_model.print_trainable_parameters() + + else: + for name, param in self.llama_model.named_parameters(): + param.requires_grad = False print('Loading LLAMA Done') self.llama_proj = nn.Linear( - self.Qformer.config.hidden_size, self.llama_model.config.hidden_size + img_f_dim, self.llama_model.config.hidden_size ) self.max_txt_len = max_txt_len self.end_sym = end_sym @@ -133,50 +173,109 @@ class MiniGPT4(Blip2Base): with self.maybe_autocast(): image_embeds = self.ln_vision(self.visual_encoder(image)).to(device) - image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(device) + if self.has_qformer: + image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(device) - query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1) - query_output = self.Qformer.bert( - query_embeds=query_tokens, - encoder_hidden_states=image_embeds, - encoder_attention_mask=image_atts, - return_dict=True, - ) + query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1) + query_output = self.Qformer.bert( + query_embeds=query_tokens, + encoder_hidden_states=image_embeds, + encoder_attention_mask=image_atts, + return_dict=True, + ) - inputs_llama = self.llama_proj(query_output.last_hidden_state) + inputs_llama = self.llama_proj(query_output.last_hidden_state) + else: + image_embeds = image_embeds[:, 1:, :] + bs, pn, hs = image_embeds.shape + image_embeds = image_embeds.view(bs, int(pn / 4), int(hs * 4)) + + inputs_llama = self.llama_proj(image_embeds) atts_llama = torch.ones(inputs_llama.size()[:-1], dtype=torch.long).to(image.device) return inputs_llama, atts_llama - def prompt_wrap(self, img_embeds, atts_img, prompt): - if prompt: - batch_size = img_embeds.shape[0] - p_before, p_after = prompt.split('') - p_before_tokens = self.llama_tokenizer( - p_before, return_tensors="pt", add_special_tokens=False).to(img_embeds.device) - p_after_tokens = self.llama_tokenizer( - p_after, return_tensors="pt", add_special_tokens=False).to(img_embeds.device) - p_before_embeds = self.llama_model.model.embed_tokens(p_before_tokens.input_ids).expand(batch_size, -1, -1) - p_after_embeds = self.llama_model.model.embed_tokens(p_after_tokens.input_ids).expand(batch_size, -1, -1) - wrapped_img_embeds = torch.cat([p_before_embeds, img_embeds, p_after_embeds], dim=1) - wrapped_atts_img = atts_img[:, :1].expand(-1, wrapped_img_embeds.shape[1]) - return wrapped_img_embeds, wrapped_atts_img + def get_context_emb(self, prompt, img_list): + device = img_list[0].device + prompt_segs = prompt.split('') + assert len(prompt_segs) == len(img_list) + 1, "Unmatched numbers of image placeholders and images." + seg_tokens = [ + self.llama_tokenizer( + seg, return_tensors="pt", add_special_tokens=i == 0).to(device).input_ids + # only add bos to the first seg + for i, seg in enumerate(prompt_segs) + ] + seg_embs = [self.embed_tokens(seg_t) for seg_t in seg_tokens] + + mixed_embs = [emb for pair in zip(seg_embs[:-1], img_list) for emb in pair] + [seg_embs[-1]] + mixed_embs = torch.cat(mixed_embs, dim=1) + return mixed_embs + + def prompt_wrap(self, img_embeds, atts_img, prompts): + if prompts: + emb_lists = [] + if isinstance(prompts, str): + prompts = [prompts] * len(img_embeds) + + for each_img_embed, each_prompt in zip(img_embeds, prompts): + p_before, p_after = each_prompt.split('') + + p_before_tokens = self.llama_tokenizer( + p_before, return_tensors="pt", add_special_tokens=False).to(img_embeds.device) + p_after_tokens = self.llama_tokenizer( + p_after, return_tensors="pt", add_special_tokens=False).to(img_embeds.device) + p_before_embed = self.embed_tokens(p_before_tokens.input_ids) + p_after_embed = self.embed_tokens(p_after_tokens.input_ids) + wrapped_emb = torch.cat([p_before_embed, each_img_embed[None], p_after_embed], dim=1) + emb_lists.append(wrapped_emb) + emb_lens = [emb.shape[1] for emb in emb_lists] + pad_emb = self.embed_tokens(torch.tensor(self.llama_tokenizer.pad_token_id, device=img_embeds.device)) + wrapped_embs = pad_emb.expand(len(emb_lens), max(emb_lens), -1).clone() + wrapped_atts = torch.zeros([len(emb_lens), max(emb_lens)], dtype=torch.int, device=img_embeds.device) + for i, emb in enumerate(emb_lists): + wrapped_embs[i, :emb_lens[i]] = emb + wrapped_atts[i, :emb_lens[i]] = 1 + return wrapped_embs, wrapped_atts else: return img_embeds, atts_img + def concat_emb_input_output(self, input_embs, input_atts, output_embs, output_atts): + input_lens = [] + cat_embs = [] + cat_atts = [] + for i in range(input_embs.size(0)): + input_len = input_atts[i].sum() + input_lens.append(input_len) + cat_embs.append( + torch.cat([ + input_embs[i][:input_len], + output_embs[i], + input_embs[i][input_len:] + ]) + ) + cat_atts.append( + torch.cat([ + input_atts[i][:input_len], + output_atts[i], + input_atts[i][input_len:] + ]) + ) + cat_embs = torch.stack(cat_embs) + cat_atts = torch.stack(cat_atts) + return cat_embs, cat_atts, input_lens + def forward(self, samples): image = samples["image"] img_embeds, atts_img = self.encode_img(image) - if hasattr(samples, 'question_split'): # VQA dataset - print('VQA Batch') - vqa_prompt = '###Human: ' - img_embeds, atts_img = self.prompt_wrap(img_embeds, atts_img, vqa_prompt) - elif self.prompt_list: - prompt = random.choice(self.prompt_list) - img_embeds, atts_img = self.prompt_wrap(img_embeds, atts_img, prompt) + + if self.prompt_list: + instruction = random.choice(self.prompt_list) + else: + instruction = samples["instruction_input"] if "instruction_input" in samples else None + + img_embeds, atts_img = self.prompt_wrap(img_embeds, atts_img, instruction) self.llama_tokenizer.padding_side = "right" - - text = [t + self.end_sym for t in samples["text_input"]] + text = [t + self.end_sym for t in samples["answer"]] to_regress_tokens = self.llama_tokenizer( text, @@ -187,26 +286,29 @@ class MiniGPT4(Blip2Base): add_special_tokens=False ).to(image.device) - targets = to_regress_tokens.input_ids.masked_fill( - to_regress_tokens.input_ids == self.llama_tokenizer.pad_token_id, -100 - ) - - empty_targets = ( - torch.ones([atts_img.shape[0], atts_img.shape[1]+1], - dtype=torch.long).to(image.device).fill_(-100) # plus one for bos - ) - targets = torch.cat([empty_targets, targets], dim=1) - batch_size = img_embeds.shape[0] bos = torch.ones([batch_size, 1], dtype=to_regress_tokens.input_ids.dtype, device=to_regress_tokens.input_ids.device) * self.llama_tokenizer.bos_token_id - bos_embeds = self.llama_model.model.embed_tokens(bos) + bos_embeds = self.embed_tokens(bos) atts_bos = atts_img[:, :1] - to_regress_embeds = self.llama_model.model.embed_tokens(to_regress_tokens.input_ids) - inputs_embeds = torch.cat([bos_embeds, img_embeds, to_regress_embeds], dim=1) - attention_mask = torch.cat([atts_bos, atts_img, to_regress_tokens.attention_mask], dim=1) + to_regress_embeds = self.embed_tokens(to_regress_tokens.input_ids) + inputs_embeds, attention_mask, input_lens = \ + self.concat_emb_input_output(img_embeds, atts_img, to_regress_embeds, to_regress_tokens.attention_mask) + inputs_embeds = torch.cat([bos_embeds, inputs_embeds], dim=1) + attention_mask = torch.cat([atts_bos, attention_mask], dim=1) + + part_targets = to_regress_tokens.input_ids.masked_fill( + to_regress_tokens.input_ids == self.llama_tokenizer.pad_token_id, -100 + ) + targets = ( + torch.ones([inputs_embeds.shape[0], inputs_embeds.shape[1]], + dtype=torch.long).to(image.device).fill_(-100) + ) + + for i, target in enumerate(part_targets): + targets[i, input_lens[i] + 1:input_lens[i] + len(target) + 1] = target # plus 1 for bos with self.maybe_autocast(): outputs = self.llama_model( @@ -219,6 +321,13 @@ class MiniGPT4(Blip2Base): return {"loss": loss} + def embed_tokens(self, token_ids): + if hasattr(self.llama_model.base_model, 'model'): ## lora wrapped model + embeds = self.llama_model.base_model.model.model.embed_tokens(token_ids) + else: + embeds = self.llama_model.base_model.embed_tokens(token_ids) + return embeds + @classmethod def from_config(cls, cfg): vit_model = cfg.get("vit_model", "eva_clip_g") @@ -231,6 +340,7 @@ class MiniGPT4(Blip2Base): use_grad_checkpoint = cfg.get("use_grad_checkpoint", False) vit_precision = cfg.get("vit_precision", "fp16") freeze_vit = cfg.get("freeze_vit", True) + has_qformer = cfg.get("has_qformer", True) freeze_qformer = cfg.get("freeze_qformer", True) low_resource = cfg.get("low_resource", False) device_8bit = cfg.get("device_8bit", 0) @@ -240,6 +350,9 @@ class MiniGPT4(Blip2Base): max_txt_len = cfg.get("max_txt_len", 32) end_sym = cfg.get("end_sym", '\n') + lora_r = cfg.get("lora_r", 0) + lora_alpha = cfg.get("lora_alpha", 32) + model = cls( vit_model=vit_model, q_former_model=q_former_model, @@ -248,6 +361,7 @@ class MiniGPT4(Blip2Base): use_grad_checkpoint=use_grad_checkpoint, vit_precision=vit_precision, freeze_vit=freeze_vit, + has_qformer=has_qformer, freeze_qformer=freeze_qformer, num_query_token=num_query_token, llama_model=llama_model, @@ -257,6 +371,8 @@ class MiniGPT4(Blip2Base): end_sym=end_sym, low_resource=low_resource, device_8bit=device_8bit, + lora_r=lora_r, + lora_alpha=lora_alpha, ) ckpt_path = cfg.get("ckpt", "") # load weights of MiniGPT-4 diff --git a/train_configs/minigpt4_llama2_stage1_pretrain.yaml b/train_configs/minigpt4_llama2_stage1_pretrain.yaml new file mode 100644 index 0000000..6920aab --- /dev/null +++ b/train_configs/minigpt4_llama2_stage1_pretrain.yaml @@ -0,0 +1,55 @@ +model: + arch: mini_gpt4 + model_type: pretrain_llama2 + + +datasets: + laion: + vis_processor: + train: + name: "blip2_image_train" + image_size: 224 + text_processor: + train: + name: "blip_caption" + sample_ratio: 115 + cc_sbu: + vis_processor: + train: + name: "blip2_image_train" + image_size: 224 + text_processor: + train: + name: "blip_caption" + sample_ratio: 14 + + +run: + task: image_text_pretrain + # optimizer + lr_sched: "linear_warmup_cosine_lr" + init_lr: 1e-4 + min_lr: 8e-5 + warmup_lr: 1e-6 + + weight_decay: 0.05 + max_epoch: 4 + batch_size_train: 64 + batch_size_eval: 64 + num_workers: 4 + warmup_steps: 5000 + iters_per_epoch: 5000 + + seed: 42 + output_dir: "output/minigpt4_stage1_pretrain" + + amp: True + resume_ckpt_path: null + + evaluate: False + train_splits: ["train"] + + device: "cuda" + world_size: 1 + dist_url: "env://" + distributed: True \ No newline at end of file diff --git a/train_configs/minigpt4_llama2_stage2_finetune.yaml b/train_configs/minigpt4_llama2_stage2_finetune.yaml new file mode 100644 index 0000000..9a6ac2d --- /dev/null +++ b/train_configs/minigpt4_llama2_stage2_finetune.yaml @@ -0,0 +1,50 @@ +model: + arch: mini_gpt4 + model_type: pretrain_llama2 + + max_txt_len: 160 + end_sym: "" + prompt_path: "prompts/alignment.txt" + prompt_template: '[INST] {} [/INST] ' + ckpt: '/path/to/stage1/checkpoint/' + + +datasets: + cc_sbu_align: + vis_processor: + train: + name: "blip2_image_train" + image_size: 224 + text_processor: + train: + name: "blip_caption" + +run: + task: image_text_pretrain + # optimizer + lr_sched: "linear_warmup_cosine_lr" + init_lr: 3e-5 + min_lr: 1e-5 + warmup_lr: 1e-6 + + weight_decay: 0.05 + max_epoch: 5 + iters_per_epoch: 200 + batch_size_train: 12 + batch_size_eval: 12 + num_workers: 4 + warmup_steps: 200 + + seed: 42 + output_dir: "output/minigpt4_stage2_finetune" + + amp: True + resume_ckpt_path: null + + evaluate: False + train_splits: ["train"] + + device: "cuda" + world_size: 1 + dist_url: "env://" + distributed: True \ No newline at end of file diff --git a/train_configs/minigpt4_stage1_pretrain.yaml b/train_configs/minigpt4_stage1_pretrain.yaml index 044246c..4ec1597 100644 --- a/train_configs/minigpt4_stage1_pretrain.yaml +++ b/train_configs/minigpt4_stage1_pretrain.yaml @@ -1,8 +1,6 @@ model: arch: mini_gpt4 - model_type: pretrain_vicuna - freeze_vit: True - freeze_qformer: True + model_type: pretrain_vicuna0 datasets: diff --git a/train_configs/minigpt4_stage2_finetune.yaml b/train_configs/minigpt4_stage2_finetune.yaml index 1013bea..54cedb4 100644 --- a/train_configs/minigpt4_stage2_finetune.yaml +++ b/train_configs/minigpt4_stage2_finetune.yaml @@ -1,8 +1,7 @@ model: arch: mini_gpt4 - model_type: pretrain_vicuna - freeze_vit: True - freeze_qformer: True + model_type: pretrain_vicuna0 + max_txt_len: 160 end_sym: "###" prompt_path: "prompts/alignment.txt" From 871918f758761bfd8f43f367dc34ff54fcad9e71 Mon Sep 17 00:00:00 2001 From: Deyao Zhu Date: Tue, 29 Aug 2023 16:12:22 +0300 Subject: [PATCH 2/3] fix the device bug with latest transformers --- eval_configs/minigpt4_eval.yaml | 2 +- eval_configs/minigpt4_llama2_eval.yaml | 2 +- minigpt4/models/base_model.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/eval_configs/minigpt4_eval.yaml b/eval_configs/minigpt4_eval.yaml index b7298bb..b653eb7 100644 --- a/eval_configs/minigpt4_eval.yaml +++ b/eval_configs/minigpt4_eval.yaml @@ -5,7 +5,7 @@ model: end_sym: "###" low_resource: True prompt_template: '###Human: {} ###Assistant: ' - ckpt: '/home/zhud/ibex/pretrained_minigpt4.pth' + ckpt: '/path/to/checkpoint/' datasets: diff --git a/eval_configs/minigpt4_llama2_eval.yaml b/eval_configs/minigpt4_llama2_eval.yaml index d30d03a..eea99d3 100644 --- a/eval_configs/minigpt4_llama2_eval.yaml +++ b/eval_configs/minigpt4_llama2_eval.yaml @@ -5,7 +5,7 @@ model: end_sym: "" low_resource: True prompt_template: '[INST] {} [/INST] ' - ckpt: '/home/zhud/c2090/zhud/project/MiniGPT-4/minigpt4/output/minigpt4_stage2_finetune/20230826182/checkpoint_4.pth' + ckpt: '/path/to/checkpoint/' datasets: diff --git a/minigpt4/models/base_model.py b/minigpt4/models/base_model.py index 1cd2226..2a13393 100644 --- a/minigpt4/models/base_model.py +++ b/minigpt4/models/base_model.py @@ -24,7 +24,7 @@ class BaseModel(nn.Module): @property def device(self): - return list(self.parameters())[0].device + return list(self.parameters())[-1].device def load_checkpoint(self, url_or_filename): """ From ea0e81a2aa2a2b8100a46ba43c9267ce4d62dd01 Mon Sep 17 00:00:00 2001 From: ZhuDeyao Date: Fri, 1 Sep 2023 22:22:43 +0300 Subject: [PATCH 3/3] Update README.md fix the wrong link of llama2 chat --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 84d9ecb..4b31c7a 100644 --- a/README.md +++ b/README.md @@ -61,7 +61,7 @@ Download the corresponding LLM weights from the following huggingface space via | Vicuna V0 13B | Vicuna V0 7B | Llama 2 Chat 7B | :------------------------------------------------------------------------------------------------:|:----------------------------------------------------------------------------------------------:|:----------------------------------------------------------------------------------------------: - [Downlad](https://huggingface.co/Vision-CAIR/vicuna/tree/main) | [Download](https://huggingface.co/Vision-CAIR/vicuna-7b/tree/main) | [Download](https://huggingface.co/meta-llama/Llama-2-7b-chat/tree/main) + [Downlad](https://huggingface.co/Vision-CAIR/vicuna/tree/main) | [Download](https://huggingface.co/Vision-CAIR/vicuna-7b/tree/main) | [Download](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf/tree/main) Then, set the path to the vicuna weight in the model config file