""" Copyright (c) 2022, salesforce.com, inc. All rights reserved. SPDX-License-Identifier: BSD-3-Clause For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause """ import contextlib import logging import os import numpy as np import torch import torch.nn as nn from transformers import BertTokenizer from minigpt4.common.dist_utils import download_cached_file, is_dist_avail_and_initialized from minigpt4.common.utils import get_abs_path, is_url from minigpt4.models.Qformer import BertConfig, BertLMHeadModel from minigpt4.models.eva_vit import create_eva_vit_g from omegaconf import OmegaConf class BaseModel(nn.Module): """Base class for models.""" def __init__(self): super().__init__() @property def device(self): return list(self.parameters())[-1].device def load_checkpoint(self, url_or_filename): """ Load from a finetuned checkpoint. This should expect no mismatch in the model keys and the checkpoint keys. """ if is_url(url_or_filename): cached_file = download_cached_file( url_or_filename, check_hash=False, progress=True ) checkpoint = torch.load(cached_file, map_location="cpu") elif os.path.isfile(url_or_filename): checkpoint = torch.load(url_or_filename, map_location="cpu") else: raise RuntimeError("checkpoint url or path is invalid") if "model" in checkpoint.keys(): state_dict = checkpoint["model"] else: state_dict = checkpoint msg = self.load_state_dict(state_dict, strict=False) logging.info("Missing keys {}".format(msg.missing_keys)) logging.info("load checkpoint from %s" % url_or_filename) return msg @classmethod def from_pretrained(cls, model_type): """ Build a pretrained model from default configuration file, specified by model_type. Args: - model_type (str): model type, specifying architecture and checkpoints. Returns: - model (nn.Module): pretrained or finetuned model, depending on the configuration. """ model_cfg = OmegaConf.load(cls.default_config_path(model_type)).model model = cls.from_config(model_cfg) return model @classmethod def default_config_path(cls, model_type): assert ( model_type in cls.PRETRAINED_MODEL_CONFIG_DICT ), "Unknown model type {}".format(model_type) return get_abs_path(cls.PRETRAINED_MODEL_CONFIG_DICT[model_type]) def load_checkpoint_from_config(self, cfg, **kwargs): """ Load checkpoint as specified in the config file. If load_finetuned is True, load the finetuned model; otherwise, load the pretrained model. When loading the pretrained model, each task-specific architecture may define their own load_from_pretrained() method. """ load_finetuned = cfg.get("load_finetuned", True) if load_finetuned: finetune_path = cfg.get("finetuned", None) assert ( finetune_path is not None ), "Found load_finetuned is True, but finetune_path is None." self.load_checkpoint(url_or_filename=finetune_path) else: # load pre-trained weights pretrain_path = cfg.get("pretrained", None) assert "Found load_finetuned is False, but pretrain_path is None." self.load_from_pretrained(url_or_filename=pretrain_path, **kwargs) def before_evaluation(self, **kwargs): pass def show_n_params(self, return_str=True): tot = 0 for p in self.parameters(): w = 1 for x in p.shape: w *= x tot += w if return_str: if tot >= 1e6: return "{:.1f}M".format(tot / 1e6) else: return "{:.1f}K".format(tot / 1e3) else: return tot @classmethod def init_tokenizer(cls): tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") tokenizer.add_special_tokens({"bos_token": "[DEC]"}) return tokenizer def maybe_autocast(self, dtype=torch.float16): # if on cpu, don't use autocast # if on gpu, use autocast with dtype if provided, otherwise use torch.float16 enable_autocast = self.device != torch.device("cpu") if enable_autocast: return torch.cuda.amp.autocast(dtype=dtype) else: return contextlib.nullcontext() @classmethod def init_Qformer(cls, num_query_token, vision_width, cross_attention_freq=2): encoder_config = BertConfig.from_pretrained("bert-base-uncased") encoder_config.encoder_width = vision_width # insert cross-attention layer every other block encoder_config.add_cross_attention = True encoder_config.cross_attention_freq = cross_attention_freq encoder_config.query_length = num_query_token Qformer = BertLMHeadModel(config=encoder_config) query_tokens = nn.Parameter( torch.zeros(1, num_query_token, encoder_config.hidden_size) ) query_tokens.data.normal_(mean=0.0, std=encoder_config.initializer_range) return Qformer, query_tokens @classmethod def init_vision_encoder( cls, model_name, img_size, drop_path_rate, use_grad_checkpoint, precision ): assert model_name == "eva_clip_g", "vit model must be eva_clip_g for current version of MiniGPT-4" visual_encoder = create_eva_vit_g( img_size, drop_path_rate, use_grad_checkpoint, precision ) ln_vision = LayerNorm(visual_encoder.num_features) return visual_encoder, ln_vision def load_from_pretrained(self, url_or_filename): if is_url(url_or_filename): cached_file = download_cached_file( url_or_filename, check_hash=False, progress=True ) checkpoint = torch.load(cached_file, map_location="cpu") elif os.path.isfile(url_or_filename): checkpoint = torch.load(url_or_filename, map_location="cpu") else: raise RuntimeError("checkpoint url or path is invalid") state_dict = checkpoint["model"] msg = self.load_state_dict(state_dict, strict=False) # logging.info("Missing keys {}".format(msg.missing_keys)) logging.info("load checkpoint from %s" % url_or_filename) return msg class BaseEncoder(nn.Module): """ Base class for primitive encoders, such as ViT, TimeSformer, etc. """ def __init__(self): super().__init__() def forward_features(self, samples, **kwargs): raise NotImplementedError @property def device(self): return list(self.parameters())[0].device class SharedQueueMixin: @torch.no_grad() def _dequeue_and_enqueue(self, image_feat, text_feat, idxs=None): # gather keys before updating queue image_feats = concat_all_gather(image_feat) text_feats = concat_all_gather(text_feat) batch_size = image_feats.shape[0] ptr = int(self.queue_ptr) assert self.queue_size % batch_size == 0 # for simplicity # replace the keys at ptr (dequeue and enqueue) self.image_queue[:, ptr : ptr + batch_size] = image_feats.T self.text_queue[:, ptr : ptr + batch_size] = text_feats.T if idxs is not None: idxs = concat_all_gather(idxs) self.idx_queue[:, ptr : ptr + batch_size] = idxs.T ptr = (ptr + batch_size) % self.queue_size # move pointer self.queue_ptr[0] = ptr class MomentumDistilationMixin: @torch.no_grad() def copy_params(self): for model_pair in self.model_pairs: for param, param_m in zip( model_pair[0].parameters(), model_pair[1].parameters() ): param_m.data.copy_(param.data) # initialize param_m.requires_grad = False # not update by gradient @torch.no_grad() def _momentum_update(self): for model_pair in self.model_pairs: for param, param_m in zip( model_pair[0].parameters(), model_pair[1].parameters() ): param_m.data = param_m.data * self.momentum + param.data * ( 1.0 - self.momentum ) class GatherLayer(torch.autograd.Function): """ Gather tensors from all workers with support for backward propagation: This implementation does not cut the gradients as torch.distributed.all_gather does. """ @staticmethod def forward(ctx, x): output = [ torch.zeros_like(x) for _ in range(torch.distributed.get_world_size()) ] torch.distributed.all_gather(output, x) return tuple(output) @staticmethod def backward(ctx, *grads): all_gradients = torch.stack(grads) torch.distributed.all_reduce(all_gradients) return all_gradients[torch.distributed.get_rank()] def all_gather_with_grad(tensors): """ Performs all_gather operation on the provided tensors. Graph remains connected for backward grad computation. """ # Queue the gathered tensors world_size = torch.distributed.get_world_size() # There is no need for reduction in the single-proc case if world_size == 1: return tensors # tensor_all = GatherLayer.apply(tensors) tensor_all = GatherLayer.apply(tensors) return torch.cat(tensor_all, dim=0) @torch.no_grad() def concat_all_gather(tensor): """ Performs all_gather operation on the provided tensors. *** Warning ***: torch.distributed.all_gather has no gradient. """ # if use distributed training if not is_dist_avail_and_initialized(): return tensor tensors_gather = [ torch.ones_like(tensor) for _ in range(torch.distributed.get_world_size()) ] torch.distributed.all_gather(tensors_gather, tensor, async_op=False) output = torch.cat(tensors_gather, dim=0) return output def tile(x, dim, n_tile): init_dim = x.size(dim) repeat_idx = [1] * x.dim() repeat_idx[dim] = n_tile x = x.repeat(*(repeat_idx)) order_index = torch.LongTensor( np.concatenate([init_dim * np.arange(n_tile) + i for i in range(init_dim)]) ) return torch.index_select(x, dim, order_index.to(x.device)) def disabled_train(self, mode=True): """Overwrite model.train with this function to make sure train/eval mode does not change anymore.""" return self class LayerNorm(nn.LayerNorm): """Subclass torch's LayerNorm to handle fp16.""" def forward(self, x: torch.Tensor): orig_type = x.dtype ret = super().forward(x.type(torch.float32)) return ret.type(orig_type)