diff --git a/minigpt4/models/base_model.py b/minigpt4/models/base_model.py index fa21962..ae0a3be 100644 --- a/minigpt4/models/base_model.py +++ b/minigpt4/models/base_model.py @@ -5,19 +5,26 @@ For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause """ -import contextlib -import logging import os +import logging +import contextlib +from omegaconf import OmegaConf import numpy as np import torch import torch.nn as nn -from transformers import BertTokenizer +from transformers import BertTokenizer, LlamaTokenizer +from transformers.models.llama.modeling_llama import LlamaForCausalLM +from peft import ( + LoraConfig, + get_peft_model, + prepare_model_for_int8_training, +) + from minigpt4.common.dist_utils import download_cached_file, is_dist_avail_and_initialized from minigpt4.common.utils import get_abs_path, is_url -from minigpt4.models.Qformer import BertConfig, BertLMHeadModel from minigpt4.models.eva_vit import create_eva_vit_g -from omegaconf import OmegaConf + class BaseModel(nn.Module): @@ -121,12 +128,6 @@ class BaseModel(nn.Module): else: return tot - @classmethod - def init_tokenizer(cls): - tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") - tokenizer.add_special_tokens({"bos_token": "[DEC]"}) - return tokenizer - def maybe_autocast(self, dtype=torch.float16): # if on cpu, don't use autocast # if on gpu, use autocast with dtype if provided, otherwise use torch.float16 @@ -137,33 +138,74 @@ class BaseModel(nn.Module): else: return contextlib.nullcontext() - @classmethod - def init_Qformer(cls, num_query_token, vision_width, cross_attention_freq=2): - encoder_config = BertConfig.from_pretrained("bert-base-uncased") - encoder_config.encoder_width = vision_width - # insert cross-attention layer every other block - encoder_config.add_cross_attention = True - encoder_config.cross_attention_freq = cross_attention_freq - encoder_config.query_length = num_query_token - Qformer = BertLMHeadModel(config=encoder_config) - query_tokens = nn.Parameter( - torch.zeros(1, num_query_token, encoder_config.hidden_size) - ) - query_tokens.data.normal_(mean=0.0, std=encoder_config.initializer_range) - return Qformer, query_tokens - @classmethod def init_vision_encoder( - cls, model_name, img_size, drop_path_rate, use_grad_checkpoint, precision + cls, model_name, img_size, drop_path_rate, use_grad_checkpoint, precision, freeze ): + logging.info('Loading VIT') + assert model_name == "eva_clip_g", "vit model must be eva_clip_g for current version of MiniGPT-4" + if not freeze: + precision = "fp32" # fp16 is not for training + visual_encoder = create_eva_vit_g( img_size, drop_path_rate, use_grad_checkpoint, precision ) ln_vision = LayerNorm(visual_encoder.num_features) + + if freeze: + for name, param in visual_encoder.named_parameters(): + param.requires_grad = False + visual_encoder = visual_encoder.eval() + visual_encoder.train = disabled_train + for name, param in ln_vision.named_parameters(): + param.requires_grad = False + ln_vision = ln_vision.eval() + ln_vision.train = disabled_train + logging.info("freeze vision encoder") + + logging.info('Loading VIT Done') return visual_encoder, ln_vision + def init_llm(cls, llama_model_path, low_resource=False, low_res_device=0, lora_r=0, + **lora_kargs): + logging.info('Loading LLAMA') + llama_tokenizer = LlamaTokenizer.from_pretrained(llama_model_path, use_fast=False) + llama_tokenizer.pad_token = "$$" + + if low_resource: + llama_model = LlamaForCausalLM.from_pretrained( + llama_model_path, + torch_dtype=torch.float16, + load_in_8bit=True, + device_map={'': low_res_device} + ) + else: + llama_model = LlamaForCausalLM.from_pretrained( + llama_model_path, + torch_dtype=torch.float16, + ) + + if lora_r > 0: + llama_model = prepare_model_for_int8_training(llama_model) + loraconfig = LoraConfig( + r=lora_r, + bias="none", + task_type="CAUSAL_LM", + **lora_kargs + ) + llama_model = get_peft_model(llama_model, loraconfig) + + llama_model.print_trainable_parameters() + + else: + for name, param in llama_model.named_parameters(): + param.requires_grad = False + logging.info('Loading LLAMA Done') + return llama_model, llama_tokenizer + + def load_from_pretrained(self, url_or_filename): if is_url(url_or_filename): cached_file = download_cached_file( @@ -185,136 +227,6 @@ class BaseModel(nn.Module): return msg - -class BaseEncoder(nn.Module): - """ - Base class for primitive encoders, such as ViT, TimeSformer, etc. - """ - - def __init__(self): - super().__init__() - - def forward_features(self, samples, **kwargs): - raise NotImplementedError - - @property - def device(self): - return list(self.parameters())[0].device - - -class SharedQueueMixin: - @torch.no_grad() - def _dequeue_and_enqueue(self, image_feat, text_feat, idxs=None): - # gather keys before updating queue - image_feats = concat_all_gather(image_feat) - text_feats = concat_all_gather(text_feat) - - batch_size = image_feats.shape[0] - - ptr = int(self.queue_ptr) - assert self.queue_size % batch_size == 0 # for simplicity - - # replace the keys at ptr (dequeue and enqueue) - self.image_queue[:, ptr : ptr + batch_size] = image_feats.T - self.text_queue[:, ptr : ptr + batch_size] = text_feats.T - - if idxs is not None: - idxs = concat_all_gather(idxs) - self.idx_queue[:, ptr : ptr + batch_size] = idxs.T - - ptr = (ptr + batch_size) % self.queue_size # move pointer - self.queue_ptr[0] = ptr - - -class MomentumDistilationMixin: - @torch.no_grad() - def copy_params(self): - for model_pair in self.model_pairs: - for param, param_m in zip( - model_pair[0].parameters(), model_pair[1].parameters() - ): - param_m.data.copy_(param.data) # initialize - param_m.requires_grad = False # not update by gradient - - @torch.no_grad() - def _momentum_update(self): - for model_pair in self.model_pairs: - for param, param_m in zip( - model_pair[0].parameters(), model_pair[1].parameters() - ): - param_m.data = param_m.data * self.momentum + param.data * ( - 1.0 - self.momentum - ) - - -class GatherLayer(torch.autograd.Function): - """ - Gather tensors from all workers with support for backward propagation: - This implementation does not cut the gradients as torch.distributed.all_gather does. - """ - - @staticmethod - def forward(ctx, x): - output = [ - torch.zeros_like(x) for _ in range(torch.distributed.get_world_size()) - ] - torch.distributed.all_gather(output, x) - return tuple(output) - - @staticmethod - def backward(ctx, *grads): - all_gradients = torch.stack(grads) - torch.distributed.all_reduce(all_gradients) - return all_gradients[torch.distributed.get_rank()] - - -def all_gather_with_grad(tensors): - """ - Performs all_gather operation on the provided tensors. - Graph remains connected for backward grad computation. - """ - # Queue the gathered tensors - world_size = torch.distributed.get_world_size() - # There is no need for reduction in the single-proc case - if world_size == 1: - return tensors - - # tensor_all = GatherLayer.apply(tensors) - tensor_all = GatherLayer.apply(tensors) - - return torch.cat(tensor_all, dim=0) - - -@torch.no_grad() -def concat_all_gather(tensor): - """ - Performs all_gather operation on the provided tensors. - *** Warning ***: torch.distributed.all_gather has no gradient. - """ - # if use distributed training - if not is_dist_avail_and_initialized(): - return tensor - - tensors_gather = [ - torch.ones_like(tensor) for _ in range(torch.distributed.get_world_size()) - ] - torch.distributed.all_gather(tensors_gather, tensor, async_op=False) - - output = torch.cat(tensors_gather, dim=0) - return output - - -def tile(x, dim, n_tile): - init_dim = x.size(dim) - repeat_idx = [1] * x.dim() - repeat_idx[dim] = n_tile - x = x.repeat(*(repeat_idx)) - order_index = torch.LongTensor( - np.concatenate([init_dim * np.arange(n_tile) + i for i in range(init_dim)]) - ) - return torch.index_select(x, dim, order_index.to(x.device)) - - def disabled_train(self, mode=True): """Overwrite model.train with this function to make sure train/eval mode does not change anymore.""" diff --git a/minigpt4/models/mini_gpt4.py b/minigpt4/models/mini_gpt4.py index d1ff071..2fdf7da 100644 --- a/minigpt4/models/mini_gpt4.py +++ b/minigpt4/models/mini_gpt4.py @@ -8,6 +8,7 @@ import torch.nn as nn from minigpt4.common.registry import registry from minigpt4.models.base_model import BaseModel, disabled_train from minigpt4.models.minigpt_base import MiniGPTBase +from minigpt4.models.Qformer import BertConfig, BertLMHeadModel from transformers.models.llama.modeling_llama import LlamaForCausalLM from transformers import LlamaTokenizer @@ -31,6 +32,99 @@ class MiniGPT4(MiniGPTBase): "pretrain_llama2": "configs/models/minigpt4_llama2.yaml", } + def __init__( + self, + vit_model="eva_clip_g", + q_former_model="https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/blip2_pretrained_flant5xxl.pth", + img_size=224, + drop_path_rate=0, + use_grad_checkpoint=False, + vit_precision="fp16", + freeze_vit=True, + has_qformer=True, + freeze_qformer=True, + num_query_token=32, + llama_model="", + prompt_path="", + prompt_template="", + max_txt_len=32, + end_sym='\n', + low_resource=False, # use 8 bit and put vit in cpu + device_8bit=0, # the device of 8bit model should be set when loading and cannot be changed anymore. + ): + super().__init__( + vit_model=vit_model, + img_size=img_size, + drop_path_rate=drop_path_rate, + use_grad_checkpoint=use_grad_checkpoint, + vit_precision=vit_precision, + freeze_vit=freeze_vit, + llama_model=llama_model, + max_txt_len=max_txt_len, + end_sym=end_sym, + low_resource=low_resource, + device_8bit=device_8bit, + ) + + self.has_qformer = has_qformer + if self.has_qformer: + print('Loading Q-Former') + self.Qformer, self.query_tokens = self.init_Qformer( + num_query_token, self.visual_encoder.num_features, freeze_qformer + ) + self.load_from_pretrained(url_or_filename=q_former_model) # load q-former weights here + + img_f_dim = self.Qformer.config.hidden_size + print('Loading Q-Former Done') + else: + img_f_dim = self.visual_encoder.num_features * 4 + print('Do not use Q-Former here.') + + self.llama_proj = nn.Linear( + img_f_dim, self.llama_model.config.hidden_size + ) + + if prompt_path: + with open(prompt_path, 'r') as f: + raw_prompts = f.read().splitlines() + filted_prompts = [raw_prompt for raw_prompt in raw_prompts if "" in raw_prompt] + self.prompt_list = [prompt_template.format(p) for p in filted_prompts] + print('Load {} training prompts'.format(len(self.prompt_list))) + print('Prompt Example \n{}'.format(random.choice(self.prompt_list))) + else: + self.prompt_list = [] + + @classmethod + def init_Qformer(cls, num_query_token, vision_width, freeze): + encoder_config = BertConfig.from_pretrained("bert-base-uncased") + encoder_config.encoder_width = vision_width + # insert cross-attention layer every other block + encoder_config.add_cross_attention = True + encoder_config.cross_attention_freq = 2 + encoder_config.query_length = num_query_token + Qformer = BertLMHeadModel(config=encoder_config) + query_tokens = nn.Parameter( + torch.zeros(1, num_query_token, encoder_config.hidden_size) + ) + query_tokens.data.normal_(mean=0.0, std=encoder_config.initializer_range) + + Qformer.cls = None + Qformer.bert.embeddings.word_embeddings = None + Qformer.bert.embeddings.position_embeddings = None + for layer in Qformer.bert.encoder.layer: + layer.output = None + layer.intermediate = None + + if freeze: + for name, param in Qformer.named_parameters(): + param.requires_grad = False + Qformer = Qformer.eval() + Qformer.train = disabled_train + query_tokens.requires_grad = False + logging.info("freeze Qformer") + + return Qformer, query_tokens + def encode_img(self, image): device = image.device @@ -82,9 +176,6 @@ class MiniGPT4(MiniGPTBase): max_txt_len = cfg.get("max_txt_len", 32) end_sym = cfg.get("end_sym", '\n') - lora_r = cfg.get("lora_r", 0) - lora_alpha = cfg.get("lora_alpha", 32) - model = cls( vit_model=vit_model, q_former_model=q_former_model, @@ -103,8 +194,6 @@ class MiniGPT4(MiniGPTBase): end_sym=end_sym, low_resource=low_resource, device_8bit=device_8bit, - lora_r=lora_r, - lora_alpha=lora_alpha, ) ckpt_path = cfg.get("ckpt", "") # load weights of MiniGPT-4 diff --git a/minigpt4/models/minigpt_base.py b/minigpt4/models/minigpt_base.py index 7c4008d..0a8c848 100644 --- a/minigpt4/models/minigpt_base.py +++ b/minigpt4/models/minigpt_base.py @@ -13,9 +13,7 @@ from transformers import LlamaTokenizer from peft import ( LoraConfig, get_peft_model, - get_peft_model_state_dict, prepare_model_for_int8_training, - set_peft_model_state_dict, ) @@ -27,131 +25,41 @@ class MiniGPTBase(BaseModel): def __init__( self, vit_model="eva_clip_g", - q_former_model="https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/blip2_pretrained_flant5xxl.pth", img_size=224, drop_path_rate=0, use_grad_checkpoint=False, vit_precision="fp16", freeze_vit=True, - has_qformer=True, - freeze_qformer=True, - num_query_token=32, llama_model="", - prompt_path="", - prompt_template="", max_txt_len=32, end_sym='\n', low_resource=False, # use 8 bit and put vit in cpu device_8bit=0, # the device of 8bit model should be set when loading and cannot be changed anymore. - lora_r=0, + lora_r=0, # lora_r means lora is not used lora_target_modules=["q_proj", "v_proj"], lora_alpha=16, lora_dropout=0.05, ): super().__init__() - self.tokenizer = self.init_tokenizer() - self.low_resource = low_resource + self.llama_model, self.llama_tokenizer = self.init_llm( + llama_model_path=llama_model, + low_resource=low_resource, + low_res_device=device_8bit, + lora_r=lora_r, + lora_target_modules=lora_target_modules, + lora_alpha=lora_alpha, + lora_dropout=lora_dropout, + ) - print('Loading VIT') self.visual_encoder, self.ln_vision = self.init_vision_encoder( - vit_model, img_size, drop_path_rate, use_grad_checkpoint, vit_precision + vit_model, img_size, drop_path_rate, use_grad_checkpoint, vit_precision, freeze_vit ) - if freeze_vit: - for name, param in self.visual_encoder.named_parameters(): - param.requires_grad = False - self.visual_encoder = self.visual_encoder.eval() - self.visual_encoder.train = disabled_train - for name, param in self.ln_vision.named_parameters(): - param.requires_grad = False - self.ln_vision = self.ln_vision.eval() - self.ln_vision.train = disabled_train - logging.info("freeze vision encoder") - print('Loading VIT Done') - self.has_qformer = has_qformer - if self.has_qformer: - print('Loading Q-Former') - self.Qformer, self.query_tokens = self.init_Qformer( - num_query_token, self.visual_encoder.num_features - ) - self.Qformer.cls = None - self.Qformer.bert.embeddings.word_embeddings = None - self.Qformer.bert.embeddings.position_embeddings = None - for layer in self.Qformer.bert.encoder.layer: - layer.output = None - layer.intermediate = None - self.load_from_pretrained(url_or_filename=q_former_model) - - if freeze_qformer: - for name, param in self.Qformer.named_parameters(): - param.requires_grad = False - self.Qformer = self.Qformer.eval() - self.Qformer.train = disabled_train - self.query_tokens.requires_grad = False - logging.info("freeze Qformer") - - img_f_dim = self.Qformer.config.hidden_size - print('Loading Q-Former Done') - else: - img_f_dim = self.visual_encoder.num_features * 4 - print('Do not use Q-Former here.') - - print('Loading LLAMA') - self.llama_tokenizer = LlamaTokenizer.from_pretrained(llama_model, use_fast=False) - self.llama_tokenizer.pad_token = "$$" - - if self.low_resource: - self.llama_model = LlamaForCausalLM.from_pretrained( - llama_model, - torch_dtype=torch.float16, - load_in_8bit=True, - device_map={'': device_8bit} - ) - else: - self.llama_model = LlamaForCausalLM.from_pretrained( - llama_model, - torch_dtype=torch.float16, - ) - - if lora_r > 0: - self.llama_model = prepare_model_for_int8_training(self.llama_model) - loraconfig = LoraConfig( - r=lora_r, - lora_alpha=lora_alpha, - target_modules=lora_target_modules, - lora_dropout=lora_dropout, - bias="none", - task_type="CAUSAL_LM" - ) - self.llama_model = get_peft_model(self.llama_model, loraconfig) - - # if ckpt_path: - # print('load the llm under lora') - # ckpt = torch.load(ckpt_path) - # set_peft_model_state_dict(self.llama_model,ckpt) - self.llama_model.print_trainable_parameters() - - else: - for name, param in self.llama_model.named_parameters(): - param.requires_grad = False - print('Loading LLAMA Done') - - self.llama_proj = nn.Linear( - img_f_dim, self.llama_model.config.hidden_size - ) self.max_txt_len = max_txt_len self.end_sym = end_sym - if prompt_path: - with open(prompt_path, 'r') as f: - raw_prompts = f.read().splitlines() - filted_prompts = [raw_prompt for raw_prompt in raw_prompts if "" in raw_prompt] - self.prompt_list = [prompt_template.format(p) for p in filted_prompts] - print('Load {} training prompts'.format(len(self.prompt_list))) - print('Prompt Example \n{}'.format(random.choice(self.prompt_list))) - else: - self.prompt_list = [] + self.prompt_list = [] def vit_to_cpu(self): self.ln_vision.to("cpu")