MiniGPT-4/test/models/test_moe_model.py
2023-12-01 23:17:44 +08:00

182 lines
5.8 KiB
Python

"""
Copyright (c) 2022, salesforce.com, inc.
All rights reserved.
SPDX-License-Identifier: BSD-3-Clause
For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""
import sys
sys.path.append("/mnt/pfs-guan-ssai/nlu/wanghanzi/multimodal/PromptMoE")
from minigpt4.models.QformerMoE import (
BertConfig,
BertMoELMHeadModel
)
vision_width = 1408
cross_attention_freq = 2
num_query_token = 32
# init_QformerMoE
moe_encoder_config = BertConfig.from_pretrained("/mnt/pfs-guan-ssai/nlu/wanghanzi/models/bert-base-uncased")
moe_encoder_config.encoder_width = vision_width
# insert cross-attention layer every other block
moe_encoder_config.add_cross_attention = True
moe_encoder_config.cross_attention_freq = cross_attention_freq
moe_encoder_config.query_length = num_query_token
moe_encoder_config.moebert_expert_num = 4
moe_encoder_config.moebert_route_method = "gate-sentence"
moe_encoder_config.moe_topk = 2
moe_encoder_config.moebert_load_balance = 0.1
moe_encoder_config.moebert_share_importance = 512 # TODO: meaning?
MoEQformer = BertMoELMHeadModel.from_pretrained(
"/mnt/pfs-guan-ssai/nlu/wanghanzi/models/bert-base-uncased", config=moe_encoder_config
)
"""
Compare Qformer & QformerMoE
"""
# blip2_qformer
# calculate parameters
from minigpt4.models import load_model
model = load_model("blip2", "pretrain")
model.QformerMoE, model.query_tokens_moe = model.init_QformerMoE(
num_query_token, model.visual_encoder.num_features, cross_attention_freq
)
model.Qformer, model.query_tokens = model.init_Qformer(
num_query_token, model.visual_encoder.num_features, cross_attention_freq
)
state_dict = model.Qformer.state_dict()
for name, param in model.Qformer.named_parameters():
if "_query" in name:
key_orig = name.replace("_query", "")
param.data.copy_(state_dict[key_orig])
if "10" in name:
print(name)
"""
blip2_t5_qformer_moe
Calculate Num Parameters
"""
import torch
import sys
sys.path.append("/mnt/pfs-guan-ssai/nlu/wanghanzi/multimodal/PromptMoE")
from minigpt4.models import model_zoo
from minigpt4.models import load_model
print(model_zoo)
device = torch.device("cuda:2" if torch.cuda.is_available() else "cpu")
model = load_model("blip2_t5_qformer_moe", "flant5xxl", device=device)
num_parameters=0
for n, p in model.Qformer.named_parameters():
if not p.requires_grad:
continue # frozen weights
if "11.experts.experts" in n:
print(n)
num_parameters += p.data.nelement()
print(num_parameters) # 23,619,840
# total trainable parameter: 415,631,104
num_parameters=0
for n, p in model.named_parameters():
if not p.requires_grad:
continue # frozen weights
num_parameters += p.data.nelement()
print(num_parameters) # 23,619,840
# total trainable parameter: 415,631,104
num_parameters=0
for n, p in model.named_parameters():
if not p.requires_grad:
continue # frozen weights
if 'Qformer.bert.encoder.layer.6.crossattention' in n:
num_parameters += p.data.nelement()
# if 'Qformer.bert.encoder.layer.11.output' in n:
# num_parameters += p.data.nelement()
print(num_parameters)
"""
forward
"""
import torch
import sys
sys.path.append("/mnt/pfs-guan-ssai/nlu/wanghanzi/multimodal/PromptMoE")
from minigpt4.models import load_model
device = torch.device("cuda:2" if torch.cuda.is_available() else "cpu")
model = load_model("blip2", "pretrain", device=device)
samples = {
'q_input':["What is around the open window?", # n23181
"Is the ground blue or brown?", # n168412
"What color are the pants?", # n446242
"What is the airplane flying above?"], # n414992
'llm_input':["What is around the open window?", # n23181
"Is the ground blue or brown?", # n168412
"What color are the pants?", # n446242
"What is the airplane flying above?"], # n414992
'text_output':["drapes",
"brown",
"red",
"ocean"
],
'image': torch.randn(4, 3, 224, 224).half().to(device)
# 'image': torch.randn(4, 3, 336, 336).half().to(device)
}
Qformer, query_tokens = model.init_QformerMoE(
num_query_token=32,
vision_width=1408,
moebert_expert_num=5,
moebert_route_method="gate-sentence",
moebert_load_balance=0.1,
moe_topk=2,
cross_attention_freq=2
)
Qformer = Qformer.to(device)
def maybe_autocast(device, dtype=torch.float16):
# if on cpu, don't use autocast
# if on gpu, use autocast with dtype if provided, otherwise use torch.float16
enable_autocast = device != torch.device("cpu")
if enable_autocast:
return torch.cuda.amp.autocast(dtype=dtype)
else:
return contextlib.nullcontext()
image = samples["image"]
with maybe_autocast(device):
image_embeds = model.ln_vision(model.visual_encoder(image))
image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(image.device)
bs = image.size(0)
query_tokens = query_tokens.expand(bs, -1, -1).to(device)
# image = samples["image"]
# image_atts = torch.ones(4, 257).to(device)
# image_embeds = torch.randn(4, 257, 1408).to(device)
# bz = image_embeds.shape[0]
# query_tokens = query_tokens.expand(bz, -1, -1).to(device)
text_Qformer = model.tokenizer(
samples["q_input"],
padding='longest',
truncation=True,
max_length=32,
return_tensors="pt",
).to(image.device)
query_atts = torch.ones(query_tokens.size()[:-1], dtype=torch.long).to(image.device)
Qformer_atts = torch.cat([query_atts,text_Qformer.attention_mask],dim=1).to(device)
query_output = Qformer.bert(
text_Qformer.input_ids,
attention_mask=Qformer_atts,
query_embeds=query_tokens,
encoder_hidden_states=image_embeds,
encoder_attention_mask=image_atts,
return_dict=True,
)