mirror of
https://github.com/Vision-CAIR/MiniGPT-4.git
synced 2025-04-04 01:50:47 +00:00
249 lines
8.1 KiB
Python
249 lines
8.1 KiB
Python
"""
|
|
Copyright (c) 2022, salesforce.com, inc.
|
|
All rights reserved.
|
|
SPDX-License-Identifier: BSD-3-Clause
|
|
For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
|
"""
|
|
|
|
import os
|
|
import logging
|
|
import contextlib
|
|
|
|
from omegaconf import OmegaConf
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn as nn
|
|
from transformers import BertTokenizer, LlamaTokenizer
|
|
from transformers.models.llama.modeling_llama import LlamaForCausalLM
|
|
from peft import (
|
|
LoraConfig,
|
|
get_peft_model,
|
|
prepare_model_for_int8_training,
|
|
)
|
|
|
|
from minigpt4.common.dist_utils import download_cached_file, is_dist_avail_and_initialized
|
|
from minigpt4.common.utils import get_abs_path, is_url
|
|
from minigpt4.models.eva_vit import create_eva_vit_g
|
|
|
|
|
|
|
|
class BaseModel(nn.Module):
|
|
"""Base class for models."""
|
|
|
|
def __init__(self):
|
|
super().__init__()
|
|
|
|
@property
|
|
def device(self):
|
|
return list(self.parameters())[-1].device
|
|
|
|
def load_checkpoint(self, url_or_filename):
|
|
"""
|
|
Load from a finetuned checkpoint.
|
|
|
|
This should expect no mismatch in the model keys and the checkpoint keys.
|
|
"""
|
|
|
|
if is_url(url_or_filename):
|
|
cached_file = download_cached_file(
|
|
url_or_filename, check_hash=False, progress=True
|
|
)
|
|
checkpoint = torch.load(cached_file, map_location="cpu")
|
|
elif os.path.isfile(url_or_filename):
|
|
checkpoint = torch.load(url_or_filename, map_location="cpu")
|
|
else:
|
|
raise RuntimeError("checkpoint url or path is invalid")
|
|
|
|
if "model" in checkpoint.keys():
|
|
state_dict = checkpoint["model"]
|
|
else:
|
|
state_dict = checkpoint
|
|
|
|
msg = self.load_state_dict(state_dict, strict=False)
|
|
|
|
logging.info("Missing keys {}".format(msg.missing_keys))
|
|
logging.info("load checkpoint from %s" % url_or_filename)
|
|
|
|
return msg
|
|
|
|
@classmethod
|
|
def from_pretrained(cls, model_type):
|
|
"""
|
|
Build a pretrained model from default configuration file, specified by model_type.
|
|
|
|
Args:
|
|
- model_type (str): model type, specifying architecture and checkpoints.
|
|
|
|
Returns:
|
|
- model (nn.Module): pretrained or finetuned model, depending on the configuration.
|
|
"""
|
|
model_cfg = OmegaConf.load(cls.default_config_path(model_type)).model
|
|
model = cls.from_config(model_cfg)
|
|
|
|
return model
|
|
|
|
@classmethod
|
|
def default_config_path(cls, model_type):
|
|
assert (
|
|
model_type in cls.PRETRAINED_MODEL_CONFIG_DICT
|
|
), "Unknown model type {}".format(model_type)
|
|
return get_abs_path(cls.PRETRAINED_MODEL_CONFIG_DICT[model_type])
|
|
|
|
def load_checkpoint_from_config(self, cfg, **kwargs):
|
|
"""
|
|
Load checkpoint as specified in the config file.
|
|
|
|
If load_finetuned is True, load the finetuned model; otherwise, load the pretrained model.
|
|
When loading the pretrained model, each task-specific architecture may define their
|
|
own load_from_pretrained() method.
|
|
"""
|
|
load_finetuned = cfg.get("load_finetuned", True)
|
|
if load_finetuned:
|
|
finetune_path = cfg.get("finetuned", None)
|
|
assert (
|
|
finetune_path is not None
|
|
), "Found load_finetuned is True, but finetune_path is None."
|
|
self.load_checkpoint(url_or_filename=finetune_path)
|
|
else:
|
|
# load pre-trained weights
|
|
pretrain_path = cfg.get("pretrained", None)
|
|
assert "Found load_finetuned is False, but pretrain_path is None."
|
|
self.load_from_pretrained(url_or_filename=pretrain_path, **kwargs)
|
|
|
|
def before_evaluation(self, **kwargs):
|
|
pass
|
|
|
|
def show_n_params(self, return_str=True):
|
|
tot = 0
|
|
for p in self.parameters():
|
|
w = 1
|
|
for x in p.shape:
|
|
w *= x
|
|
tot += w
|
|
if return_str:
|
|
if tot >= 1e6:
|
|
return "{:.1f}M".format(tot / 1e6)
|
|
else:
|
|
return "{:.1f}K".format(tot / 1e3)
|
|
else:
|
|
return tot
|
|
|
|
def maybe_autocast(self, dtype=torch.float16):
|
|
# if on cpu, don't use autocast
|
|
# if on gpu, use autocast with dtype if provided, otherwise use torch.float16
|
|
enable_autocast = self.device != torch.device("cpu")
|
|
|
|
if enable_autocast:
|
|
return torch.cuda.amp.autocast(dtype=dtype)
|
|
else:
|
|
return contextlib.nullcontext()
|
|
|
|
@classmethod
|
|
def init_vision_encoder(
|
|
cls, model_name, img_size, drop_path_rate, use_grad_checkpoint, precision, freeze
|
|
):
|
|
logging.info('Loading VIT')
|
|
|
|
assert model_name == "eva_clip_g", "vit model must be eva_clip_g for current version of MiniGPT-4"
|
|
if not freeze:
|
|
precision = "fp32" # fp16 is not for training
|
|
|
|
visual_encoder = create_eva_vit_g(
|
|
img_size, drop_path_rate, use_grad_checkpoint, precision
|
|
)
|
|
|
|
ln_vision = LayerNorm(visual_encoder.num_features)
|
|
|
|
if freeze:
|
|
for name, param in visual_encoder.named_parameters():
|
|
param.requires_grad = False
|
|
visual_encoder = visual_encoder.eval()
|
|
visual_encoder.train = disabled_train
|
|
for name, param in ln_vision.named_parameters():
|
|
param.requires_grad = False
|
|
ln_vision = ln_vision.eval()
|
|
ln_vision.train = disabled_train
|
|
logging.info("freeze vision encoder")
|
|
|
|
logging.info('Loading VIT Done')
|
|
return visual_encoder, ln_vision
|
|
|
|
def init_llm(cls, llama_model_path, low_resource=False, low_res_device=0, lora_r=0,
|
|
lora_target_modules=["q_proj","v_proj"], **lora_kargs):
|
|
logging.info('Loading LLAMA')
|
|
llama_tokenizer = LlamaTokenizer.from_pretrained(llama_model_path, use_fast=False)
|
|
llama_tokenizer.pad_token = "$$"
|
|
|
|
if low_resource:
|
|
llama_model = LlamaForCausalLM.from_pretrained(
|
|
llama_model_path,
|
|
torch_dtype=torch.float16,
|
|
load_in_8bit=True,
|
|
device_map={'': low_res_device}
|
|
)
|
|
else:
|
|
llama_model = LlamaForCausalLM.from_pretrained(
|
|
llama_model_path,
|
|
torch_dtype=torch.float16,
|
|
)
|
|
|
|
if lora_r > 0:
|
|
llama_model = prepare_model_for_int8_training(llama_model)
|
|
loraconfig = LoraConfig(
|
|
r=lora_r,
|
|
bias="none",
|
|
task_type="CAUSAL_LM",
|
|
target_modules=lora_target_modules,
|
|
**lora_kargs
|
|
)
|
|
llama_model = get_peft_model(llama_model, loraconfig)
|
|
|
|
llama_model.print_trainable_parameters()
|
|
|
|
else:
|
|
for name, param in llama_model.named_parameters():
|
|
param.requires_grad = False
|
|
logging.info('Loading LLAMA Done')
|
|
return llama_model, llama_tokenizer
|
|
|
|
|
|
def load_from_pretrained(self, url_or_filename):
|
|
if is_url(url_or_filename):
|
|
cached_file = download_cached_file(
|
|
url_or_filename, check_hash=False, progress=True
|
|
)
|
|
checkpoint = torch.load(cached_file, map_location="cpu")
|
|
elif os.path.isfile(url_or_filename):
|
|
checkpoint = torch.load(url_or_filename, map_location="cpu")
|
|
else:
|
|
raise RuntimeError("checkpoint url or path is invalid")
|
|
|
|
state_dict = checkpoint["model"]
|
|
|
|
msg = self.load_state_dict(state_dict, strict=False)
|
|
|
|
# logging.info("Missing keys {}".format(msg.missing_keys))
|
|
logging.info("load checkpoint from %s" % url_or_filename)
|
|
|
|
return msg
|
|
|
|
|
|
def disabled_train(self, mode=True):
|
|
"""Overwrite model.train with this function to make sure train/eval mode
|
|
does not change anymore."""
|
|
return self
|
|
|
|
|
|
class LayerNorm(nn.LayerNorm):
|
|
"""Subclass torch's LayerNorm to handle fp16."""
|
|
|
|
def forward(self, x: torch.Tensor):
|
|
orig_type = x.dtype
|
|
ret = super().forward(x.type(torch.float32))
|
|
return ret.type(orig_type)
|
|
|
|
|
|
|
|
|
|
|