diff --git a/eval_configs/minigpt4_eval.yaml b/eval_configs/minigpt4_eval.yaml index b653eb7..15af837 100644 --- a/eval_configs/minigpt4_eval.yaml +++ b/eval_configs/minigpt4_eval.yaml @@ -5,7 +5,8 @@ model: end_sym: "###" low_resource: True prompt_template: '###Human: {} ###Assistant: ' - ckpt: '/path/to/checkpoint/' + ckpt: '/home/zhud/weights/minigpt4/prerained_minigpt4_7b.pth' + llama_model: "/home/zhud/weights/vicuna-7b" datasets: diff --git a/eval_configs/minigpt4_llama2_eval.yaml b/eval_configs/minigpt4_llama2_eval.yaml index eea99d3..62709e1 100644 --- a/eval_configs/minigpt4_llama2_eval.yaml +++ b/eval_configs/minigpt4_llama2_eval.yaml @@ -5,7 +5,8 @@ model: end_sym: "" low_resource: True prompt_template: '[INST] {} [/INST] ' - ckpt: '/path/to/checkpoint/' + ckpt: '/home/zhud/weights/minigpt4/pretrained_minigpt4_llama2_7b.pth' + llama_model: "/ibex/project/c2133/llama_v2/llama-2-7b-chat-pytorch_update" datasets: diff --git a/minigpt4/models/__init__.py b/minigpt4/models/__init__.py index 54acd24..bc4c084 100644 --- a/minigpt4/models/__init__.py +++ b/minigpt4/models/__init__.py @@ -11,7 +11,6 @@ from omegaconf import OmegaConf from minigpt4.common.registry import registry from minigpt4.models.base_model import BaseModel -from minigpt4.models.blip2 import Blip2Base from minigpt4.models.mini_gpt4 import MiniGPT4 from minigpt4.processors.base_processor import BaseProcessor @@ -19,7 +18,6 @@ from minigpt4.processors.base_processor import BaseProcessor __all__ = [ "load_model", "BaseModel", - "Blip2Base", "MiniGPT4", ] diff --git a/minigpt4/models/base_model.py b/minigpt4/models/base_model.py index 2a13393..fa21962 100644 --- a/minigpt4/models/base_model.py +++ b/minigpt4/models/base_model.py @@ -5,14 +5,18 @@ For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause """ +import contextlib import logging import os import numpy as np import torch import torch.nn as nn +from transformers import BertTokenizer from minigpt4.common.dist_utils import download_cached_file, is_dist_avail_and_initialized from minigpt4.common.utils import get_abs_path, is_url +from minigpt4.models.Qformer import BertConfig, BertLMHeadModel +from minigpt4.models.eva_vit import create_eva_vit_g from omegaconf import OmegaConf @@ -117,6 +121,70 @@ class BaseModel(nn.Module): else: return tot + @classmethod + def init_tokenizer(cls): + tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") + tokenizer.add_special_tokens({"bos_token": "[DEC]"}) + return tokenizer + + def maybe_autocast(self, dtype=torch.float16): + # if on cpu, don't use autocast + # if on gpu, use autocast with dtype if provided, otherwise use torch.float16 + enable_autocast = self.device != torch.device("cpu") + + if enable_autocast: + return torch.cuda.amp.autocast(dtype=dtype) + else: + return contextlib.nullcontext() + + @classmethod + def init_Qformer(cls, num_query_token, vision_width, cross_attention_freq=2): + encoder_config = BertConfig.from_pretrained("bert-base-uncased") + encoder_config.encoder_width = vision_width + # insert cross-attention layer every other block + encoder_config.add_cross_attention = True + encoder_config.cross_attention_freq = cross_attention_freq + encoder_config.query_length = num_query_token + Qformer = BertLMHeadModel(config=encoder_config) + query_tokens = nn.Parameter( + torch.zeros(1, num_query_token, encoder_config.hidden_size) + ) + query_tokens.data.normal_(mean=0.0, std=encoder_config.initializer_range) + return Qformer, query_tokens + + @classmethod + def init_vision_encoder( + cls, model_name, img_size, drop_path_rate, use_grad_checkpoint, precision + ): + assert model_name == "eva_clip_g", "vit model must be eva_clip_g for current version of MiniGPT-4" + visual_encoder = create_eva_vit_g( + img_size, drop_path_rate, use_grad_checkpoint, precision + ) + + ln_vision = LayerNorm(visual_encoder.num_features) + return visual_encoder, ln_vision + + def load_from_pretrained(self, url_or_filename): + if is_url(url_or_filename): + cached_file = download_cached_file( + url_or_filename, check_hash=False, progress=True + ) + checkpoint = torch.load(cached_file, map_location="cpu") + elif os.path.isfile(url_or_filename): + checkpoint = torch.load(url_or_filename, map_location="cpu") + else: + raise RuntimeError("checkpoint url or path is invalid") + + state_dict = checkpoint["model"] + + msg = self.load_state_dict(state_dict, strict=False) + + # logging.info("Missing keys {}".format(msg.missing_keys)) + logging.info("load checkpoint from %s" % url_or_filename) + + return msg + + class BaseEncoder(nn.Module): """ @@ -245,3 +313,23 @@ def tile(x, dim, n_tile): np.concatenate([init_dim * np.arange(n_tile) + i for i in range(init_dim)]) ) return torch.index_select(x, dim, order_index.to(x.device)) + + +def disabled_train(self, mode=True): + """Overwrite model.train with this function to make sure train/eval mode + does not change anymore.""" + return self + + +class LayerNorm(nn.LayerNorm): + """Subclass torch's LayerNorm to handle fp16.""" + + def forward(self, x: torch.Tensor): + orig_type = x.dtype + ret = super().forward(x.type(torch.float32)) + return ret.type(orig_type) + + + + + diff --git a/minigpt4/models/blip2.py b/minigpt4/models/blip2.py deleted file mode 100644 index ee4a9dc..0000000 --- a/minigpt4/models/blip2.py +++ /dev/null @@ -1,221 +0,0 @@ -""" - Copyright (c) 2023, salesforce.com, inc. - All rights reserved. - SPDX-License-Identifier: BSD-3-Clause - For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause -""" -import contextlib -import logging -import os -import time -import datetime - -import torch -import torch.nn as nn -import torch.distributed as dist -import torch.nn.functional as F - -import minigpt4.common.dist_utils as dist_utils -from minigpt4.common.dist_utils import download_cached_file -from minigpt4.common.utils import is_url -from minigpt4.common.logger import MetricLogger -from minigpt4.models.base_model import BaseModel -from minigpt4.models.Qformer import BertConfig, BertLMHeadModel -from minigpt4.models.eva_vit import create_eva_vit_g -from transformers import BertTokenizer - - -class Blip2Base(BaseModel): - @classmethod - def init_tokenizer(cls): - tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") - tokenizer.add_special_tokens({"bos_token": "[DEC]"}) - return tokenizer - - def maybe_autocast(self, dtype=torch.float16): - # if on cpu, don't use autocast - # if on gpu, use autocast with dtype if provided, otherwise use torch.float16 - enable_autocast = self.device != torch.device("cpu") - - if enable_autocast: - return torch.cuda.amp.autocast(dtype=dtype) - else: - return contextlib.nullcontext() - - @classmethod - def init_Qformer(cls, num_query_token, vision_width, cross_attention_freq=2): - encoder_config = BertConfig.from_pretrained("bert-base-uncased") - encoder_config.encoder_width = vision_width - # insert cross-attention layer every other block - encoder_config.add_cross_attention = True - encoder_config.cross_attention_freq = cross_attention_freq - encoder_config.query_length = num_query_token - Qformer = BertLMHeadModel(config=encoder_config) - query_tokens = nn.Parameter( - torch.zeros(1, num_query_token, encoder_config.hidden_size) - ) - query_tokens.data.normal_(mean=0.0, std=encoder_config.initializer_range) - return Qformer, query_tokens - - @classmethod - def init_vision_encoder( - cls, model_name, img_size, drop_path_rate, use_grad_checkpoint, precision - ): - assert model_name == "eva_clip_g", "vit model must be eva_clip_g for current version of MiniGPT-4" - visual_encoder = create_eva_vit_g( - img_size, drop_path_rate, use_grad_checkpoint, precision - ) - - ln_vision = LayerNorm(visual_encoder.num_features) - return visual_encoder, ln_vision - - def load_from_pretrained(self, url_or_filename): - if is_url(url_or_filename): - cached_file = download_cached_file( - url_or_filename, check_hash=False, progress=True - ) - checkpoint = torch.load(cached_file, map_location="cpu") - elif os.path.isfile(url_or_filename): - checkpoint = torch.load(url_or_filename, map_location="cpu") - else: - raise RuntimeError("checkpoint url or path is invalid") - - state_dict = checkpoint["model"] - - msg = self.load_state_dict(state_dict, strict=False) - - # logging.info("Missing keys {}".format(msg.missing_keys)) - logging.info("load checkpoint from %s" % url_or_filename) - - return msg - - -def disabled_train(self, mode=True): - """Overwrite model.train with this function to make sure train/eval mode - does not change anymore.""" - return self - - -class LayerNorm(nn.LayerNorm): - """Subclass torch's LayerNorm to handle fp16.""" - - def forward(self, x: torch.Tensor): - orig_type = x.dtype - ret = super().forward(x.type(torch.float32)) - return ret.type(orig_type) - - -def compute_sim_matrix(model, data_loader, **kwargs): - k_test = kwargs.pop("k_test") - - metric_logger = MetricLogger(delimiter=" ") - header = "Evaluation:" - - logging.info("Computing features for evaluation...") - start_time = time.time() - - texts = data_loader.dataset.text - num_text = len(texts) - text_bs = 256 - text_ids = [] - text_embeds = [] - text_atts = [] - for i in range(0, num_text, text_bs): - text = texts[i : min(num_text, i + text_bs)] - text_input = model.tokenizer( - text, - padding="max_length", - truncation=True, - max_length=35, - return_tensors="pt", - ).to(model.device) - text_feat = model.forward_text(text_input) - text_embed = F.normalize(model.text_proj(text_feat)) - text_embeds.append(text_embed) - text_ids.append(text_input.input_ids) - text_atts.append(text_input.attention_mask) - - text_embeds = torch.cat(text_embeds, dim=0) - text_ids = torch.cat(text_ids, dim=0) - text_atts = torch.cat(text_atts, dim=0) - - vit_feats = [] - image_embeds = [] - for samples in data_loader: - image = samples["image"] - - image = image.to(model.device) - image_feat, vit_feat = model.forward_image(image) - image_embed = model.vision_proj(image_feat) - image_embed = F.normalize(image_embed, dim=-1) - - vit_feats.append(vit_feat.cpu()) - image_embeds.append(image_embed) - - vit_feats = torch.cat(vit_feats, dim=0) - image_embeds = torch.cat(image_embeds, dim=0) - - sims_matrix = [] - for image_embed in image_embeds: - sim_q2t = image_embed @ text_embeds.t() - sim_i2t, _ = sim_q2t.max(0) - sims_matrix.append(sim_i2t) - sims_matrix = torch.stack(sims_matrix, dim=0) - - score_matrix_i2t = torch.full( - (len(data_loader.dataset.image), len(texts)), -100.0 - ).to(model.device) - - num_tasks = dist_utils.get_world_size() - rank = dist_utils.get_rank() - step = sims_matrix.size(0) // num_tasks + 1 - start = rank * step - end = min(sims_matrix.size(0), start + step) - - for i, sims in enumerate( - metric_logger.log_every(sims_matrix[start:end], 50, header) - ): - topk_sim, topk_idx = sims.topk(k=k_test, dim=0) - image_inputs = vit_feats[start + i].repeat(k_test, 1, 1).to(model.device) - score = model.compute_itm( - image_inputs=image_inputs, - text_ids=text_ids[topk_idx], - text_atts=text_atts[topk_idx], - ).float() - score_matrix_i2t[start + i, topk_idx] = score + topk_sim - - sims_matrix = sims_matrix.t() - score_matrix_t2i = torch.full( - (len(texts), len(data_loader.dataset.image)), -100.0 - ).to(model.device) - - step = sims_matrix.size(0) // num_tasks + 1 - start = rank * step - end = min(sims_matrix.size(0), start + step) - - for i, sims in enumerate( - metric_logger.log_every(sims_matrix[start:end], 50, header) - ): - topk_sim, topk_idx = sims.topk(k=k_test, dim=0) - image_inputs = vit_feats[topk_idx.cpu()].to(model.device) - score = model.compute_itm( - image_inputs=image_inputs, - text_ids=text_ids[start + i].repeat(k_test, 1), - text_atts=text_atts[start + i].repeat(k_test, 1), - ).float() - score_matrix_t2i[start + i, topk_idx] = score + topk_sim - - if dist_utils.is_dist_avail_and_initialized(): - dist.barrier() - torch.distributed.all_reduce( - score_matrix_i2t, op=torch.distributed.ReduceOp.SUM - ) - torch.distributed.all_reduce( - score_matrix_t2i, op=torch.distributed.ReduceOp.SUM - ) - - total_time = time.time() - start_time - total_time_str = str(datetime.timedelta(seconds=int(total_time))) - logging.info("Evaluation time {}".format(total_time_str)) - - return score_matrix_i2t.cpu().numpy(), score_matrix_t2i.cpu().numpy() diff --git a/minigpt4/models/blip2_outputs.py b/minigpt4/models/blip2_outputs.py deleted file mode 100644 index e8722b1..0000000 --- a/minigpt4/models/blip2_outputs.py +++ /dev/null @@ -1,110 +0,0 @@ -""" - Copyright (c) 2022, salesforce.com, inc. - All rights reserved. - SPDX-License-Identifier: BSD-3-Clause - For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause -""" - -from dataclasses import dataclass -from typing import Optional - -import torch -from transformers.modeling_outputs import ( - ModelOutput, - BaseModelOutputWithPoolingAndCrossAttentions, - CausalLMOutputWithCrossAttentions, -) - - -@dataclass -class BlipSimilarity(ModelOutput): - sim_i2t: torch.FloatTensor = None - sim_t2i: torch.FloatTensor = None - - sim_i2t_m: Optional[torch.FloatTensor] = None - sim_t2i_m: Optional[torch.FloatTensor] = None - - sim_i2t_targets: Optional[torch.FloatTensor] = None - sim_t2i_targets: Optional[torch.FloatTensor] = None - - -@dataclass -class BlipIntermediateOutput(ModelOutput): - """ - Data class for intermediate outputs of BLIP models. - - image_embeds (torch.FloatTensor): Image embeddings, shape (batch_size, num_patches, embed_dim). - text_embeds (torch.FloatTensor): Text embeddings, shape (batch_size, seq_len, embed_dim). - - image_embeds_m (torch.FloatTensor): Image embeddings from momentum visual encoder, shape (batch_size, num_patches, embed_dim). - text_embeds_m (torch.FloatTensor): Text embeddings from momentum text encoder, shape (batch_size, seq_len, embed_dim). - - encoder_output (BaseModelOutputWithPoolingAndCrossAttentions): output from the image-grounded text encoder. - encoder_output_neg (BaseModelOutputWithPoolingAndCrossAttentions): output from the image-grounded text encoder for negative pairs. - - decoder_output (CausalLMOutputWithCrossAttentions): output from the image-grounded text decoder. - decoder_labels (torch.LongTensor): labels for the captioning loss. - - itm_logits (torch.FloatTensor): logits for the image-text matching loss, shape (batch_size * 3, 2). - itm_labels (torch.LongTensor): labels for the image-text matching loss, shape (batch_size * 3,) - - """ - - # uni-modal features - image_embeds: torch.FloatTensor = None - text_embeds: Optional[torch.FloatTensor] = None - - image_embeds_m: Optional[torch.FloatTensor] = None - text_embeds_m: Optional[torch.FloatTensor] = None - - # intermediate outputs of multimodal encoder - encoder_output: Optional[BaseModelOutputWithPoolingAndCrossAttentions] = None - encoder_output_neg: Optional[BaseModelOutputWithPoolingAndCrossAttentions] = None - - itm_logits: Optional[torch.FloatTensor] = None - itm_labels: Optional[torch.LongTensor] = None - - # intermediate outputs of multimodal decoder - decoder_output: Optional[CausalLMOutputWithCrossAttentions] = None - decoder_labels: Optional[torch.LongTensor] = None - - -@dataclass -class BlipOutput(ModelOutput): - # some finetuned models (e.g. BlipVQA) do not compute similarity, thus optional. - sims: Optional[BlipSimilarity] = None - - intermediate_output: BlipIntermediateOutput = None - - loss: Optional[torch.FloatTensor] = None - - loss_itc: Optional[torch.FloatTensor] = None - - loss_itm: Optional[torch.FloatTensor] = None - - loss_lm: Optional[torch.FloatTensor] = None - - -@dataclass -class BlipOutputFeatures(ModelOutput): - """ - Data class of features from BlipFeatureExtractor. - - Args: - image_embeds: (torch.FloatTensor) of shape (batch_size, num_patches+1, embed_dim), optional - image_features: (torch.FloatTensor) of shape (batch_size, num_patches+1, feature_dim), optional - text_embeds: (torch.FloatTensor) of shape (batch_size, sequence_length+1, embed_dim), optional - text_features: (torch.FloatTensor) of shape (batch_size, sequence_length+1, feature_dim), optional - - The first embedding or feature is for the [CLS] token. - - Features are obtained by projecting the corresponding embedding into a normalized low-dimensional space. - """ - - image_embeds: Optional[torch.FloatTensor] = None - image_embeds_proj: Optional[torch.FloatTensor] = None - - text_embeds: Optional[torch.FloatTensor] = None - text_embeds_proj: Optional[torch.FloatTensor] = None - - multimodal_embeds: Optional[torch.FloatTensor] = None diff --git a/minigpt4/models/mini_gpt4.py b/minigpt4/models/mini_gpt4.py index faed3d5..2575f32 100644 --- a/minigpt4/models/mini_gpt4.py +++ b/minigpt4/models/mini_gpt4.py @@ -6,7 +6,7 @@ from torch.cuda.amp import autocast as autocast import torch.nn as nn from minigpt4.common.registry import registry -from minigpt4.models.blip2 import Blip2Base, disabled_train +from minigpt4.models.base_model import BaseModel, disabled_train from transformers.models.llama.modeling_llama import LlamaForCausalLM from transformers import LlamaTokenizer @@ -20,9 +20,9 @@ from peft import ( @registry.register_model("mini_gpt4") -class MiniGPT4(Blip2Base): +class MiniGPT4(BaseModel): """ - BLIP2 GPT-LLAMA model. + MiniGPT-4 model """ PRETRAINED_MODEL_CONFIG_DICT = {