MiniGPT-4/minigpt4/models/blip2_vicuna_instruct.py

725 lines
30 KiB
Python
Raw Normal View History

2023-11-23 05:39:24 +00:00
"""
Requires Transformer 4.28 and above, implementation may change according the Llama implementation
"""
import logging
import string
from packaging import version
import re
import torch
from torch.cuda.amp import autocast as autocast
import torch.nn as nn
import transformers
from minigpt4.common.registry import registry
from minigpt4.models.blip2 import Blip2Base, disabled_train
@registry.register_model("blip2_vicuna_instruct")
class Blip2VicunaInstruct(Blip2Base):
"""
BLIP2 Vicuna model.
Supported model types:
- vicuna7b
- vicuna13b
Usage:
>>> from minigpt4.models import load_model
>>> import torch
>>> device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
2023-12-19 03:24:51 +00:00
>>> model = load_model("blip2_vicuna_instruct", "vicuna7b_qfmoe_route", device=device)
2023-11-23 05:39:24 +00:00
"""
PRETRAINED_MODEL_CONFIG_DICT = {
"vicuna7b_instruct": "configs/models/blip2/blip2_instruct_vicuna7b.yaml",
"vicuna7b_pretrain": "configs/models/blip2/blip2_pretrain_vicuna7b.yaml",
2023-12-19 03:24:51 +00:00
"vicuna7b_qfmoe_post": "configs/models/blip2/blip2_qformer_moe_post_vicuna7b.yaml",
"vicuna7b_qfmoe_route": "configs/models/blip2/blip2_pretrain_vicuna7b_route_moe.yaml",
2023-11-23 05:39:24 +00:00
}
def __init__(
self,
vit_model="eva_clip_g",
q_former_model="",
img_size=224,
drop_path_rate=0,
use_grad_checkpoint=False,
vit_precision="fp16",
freeze_vit=True,
freeze_llm=True,
freeze_qformer=False,
freeze_proj=False,
num_query_token=32,
llm_model="",
prompt="",
max_txt_len=128,
max_output_txt_len=256,
apply_lemmatizer=False,
qformer_text_input=True,
use_moeqformer=False,
2023-12-19 03:24:51 +00:00
use_route_moe=False,
moebert_num_beams=2,
2023-11-23 05:39:24 +00:00
moebert_expert_num=5,
moebert_route_method="gate-sentence",
moebert_load_balance = 0.1,
moe_topk = 1,
2023-12-19 03:24:51 +00:00
use_balance_loss = True,
moe_weight_type = "l2_norm",
gate_save_path = None,
2024-03-04 02:41:24 +00:00
bal_loss_decay_epoch = 3,
ln_position = "out",
2023-11-23 05:39:24 +00:00
):
super().__init__()
transformers_version = version.parse(transformers.__version__)
assert transformers_version >= version.parse("4.28"), "BLIP-2 Vicuna requires transformers>=4.28"
from transformers import LlamaTokenizer
from minigpt4.models.modeling_llama import LlamaForCausalLM
self.tokenizer = self.init_tokenizer(truncation_side="left")
print('Initing & Loading VIT')
self.visual_encoder, self.ln_vision = self.init_vision_encoder(
vit_model, img_size, drop_path_rate, use_grad_checkpoint, vit_precision
)
if freeze_vit:
for name, param in self.visual_encoder.named_parameters():
param.requires_grad = False
self.visual_encoder = self.visual_encoder.eval()
self.visual_encoder.train = disabled_train
# freeze ln vision
# for name, param in self.ln_vision.named_parameters():
# param.requires_grad = False
# self.ln_vision = self.ln_vision.eval()
# self.ln_vision.train = disabled_train
logging.info("freeze vision encoder but not ln_vision")
print('Loading VIT Done')
if use_moeqformer:
2023-12-19 03:24:51 +00:00
if use_route_moe:
self.Qformer, self.query_tokens = self.init_RouteMoEQformer(
num_query_token=num_query_token,
vision_width=self.visual_encoder.num_features,
moebert_expert_num=moebert_expert_num,
moebert_num_beams=moebert_num_beams,
route_method=moebert_route_method,
moe_weight_type=moe_weight_type,
2023-12-19 03:24:51 +00:00
cross_attention_freq=2
)
else:
self.Qformer, self.query_tokens = self.init_QformerMoE(
num_query_token=num_query_token,
vision_width=self.visual_encoder.num_features,
moebert_expert_num=moebert_expert_num,
moebert_route_method=moebert_route_method,
moebert_load_balance=moebert_load_balance,
moe_topk=moe_topk,
use_balance_loss=use_balance_loss,
moe_weight_type=moe_weight_type,
2024-03-04 02:41:24 +00:00
cross_attention_freq=2,
ln_position=ln_position,
2023-12-19 03:24:51 +00:00
)
2023-11-23 05:39:24 +00:00
else:
self.Qformer, self.query_tokens = self.init_Qformer(
num_query_token, self.visual_encoder.num_features
)
if not qformer_text_input:
self.Qformer.bert.embeddings.word_embeddings = None
self.Qformer.bert.embeddings.position_embeddings = None
for layer in self.Qformer.bert.encoder.layer:
layer.output = None
layer.intermediate = None
else:
self.Qformer.resize_token_embeddings(len(self.tokenizer))
self.Qformer.cls = None
print("Loading LLM")
self.llm_tokenizer = LlamaTokenizer.from_pretrained(llm_model, use_fast=False, truncation_side="left")
self.llm_model = LlamaForCausalLM.from_pretrained(
llm_model, torch_dtype=torch.float16
)
# self.llm_tokenizer.add_special_tokens({'pad_token': '[PAD]'})
# self.llm_tokenizer.add_special_tokens({'bos_token': '</s>'})
# self.llm_tokenizer.add_special_tokens({'eos_token': '</s>'})
# self.llm_tokenizer.add_special_tokens({'unk_token': '</s>'})
self.llm_tokenizer.pad_token = self.llm_tokenizer.unk_token
self.llm_model.resize_token_embeddings(len(self.llm_tokenizer))
# self.eos_token_id = self.llm_tokenizer(
# self.llm_tokenizer.eos_token, add_special_tokens=False
# ).input_ids[0]
if freeze_llm:
for name, param in self.llm_model.named_parameters():
param.requires_grad = False
self.llm_proj = nn.Linear(
self.Qformer.config.hidden_size, self.llm_model.config.hidden_size
)
if qformer_text_input:
# Hard-coded to load from BLIP-2 stage-1 pre-trained model( to init ffn but not ideal)
self.load_from_pretrained(
2023-12-19 03:24:51 +00:00
url_or_filename="/mnt/pfs-guan-ssai/nlu/wanghanzi/models/blip2/blip2_pretrained/blip2_pretrained.pth",
num_query_token=num_query_token
2023-11-23 05:39:24 +00:00
)
if use_moeqformer:
# load blip2_vicuna_pretrain to init query_ffn
2023-12-19 03:24:51 +00:00
self.load_from_pretrained(
url_or_filename=q_former_model,
num_query_token=num_query_token
)
2023-11-23 05:39:24 +00:00
# init MoE Layer(init moe ffn by blip2 query ffn)
state_dict = self.Qformer.state_dict()
for name, param in self.Qformer.named_parameters():
if "_query" in name and "experts.experts" in name:
pattern = r'\.experts\.experts\.\d+'
key_orig = re.sub(pattern, '', name)
param.data.copy_(state_dict[key_orig]) # copy state_dict[key_orig] to param
if "experts.intermediate_query" in name or "experts.output_query" in name:
key_orig = re.sub(r'experts\.', '', name)
param.data.copy_(state_dict[key_orig]) # copy state_dict[key_orig] to param
if "_query" in name and "experts" not in name: # raw ffn_query not update
param.requires_grad = False
ln_pattern = r"bert\.encoder\.layer\.\d+\.expert_ln\.(weight|bias)"
if re.match(ln_pattern, name):
key_orig = re.sub('expert_ln', 'output_query.LayerNorm', name)
param.data.copy_(state_dict[key_orig])
d1_pattern = r"bert\.encoder\.layer\.(\d+)\.experts(\.|\.experts\.\d+\.)dense1\.(weight|bias)"
if re.match(d1_pattern, name):
key_orig = re.sub(r'experts(\.|\.experts\.\d+\.)dense1', 'intermediate_query.dense', name)
param.data.copy_(state_dict[key_orig])
d2_pattern = r"bert\.encoder\.layer\.(\d+)\.experts(\.|\.experts\.\d+\.)dense2\.(weight|bias)"
if re.match(d2_pattern, name):
key_orig = re.sub(r'experts(\.|\.experts\.\d+\.)dense2', 'output_query.dense', name)
param.data.copy_(state_dict[key_orig])
2023-11-23 05:39:24 +00:00
# freeze qformer
if freeze_qformer:
for name, param in self.Qformer.named_parameters():
param.requires_grad = False
self.Qformer = self.Qformer.eval()
self.Qformer.train = disabled_train
logging.info("freeze Qformer")
# After loading, freeze llm_proj
if freeze_proj:
for name, param in self.llm_proj.named_parameters():
param.requires_grad = False
self.llm_proj = self.llm_proj.eval()
self.llm_proj.train = disabled_train
self.max_txt_len = max_txt_len
self.max_output_txt_len = max_output_txt_len
self.prompt = prompt
prompt_tokens = self.llm_tokenizer(self.prompt, return_tensors="pt")
self.prompt_length = prompt_tokens.attention_mask.sum(1)
self._lemmatizer = None
self.qformer_text_input = qformer_text_input
self.use_moeqformer = use_moeqformer
2023-12-19 03:24:51 +00:00
self.use_route_moe = use_route_moe
2023-11-23 05:39:24 +00:00
self.moebert_load_balance = moebert_load_balance
self.moebert_num_beams = moebert_num_beams
2023-11-23 05:39:24 +00:00
2023-12-19 03:24:51 +00:00
self.gate_save_path = gate_save_path
2024-03-04 02:41:24 +00:00
self.bal_loss_decay_epoch = bal_loss_decay_epoch
2023-12-19 03:24:51 +00:00
# if self.gate_save_path != None:
# import os
# if not os.path.exists(self.gate_save_path):
# os.mkdir(self.gate_save_path)
2023-11-23 05:39:24 +00:00
def concat_text_input_output(self, input_ids, input_atts, output_ids, output_atts):
input_part_targets_len = []
llm_tokens = {"input_ids": [], "attention_mask": []}
for i in range(input_ids.size(0)):
this_input_ones = input_atts[i].sum()
input_part_targets_len.append(this_input_ones)
llm_tokens['input_ids'].append(
torch.cat([
input_ids[i][:this_input_ones],
output_ids[i][1:],
input_ids[i][this_input_ones:]
])
)
llm_tokens['attention_mask'].append(
torch.cat([
input_atts[i][:this_input_ones],
output_atts[i][1:],
input_atts[i][this_input_ones:]
])
)
llm_tokens['input_ids'] = torch.stack(llm_tokens['input_ids'])
llm_tokens['attention_mask'] = torch.stack(llm_tokens['attention_mask'])
return llm_tokens, input_part_targets_len
def forward(self, samples):
# print('-----------------')
# print(samples["text_input"])
# print(samples["text_output"])
# print('-----------------')
# import pdb;pdb.set_trace() # 0107test
2023-11-23 05:39:24 +00:00
image = samples["image"]
with self.maybe_autocast():
image_embeds = self.ln_vision(self.visual_encoder(image))
image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(image.device)
bs = image.size(0)
query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
if self.qformer_text_input:
text_Qformer = self.tokenizer(
samples["q_input"],
padding='longest',
truncation=True,
max_length=self.max_txt_len,
return_tensors="pt",
).to(image.device)
query_atts = torch.ones(query_tokens.size()[:-1], dtype=torch.long).to(image.device)
Qformer_atts = torch.cat([query_atts, text_Qformer.attention_mask],dim=1)
query_output = self.Qformer.bert(
text_Qformer.input_ids,
attention_mask=Qformer_atts,
query_embeds=query_tokens,
encoder_hidden_states=image_embeds,
encoder_attention_mask=image_atts,
return_dict=True,
2023-12-19 03:24:51 +00:00
output_hidden_states=True,
2023-11-23 05:39:24 +00:00
)
else:
query_output = self.Qformer.bert(
query_embeds=query_tokens,
encoder_hidden_states=image_embeds,
encoder_attention_mask=image_atts,
return_dict=True,
2023-12-19 03:24:51 +00:00
output_hidden_states=True,
2023-11-23 05:39:24 +00:00
)
# import pdb; pdb.set_trace()# 0107test
2023-11-23 05:39:24 +00:00
query_output_to_linear = query_output.last_hidden_state[:,:query_tokens.size(1),:]
2023-12-19 03:24:51 +00:00
if self.use_moeqformer:
2023-12-19 03:24:51 +00:00
gate_loss = query_output.gate_loss # only available in QformerMoE
if self.gate_save_path != None:
all_hidden_states = query_output.hidden_states
# prob_gate_normalized = query_output.gate_loads
beam_scores = query_output.beam_scores
expert_route = query_output.expert_route
gate_route = list()
import numpy as np
import json
import os
try:
for i in range(len(samples['image_id'])):
image_id = samples['image_id'][i]
gate_route.append({
'iters': samples['iters'],
'image_id':image_id,
'q_input': samples['q_input'][i],
'text_output': samples['text_output'][i],
'beam_scores': beam_scores[i].tolist(),
'expert_route': expert_route[i].tolist(),
# 'gate_route_11': prob_gate_normalized[10][i].tolist(),
# 'gate_route_9': prob_gate_normalized[8][i].tolist(),
# 'gate_route_7': prob_gate_normalized[6][i].tolist(),
# 'gate_route_5': prob_gate_normalized[4][i].tolist(),
# 'gate_route_3': prob_gate_normalized[2][i].tolist(),
# 'gate_route_1': prob_gate_normalized[0][i].tolist(),
})
# for layer in [6,8,10]:
# layer_data = all_hidden_states[layer]s
2023-12-19 03:24:51 +00:00
# file_path = os.path.join(self.gate_save_path, f'{image_id}_{str(layer)}.npy')
# x = layer_data.data.cpu().numpy()
# np.save(file_path,x)
with open(os.path.join(self.gate_save_path, 'train_save_beam.json'),'a+') as f:
f.write(f"{json.dumps(gate_route)}\n")
except Exception as e:
print("Gate Save Error....")
print(e)
2023-11-23 05:39:24 +00:00
inputs_llm = self.llm_proj(query_output_to_linear)
atts_llm = torch.ones(inputs_llm.size()[:-1], dtype=torch.long).to(image.device)
self.llm_tokenizer.padding_side = "right"
self.llm_tokenizer.truncation_side = 'left'
text_input_tokens = self.llm_tokenizer(
samples['llm_input'],
return_tensors="pt",
padding="longest",
truncation=True,
max_length=self.max_txt_len,
).to(image.device)
self.llm_tokenizer.truncation_side = 'right'
text_output_tokens = self.llm_tokenizer(
[t + self.llm_tokenizer.eos_token for t in samples['text_output']],
return_tensors="pt",
padding="longest",
truncation=True,
max_length=self.max_output_txt_len,
).to(image.device)
llm_tokens, input_part_targets_len = self.concat_text_input_output(
text_input_tokens.input_ids,
text_input_tokens.attention_mask,
text_output_tokens.input_ids,
text_output_tokens.attention_mask,
)
# do not apply loss to the padding
targets = llm_tokens['input_ids'].masked_fill(
llm_tokens['input_ids'] == self.llm_tokenizer.pad_token_id, -100
)
# do not apply loss to the text input (i.e., instruction)
for i, l in enumerate(input_part_targets_len):
targets[i][:l] = -100
# do not apply loss to the query tokens
empty_targets = (
torch.ones(atts_llm.size(), dtype=torch.long).to(image.device).fill_(-100)
)
targets = torch.cat([empty_targets, targets], dim=1)
inputs_embeds = self.llm_model.get_input_embeddings()(llm_tokens['input_ids'])
inputs_embeds = torch.cat([inputs_llm, inputs_embeds], dim=1)
attention_mask = torch.cat([atts_llm, llm_tokens['attention_mask']], dim=1)
with self.maybe_autocast():
outputs = self.llm_model(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
return_dict=True,
labels=targets,
)
2024-03-04 02:41:24 +00:00
if self.use_moeqformer:
2024-03-04 02:41:24 +00:00
if samples['epoch'] > self.bal_loss_decay_epoch:
loss = outputs.loss
else:
loss = outputs.loss + self.moebert_load_balance * gate_loss
2023-11-23 05:39:24 +00:00
else:
loss = outputs.loss
return {"loss": loss}
@torch.no_grad()
def generate(
self,
samples,
use_nucleus_sampling=False,
num_beams=5,
max_length=256,
min_length=1,
top_p=0.9,
repetition_penalty=1.5,
length_penalty=1,
num_captions=1,
temperature=1,
):
self.llm_tokenizer.padding_side = "left"
image = samples["image"]
bs = image.size(0)
query_tokens = self.query_tokens.expand(bs, -1, -1)
if self.qformer_text_input:
text_Qformer = self.tokenizer(
samples["q_input"],
padding='longest',
truncation=True,
max_length=self.max_txt_len,
return_tensors="pt",
).to(image.device)
query_atts = torch.ones(query_tokens.size()[:-1], dtype=torch.long).to(image.device)
Qformer_atts = torch.cat([query_atts, text_Qformer.attention_mask], dim=1)
with self.maybe_autocast():
image_embeds = self.ln_vision(self.visual_encoder(image))
image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(image.device)
if self.qformer_text_input:
query_output = self.Qformer.bert(
text_Qformer.input_ids,
attention_mask=Qformer_atts,
query_embeds=query_tokens,
encoder_hidden_states=image_embeds,
encoder_attention_mask=image_atts,
return_dict=True,
2023-12-19 03:24:51 +00:00
output_hidden_states=True,
2023-11-23 05:39:24 +00:00
)
else:
query_output = self.Qformer.bert(
query_embeds=query_tokens,
encoder_hidden_states=image_embeds,
encoder_attention_mask=image_atts,
return_dict=True,
2023-12-19 03:24:51 +00:00
output_hidden_states=True,
2023-11-23 05:39:24 +00:00
)
# import pdb; pdb.set_trace()
2023-12-19 03:24:51 +00:00
if self.gate_save_path != None:
all_hidden_states = query_output.hidden_states
# prob_gate_normalized = query_output.gate_loads
beam_scores = query_output.beam_scores
expert_route = query_output.expert_route
import numpy as np
import json
import os
gate_route = list()
try:
for i in range(len(samples['image_id'])):
source = samples['source'][i]
if source in ['gqa']:
image_id = samples['image_id'][i].split('.')[0]
else:
image_id = samples['image_id'][i].split('/')[-1].split('.')[0]
gate_route.append({
'source': source,
'image_id':image_id,
'q_input': samples['q_input'][i],
'beam_scores': beam_scores[i].tolist(),
'expert_route': expert_route[i].tolist(),
# 'gate_route_11': prob_gate_normalized[10][i].tolist(),
# 'gate_route_9': prob_gate_normalized[8][i].tolist(),
# 'gate_route_7': prob_gate_normalized[6][i].tolist(),
# 'gate_route_5': prob_gate_normalized[4][i].tolist(),
# 'gate_route_3': prob_gate_normalized[2][i].tolist(),
# 'gate_route_1': prob_gate_normalized[0][i].tolist(),
})
for layer in [6,7,8,9,10,11]:
if layer in [6,11]:
layer_data = all_hidden_states[layer][i, :, :]
2023-12-19 03:24:51 +00:00
else:
layer_data = all_hidden_states[layer][i*self.moebert_num_beams, :, :]
2023-12-19 03:24:51 +00:00
file_path = os.path.join(self.gate_save_path, f'{image_id}_{str(layer)}.npy')
x = layer_data.data.cpu().numpy()
np.save(file_path,x) # 大功告成
with open(os.path.join(self.gate_save_path, 'generate_save_beam.json'),'a+') as f:
f.write(f"{json.dumps(gate_route)}\n")
except Exception as e:
print("Gate Save Error....")
print(e)
2023-11-23 05:39:24 +00:00
inputs_llm = self.llm_proj(query_output.last_hidden_state[:,:query_tokens.size(1),:])
atts_llm = torch.ones(inputs_llm.size()[:-1], dtype=torch.long).to(image.device)
llm_tokens = self.llm_tokenizer(
samples['llm_input'],
padding="longest",
return_tensors="pt"
).to(image.device)
with self.maybe_autocast():
inputs_embeds = self.llm_model.get_input_embeddings()(llm_tokens.input_ids)
2024-03-04 02:41:24 +00:00
# path = "/mnt/pfs-guan-ssai/nlu/wanghanzi/multimodal/PromptMoE/embedding/"
# np.save(os.join(path, "inputs_llm.npy"), inputs_llm.cpu().numpy)
# np.save(os.join(path, "inputs_llm.npy"), self.llm_model.get_input_embeddings().weight.cpu().numpy)
# samples_copy = samples.copy()
# samples_copy.pop('image', None)
# with open(os.path.join(path, 'test_samples.json'),'a+') as f:
# f.write(f"{json.dumps(samples_copy)}\n")
2023-11-23 05:39:24 +00:00
inputs_embeds = torch.cat([inputs_llm, inputs_embeds], dim=1)
attention_mask = torch.cat([atts_llm, llm_tokens.attention_mask], dim=1)
outputs = self.llm_model.generate(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
do_sample=use_nucleus_sampling,
top_p=top_p,
temperature=temperature,
num_beams=num_beams,
max_length=max_length,
min_length=min_length,
# eos_token_id=self.eos_token_id,
repetition_penalty=repetition_penalty,
length_penalty=length_penalty,
num_return_sequences=num_captions,
)
# outputs[outputs == 0] = 2 # convert output id 0 to 2 (eos_token_id)
output_text = self.llm_tokenizer.batch_decode(outputs, skip_special_tokens=True)
output_text = [text.strip() for text in output_text]
return output_text
def predict_answers(
self,
samples,
num_beams=5,
inference_method="generate",
max_len=10,
min_len=1,
num_ans_candidates=128,
answer_list=None,
prompt="",
length_penalty=0,
**kwargs
):
if isinstance(samples["llm_input"], str):
samples["llm_input"] = [samples["llm_input"]]
# if prompt:
# if prompt.count("{}") == 2:
# if 'ocr_tokens' in samples:
# text_input = [
# prompt.format(', '.join(samples['ocr_tokens'][i][:30]), samples["llm_input"][i])
# for i in range(len(samples["llm_input"]))]
# elif 'choices' in samples:
# text_input = []
# for i in range(len(samples["llm_input"])):
# this_choices = [f"({string.ascii_lowercase[j]}) {ch}" for j, ch in enumerate(samples["choices"][i])]
# this_choices = " ".join(this_choices)
# text_input.append(prompt.format(samples["llm_input"][i], this_choices))
# else:
# text_input = [prompt.format(question) for question in samples["llm_input"]]
# else:
# text_input = samples["llm_input"]
# samples["prompt"] = text_input
output_text = self.generate(
samples,
num_beams=num_beams,
max_length=max_len,
min_length=min_len,
length_penalty=length_penalty
)
if "apply_lemmatizer" in samples.keys() and samples["apply_lemmatizer"]:
output_text = self._lemmatize(output_text)
return output_text
def _lemmatize(self, answers):
def apply(answer):
doc = self.lemmatizer(answer)
words = []
for token in doc:
if token.pos_ in ["NOUN", "VERB"]:
words.append(token.lemma_)
else:
words.append(token.text)
answer = " ".join(words)
return answer
return [apply(answer) for answer in answers]
@property
def lemmatizer(self):
if self._lemmatizer is None:
try:
import spacy
self._lemmatizer = spacy.load("en_core_web_sm")
except ImportError:
logging.error(
"""
Please install spacy and en_core_web_sm model to apply lemmatization.
python -m spacy download en_core_web_sm
OR
import spacy.cli
spacy.cli.download("en_core_web_sm")
"""
)
exit(1)
return self._lemmatizer
@classmethod
def from_config(cls, cfg):
vit_model = cfg.get("vit_model", "eva_clip_g")
q_former_model = cfg.get("q_former_model", "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/blip2/blip2_vicuna7b/blip2_pretrained_vicuna7b.pth")
img_size = cfg.get("image_size")
num_query_token = cfg.get("num_query_token")
llm_model = cfg.get("llm_model")
drop_path_rate = cfg.get("drop_path_rate", 0)
use_grad_checkpoint = cfg.get("use_grad_checkpoint", False)
vit_precision = cfg.get("vit_precision", "fp16")
freeze_vit = cfg.get("freeze_vit", True)
freeze_llm = cfg.get("freeze_llm", True)
freeze_qformer = cfg.get("freeze_qformer", False)
freeze_proj = cfg.get("freeze_proj", False)
prompt = cfg.get("prompt", "")
max_txt_len = cfg.get("max_txt_len", 128)
max_output_txt_len = cfg.get("max_output_txt_len", 256)
apply_lemmatizer = cfg.get("apply_lemmatizer", False)
qformer_text_input = cfg.get("qformer_text_input", True)
use_moeqformer = cfg.get("use_moeqformer", False)
2023-12-19 03:24:51 +00:00
use_route_moe = cfg.get("use_route_moe", False)
moebert_num_beams = cfg.get("moebert_num_beams", 2)
2023-11-23 05:39:24 +00:00
moebert_expert_num = cfg.get("moebert_expert_num", 5)
moebert_route_method = cfg.get("moebert_route_method", "gate-sentence")
moebert_load_balance = cfg.get("moebert_load_balance", 0.1)
moe_topk = cfg.get("moe_topk", 1)
2023-12-19 03:24:51 +00:00
use_balance_loss = cfg.get("use_balance_loss", True)
moe_weight_type = cfg.get("moe_weight_type",'l2_norm')
gate_save_path = cfg.get("gate_save_path", None)
2024-03-04 02:41:24 +00:00
bal_loss_decay_epoch = cfg.get("bal_loss_decay_epoch", 3)
ln_position = cfg.get("ln_position","out")
2023-11-23 05:39:24 +00:00
model = cls(
vit_model=vit_model,
img_size=img_size,
q_former_model=q_former_model,
drop_path_rate=drop_path_rate,
use_grad_checkpoint=use_grad_checkpoint,
vit_precision=vit_precision,
freeze_vit=freeze_vit,
freeze_llm=freeze_llm,
freeze_qformer=freeze_qformer,
freeze_proj=freeze_proj,
num_query_token=num_query_token,
llm_model=llm_model,
prompt=prompt,
max_txt_len=max_txt_len,
max_output_txt_len=max_output_txt_len,
apply_lemmatizer=apply_lemmatizer,
qformer_text_input=qformer_text_input,
use_moeqformer=use_moeqformer,
2023-12-19 03:24:51 +00:00
use_route_moe=use_route_moe,
moebert_num_beams=moebert_num_beams,
2023-11-23 05:39:24 +00:00
moebert_expert_num=moebert_expert_num,
moebert_route_method=moebert_route_method,
moebert_load_balance=moebert_load_balance,
moe_topk=moe_topk,
2023-12-19 03:24:51 +00:00
use_balance_loss=use_balance_loss,
moe_weight_type=moe_weight_type,
gate_save_path=gate_save_path,
2024-03-04 02:41:24 +00:00
bal_loss_decay_epoch=bal_loss_decay_epoch,
ln_position=ln_position,
2023-11-23 05:39:24 +00:00
)
# if qformer_text_input:
# # Hard-coded to load from BLIP-2 stage-1 pre-trained model (not ideal)
# model.load_from_pretrained(
# url_or_filename="https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/blip2_pretrained.pth"
# )
model.load_checkpoint_from_config(cfg)
# check update params
print("Updating following parameters:")
for name, param in model.named_parameters():
if param.requires_grad == True:
print(name)
# [name for name, param in model.named_parameters() if (param.requires_grad == False and 'Qformer' in name and 'intermediate_query' in name)]
# import pdb; pdb.set_trace()# 0107test
2023-11-23 05:39:24 +00:00
return model