MiniGPT-4/minigpt4/models/blip2.py
2024-03-28 14:48:42 +08:00

484 lines
19 KiB
Python

"""
Copyright (c) 2023, salesforce.com, inc.
All rights reserved.
SPDX-License-Identifier: BSD-3-Clause
For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""
import contextlib
import logging
import os
import time
import datetime
import torch
import torch.nn as nn
import torch.distributed as dist
import torch.nn.functional as F
import minigpt4.common.dist_utils as dist_utils
from minigpt4.common.dist_utils import download_cached_file
from minigpt4.common.utils import is_url
from minigpt4.common.logger import MetricLogger
from minigpt4.models.base_model import BaseModel
from minigpt4.models.Qformer import BertConfig, BertLMHeadModel
from minigpt4.models.QformerMoE import BertMoELMHeadModel
from minigpt4.models.QformerMoELN import BertMoELMHeadModelLNIn
from minigpt4.models.QformerRouteMoE import BertMoERouteLMHeadModel
from minigpt4.models.QformerRouteMoELN import BertMoERouteLMHeadModelLNIn
from minigpt4.models.QformerRouteMoELNUni import BertMoERouteLMHeadModelLNInUniversal
from minigpt4.models.QformerRouteMoEUni import BertMoERouteLMHeadModelUniversal
from minigpt4.models.eva_vit import create_eva_vit_g
from transformers import BertTokenizer
from peft import (
LoraConfig,
get_peft_model,
prepare_model_for_int8_training,
)
class Blip2Base(BaseModel):
@classmethod
def init_tokenizer(cls, truncation_side="right"):
tokenizer = BertTokenizer.from_pretrained("/mnt/pfs-guan-ssai/nlu/wanghanzi/models/bert-base-uncased", truncation_side=truncation_side)
tokenizer.add_special_tokens({"bos_token": "[DEC]"})
return tokenizer
def maybe_autocast(self, dtype=torch.float16):
# if on cpu, don't use autocast
# if on gpu, use autocast with dtype if provided, otherwise use torch.float16
enable_autocast = self.device != torch.device("cpu")
if enable_autocast:
return torch.cuda.amp.autocast(dtype=dtype)
else:
return contextlib.nullcontext()
@classmethod
def init_Qformer(cls, num_query_token, vision_width, cross_attention_freq=2):
encoder_config = BertConfig.from_pretrained("/mnt/pfs-guan-ssai/nlu/wanghanzi/models/bert-base-uncased")
encoder_config.encoder_width = vision_width
# insert cross-attention layer every other block
encoder_config.add_cross_attention = True
encoder_config.cross_attention_freq = cross_attention_freq
encoder_config.query_length = num_query_token
Qformer = BertLMHeadModel.from_pretrained(
"/mnt/pfs-guan-ssai/nlu/wanghanzi/models/bert-base-uncased", config=encoder_config
)
query_tokens = nn.Parameter(
torch.zeros(1, num_query_token, encoder_config.hidden_size)
)
query_tokens.data.normal_(mean=0.0, std=encoder_config.initializer_range)
return Qformer, query_tokens
@classmethod
def init_RouteMoEQformerUni(cls, num_query_token, vision_width, moebert_expert_num, moebert_num_beams, route_method, moe_weight_type, cross_attention_freq=2, ln_position="out"):
moe_encoder_config = BertConfig.from_pretrained("/mnt/pfs-guan-ssai/nlu/wanghanzi/models/bert-base-uncased")
moe_encoder_config.encoder_width = vision_width
moe_encoder_config.add_cross_attention = True
moe_encoder_config.cross_attention_freq = cross_attention_freq
moe_encoder_config.query_length = num_query_token
moe_encoder_config.moebert_expert_num = moebert_expert_num
moe_encoder_config.moebert_num_beams = moebert_num_beams
moe_encoder_config.route_method = route_method
moe_encoder_config.moe_weight_type = moe_weight_type
if ln_position == "out":
RouteMoEQformer = BertMoERouteLMHeadModelUniversal.from_pretrained(
"/mnt/pfs-guan-ssai/nlu/wanghanzi/models/bert-base-uncased", config=moe_encoder_config
)
elif ln_position == "in":
RouteMoEQformer = BertMoERouteLMHeadModelLNInUniversal.from_pretrained(
"/mnt/pfs-guan-ssai/nlu/wanghanzi/models/bert-base-uncased", config=moe_encoder_config
)
query_tokens = nn.Parameter(
torch.zeros(1, num_query_token, moe_encoder_config.hidden_size)
)
query_tokens.data.normal_(mean=0.0, std=moe_encoder_config.initializer_range)
return RouteMoEQformer, query_tokens
@classmethod
def init_RouteMoEQformer(cls, num_query_token, vision_width, moebert_expert_num, moebert_num_beams, route_method, moe_weight_type, cross_attention_freq=2, ln_position="out"):
moe_encoder_config = BertConfig.from_pretrained("/mnt/pfs-guan-ssai/nlu/wanghanzi/models/bert-base-uncased")
moe_encoder_config.encoder_width = vision_width
# insert cross-attention layer every other block
moe_encoder_config.add_cross_attention = True
moe_encoder_config.cross_attention_freq = cross_attention_freq
moe_encoder_config.query_length = num_query_token
moe_encoder_config.moebert_expert_num = moebert_expert_num
moe_encoder_config.moebert_num_beams = moebert_num_beams
moe_encoder_config.route_method = route_method
moe_encoder_config.moe_weight_type = moe_weight_type
if ln_position == "out":
RouteMoEQformer = BertMoERouteLMHeadModel.from_pretrained(
"/mnt/pfs-guan-ssai/nlu/wanghanzi/models/bert-base-uncased", config=moe_encoder_config
)
elif ln_position == "in":
RouteMoEQformer = BertMoERouteLMHeadModelLNIn.from_pretrained(
"/mnt/pfs-guan-ssai/nlu/wanghanzi/models/bert-base-uncased", config=moe_encoder_config
)
query_tokens = nn.Parameter(
torch.zeros(1, num_query_token, moe_encoder_config.hidden_size)
)
query_tokens.data.normal_(mean=0.0, std=moe_encoder_config.initializer_range)
return RouteMoEQformer, query_tokens
@classmethod
def init_QformerMoE(cls, num_query_token, vision_width, moebert_expert_num, moebert_route_method, moebert_load_balance, moe_topk=1, use_balance_loss=True, moe_weight_type='l2_norm', cross_attention_freq=2,ln_position="out"):
moe_encoder_config = BertConfig.from_pretrained("/mnt/pfs-guan-ssai/nlu/wanghanzi/models/bert-base-uncased")
moe_encoder_config.encoder_width = vision_width
# insert cross-attention layer every other block
moe_encoder_config.add_cross_attention = True
moe_encoder_config.cross_attention_freq = cross_attention_freq
moe_encoder_config.query_length = num_query_token
moe_encoder_config.moebert_expert_num = moebert_expert_num
moe_encoder_config.moebert_route_method = moebert_route_method
moe_encoder_config.moebert_load_balance = moebert_load_balance
moe_encoder_config.moe_topk = moe_topk
moe_encoder_config.use_balance_loss = use_balance_loss
moe_encoder_config.moe_weight_type = moe_weight_type
if ln_position == "out":
MoEQformer = BertMoELMHeadModel.from_pretrained(
"/mnt/pfs-guan-ssai/nlu/wanghanzi/models/bert-base-uncased", config=moe_encoder_config
)
elif ln_position == "in":
MoEQformer = BertMoELMHeadModelLNIn.from_pretrained(
"/mnt/pfs-guan-ssai/nlu/wanghanzi/models/bert-base-uncased", config=moe_encoder_config
)
query_tokens = nn.Parameter(
torch.zeros(1, num_query_token, moe_encoder_config.hidden_size)
)
query_tokens.data.normal_(mean=0.0, std=moe_encoder_config.initializer_range)
return MoEQformer, query_tokens
def init_llm(cls, llama_model_path, freeze_llm=True, lora_r=0,
lora_target_modules=["q_proj","v_proj"], **lora_kargs):
logging.info('Loading LLAMA')
from transformers import LlamaTokenizer
from minigpt4.models.modeling_llama import LlamaForCausalLM
llama_tokenizer = LlamaTokenizer.from_pretrained(llama_model_path, use_fast=False)
# llama_tokenizer.add_special_tokens({'pad_token': '[PAD]'})
# llama_tokenizer.add_special_tokens({'bos_token': '</s>'})
# llama_tokenizer.add_special_tokens({'eos_token': '</s>'})
# llama_tokenizer.add_special_tokens({'unk_token': '</s>'})
llama_tokenizer.pad_token = llama_tokenizer.unk_token
llama_model = LlamaForCausalLM.from_pretrained(
llama_model_path,
torch_dtype=torch.float16,
)
llama_model.resize_token_embeddings(len(llama_tokenizer))
# self.eos_token_id = self.llm_tokenizer(
# self.llm_tokenizer.eos_token, add_special_tokens=False
# ).input_ids[0]
if freeze_llm==False and lora_r > 0:
llama_model = prepare_model_for_int8_training(llama_model)
loraconfig = LoraConfig(
r=lora_r,
bias="none",
task_type="CAUSAL_LM",
target_modules=lora_target_modules,
**lora_kargs
)
llama_model = get_peft_model(llama_model, loraconfig)
llama_model.print_trainable_parameters()
else:
for name, param in llama_model.named_parameters():
param.requires_grad = False
logging.info('Loading LLAMA Done')
return llama_model, llama_tokenizer
def init_vision_encoder(
self, model_name, img_size, drop_path_rate, use_grad_checkpoint, precision, freeze_vit=True
):
assert model_name in [
"eva_clip_g",
"eva2_clip_L",
"clip_L",
], "vit model must be eva_clip_g, eva2_clip_L or clip_L"
if model_name == "eva_clip_g":
visual_encoder = create_eva_vit_g(
img_size, drop_path_rate, use_grad_checkpoint, precision
)
# elif model_name == "eva2_clip_L":
# visual_encoder = create_eva2_vit_L(
# img_size, drop_path_rate, use_grad_checkpoint, precision
# )
elif model_name == "clip_L":
from minigpt4.models.clip_vit import create_clip_vit_L
visual_encoder = create_clip_vit_L(img_size, use_grad_checkpoint, precision)
ln_vision = LayerNorm(visual_encoder.num_features)
self.vit_name = model_name
pytorch_total_params = sum(p.numel() for p in visual_encoder.parameters())
print(f'{model_name} clip vit params:')
print(f"{pytorch_total_params * 1e-9:.2} B")
if freeze_vit:
for name, param in visual_encoder.named_parameters():
param.requires_grad = False
visual_encoder = visual_encoder.eval()
visual_encoder.train = disabled_train
# freeze ln vision
# for name, param in ln_vision.named_parameters():
# param.requires_grad = False
# ln_vision = ln_vision.eval()
# ln_vision.train = disabled_train
logging.info("freeze vision encoder but not ln_vision")
return visual_encoder, ln_vision
def mean_pool_adjust_query_tokens(self, state_dict, num_query_token):
group = 32 // num_query_token
query_tokens = state_dict['query_tokens'].view(1,num_query_token,group,768)
state_dict['query_tokens'] = torch.mean(query_tokens, dim=2)
return state_dict
def load_from_pretrained(self, url_or_filename, num_query_token=32):
if is_url(url_or_filename):
cached_file = download_cached_file(
url_or_filename, check_hash=False, progress=True
)
checkpoint = torch.load(cached_file, map_location="cpu")
elif os.path.isfile(url_or_filename):
checkpoint = torch.load(url_or_filename, map_location="cpu")
else:
raise RuntimeError("checkpoint url or path is invalid")
state_dict = checkpoint["model"]
# state_dict = self.mean_pool_adjust_query_tokens(state_dict, num_query_token)
msg = self.load_state_dict(state_dict, strict=False)
# logging.info("Missing keys {}".format(msg.missing_keys))
logging.info("load checkpoint from %s" % url_or_filename)
return msg
def get_optimizer_params(self, weight_decay, lr_scale=1):
vit_num_layers = self.visual_encoder.get_num_layer()
lr_scales = list(lr_scale ** (vit_num_layers + 1 - i) for i in range(vit_num_layers + 2))
parameter_group_names = {}
parameter_group_vars = {}
for name, param in self.named_parameters():
if not param.requires_grad:
continue # frozen weights
if len(param.shape) == 1 or name.endswith(".bias"):
group_name = "no_decay"
this_weight_decay = 0.
else:
group_name = "decay"
this_weight_decay = weight_decay
if 'visual_encoder' in name:
layer_id = self.visual_encoder.get_num_layer(name.replace('visual_encoder.',''))
group_name = "vit_layer_%d_%s" % (layer_id, group_name)
else:
layer_id = None
if group_name not in parameter_group_names:
if layer_id is not None:
scale = lr_scales[layer_id]
else:
scale = 1
parameter_group_names[group_name] = {
"weight_decay": this_weight_decay,
"params": [],
"lr_scale": scale
}
parameter_group_vars[group_name] = {
"weight_decay": this_weight_decay,
"params": [],
"lr_scale": scale
}
parameter_group_vars[group_name]["params"].append(param)
parameter_group_names[group_name]["params"].append(name)
# import json
# print("Param groups = %s" % json.dumps(parameter_group_names, indent=2))
optim_params = list(parameter_group_vars.values())
return optim_params
def _lemmatize(self, answers):
def apply(answer):
doc = self.lemmatizer(answer)
words = []
for token in doc:
if token.pos_ in ["NOUN", "VERB"]:
words.append(token.lemma_)
else:
words.append(token.text)
answer = " ".join(words)
return answer
return [apply(answer) for answer in answers]
@property
def lemmatizer(self):
if self._lemmatizer is None:
try:
import spacy
self._lemmatizer = spacy.load("en_core_web_sm")
except ImportError:
logging.error(
"""
Please install spacy and en_core_web_sm model to apply lemmatization.
python -m spacy download en_core_web_sm
OR
import spacy.cli
spacy.cli.download("en_core_web_sm")
"""
)
exit(1)
return self._lemmatizer
def disabled_train(self, mode=True):
"""Overwrite model.train with this function to make sure train/eval mode
does not change anymore."""
return self
class LayerNorm(nn.LayerNorm):
"""Subclass torch's LayerNorm to handle fp16."""
def forward(self, x: torch.Tensor):
orig_type = x.dtype
ret = super().forward(x.type(torch.float32))
return ret.type(orig_type)
def compute_sim_matrix(model, data_loader, **kwargs):
k_test = kwargs.pop("k_test")
metric_logger = MetricLogger(delimiter=" ")
header = "Evaluation:"
logging.info("Computing features for evaluation...")
start_time = time.time()
texts = data_loader.dataset.text
num_text = len(texts)
text_bs = 256
text_ids = []
text_embeds = []
text_atts = []
for i in range(0, num_text, text_bs):
text = texts[i : min(num_text, i + text_bs)]
text_input = model.tokenizer(
text,
padding="max_length",
truncation=True,
max_length=35,
return_tensors="pt",
).to(model.device)
text_feat = model.forward_text(text_input)
text_embed = F.normalize(model.text_proj(text_feat))
text_embeds.append(text_embed)
text_ids.append(text_input.input_ids)
text_atts.append(text_input.attention_mask)
text_embeds = torch.cat(text_embeds, dim=0)
text_ids = torch.cat(text_ids, dim=0)
text_atts = torch.cat(text_atts, dim=0)
vit_feats = []
image_embeds = []
for samples in data_loader:
image = samples["image"]
image = image.to(model.device)
image_feat, vit_feat = model.forward_image(image)
image_embed = model.vision_proj(image_feat)
image_embed = F.normalize(image_embed, dim=-1)
vit_feats.append(vit_feat.cpu())
image_embeds.append(image_embed)
vit_feats = torch.cat(vit_feats, dim=0)
image_embeds = torch.cat(image_embeds, dim=0)
sims_matrix = []
for image_embed in image_embeds:
sim_q2t = image_embed @ text_embeds.t()
sim_i2t, _ = sim_q2t.max(0)
sims_matrix.append(sim_i2t)
sims_matrix = torch.stack(sims_matrix, dim=0)
score_matrix_i2t = torch.full(
(len(data_loader.dataset.image), len(texts)), -100.0
).to(model.device)
num_tasks = dist_utils.get_world_size()
rank = dist_utils.get_rank()
step = sims_matrix.size(0) // num_tasks + 1
start = rank * step
end = min(sims_matrix.size(0), start + step)
for i, sims in enumerate(
metric_logger.log_every(sims_matrix[start:end], 50, header)
):
topk_sim, topk_idx = sims.topk(k=k_test, dim=0)
image_inputs = vit_feats[start + i].repeat(k_test, 1, 1).to(model.device)
score = model.compute_itm(
image_inputs=image_inputs,
text_ids=text_ids[topk_idx],
text_atts=text_atts[topk_idx],
).float()
score_matrix_i2t[start + i, topk_idx] = score + topk_sim
sims_matrix = sims_matrix.t()
score_matrix_t2i = torch.full(
(len(texts), len(data_loader.dataset.image)), -100.0
).to(model.device)
step = sims_matrix.size(0) // num_tasks + 1
start = rank * step
end = min(sims_matrix.size(0), start + step)
for i, sims in enumerate(
metric_logger.log_every(sims_matrix[start:end], 50, header)
):
topk_sim, topk_idx = sims.topk(k=k_test, dim=0)
image_inputs = vit_feats[topk_idx.cpu()].to(model.device)
score = model.compute_itm(
image_inputs=image_inputs,
text_ids=text_ids[start + i].repeat(k_test, 1),
text_atts=text_atts[start + i].repeat(k_test, 1),
).float()
score_matrix_t2i[start + i, topk_idx] = score + topk_sim
if dist_utils.is_dist_avail_and_initialized():
dist.barrier()
torch.distributed.all_reduce(
score_matrix_i2t, op=torch.distributed.ReduceOp.SUM
)
torch.distributed.all_reduce(
score_matrix_t2i, op=torch.distributed.ReduceOp.SUM
)
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
logging.info("Evaluation time {}".format(total_time_str))
return score_matrix_i2t.cpu().numpy(), score_matrix_t2i.cpu().numpy()