mirror of
https://github.com/Vision-CAIR/MiniGPT-4.git
synced 2025-04-26 15:40:45 +00:00
Route MoE Promote (Post/Pre) update 0110
This commit is contained in:
parent
ce67e5669a
commit
eb022668a3
5294
Pre_PromptMoE_RawProb_backward_graph
Normal file
5294
Pre_PromptMoE_RawProb_backward_graph
Normal file
File diff suppressed because it is too large
Load Diff
BIN
Pre_PromptMoE_RawProb_backward_graph.pdf
Normal file
BIN
Pre_PromptMoE_RawProb_backward_graph.pdf
Normal file
Binary file not shown.
@ -1,4 +1,4 @@
|
||||
name: minigptv
|
||||
name: promptmoe
|
||||
channels:
|
||||
- pytorch
|
||||
- defaults
|
||||
|
@ -17,14 +17,14 @@ datasets:
|
||||
# md5: aa31ac474cf6250ebb81d18348a07ed8
|
||||
storage:
|
||||
- /mnt/pfs-guan-ssai/nlu/wanghanzi/data/COCO_Cap/coco_karpathy_train.json
|
||||
val:
|
||||
url: https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_val.json
|
||||
storage:
|
||||
- /mnt/pfs-guan-ssai/nlu/wanghanzi/data/COCO_Cap/coco_karpathy_val.json
|
||||
test:
|
||||
url: https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_test.json
|
||||
storage:
|
||||
- /mnt/pfs-guan-ssai/nlu/wanghanzi/data/COCO_Cap/coco_karpathy_test.json
|
||||
# val:
|
||||
# url: https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_val.json
|
||||
# storage:
|
||||
# - /mnt/pfs-guan-ssai/nlu/wanghanzi/data/COCO_Cap/coco_karpathy_val.json
|
||||
# test:
|
||||
# url: https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_test.json
|
||||
# storage:
|
||||
# - /mnt/pfs-guan-ssai/nlu/wanghanzi/data/COCO_Cap/coco_karpathy_test.json
|
||||
|
||||
images:
|
||||
storage: /mnt/pfs-guan-ssai/nlu/dingyifeng/data/COCO
|
||||
|
@ -20,6 +20,7 @@ datasets:
|
||||
- https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/vqav2/v2_mscoco_val2014_annotations.json
|
||||
storage:
|
||||
- /mnt/pfs-guan-ssai/nlu/wanghanzi/data/VQAv2/vqa_val_eval.json
|
||||
# - /mnt/pfs-guan-ssai/nlu/wanghanzi/data/VQAv2/vqa_train.json
|
||||
- /mnt/pfs-guan-ssai/nlu/wanghanzi/data/VQAv2/answer_list.json
|
||||
- /mnt/pfs-guan-ssai/nlu/wanghanzi/data/VQAv2/v2_OpenEnded_mscoco_val2014_questions.json
|
||||
- /mnt/pfs-guan-ssai/nlu/wanghanzi/data/VQAv2/v2_mscoco_val2014_annotations.json
|
||||
@ -29,6 +30,7 @@ datasets:
|
||||
- https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/vqav2/answer_list.json
|
||||
storage:
|
||||
- /mnt/pfs-guan-ssai/nlu/wanghanzi/data/VQAv2/vqa_test.json
|
||||
# - /mnt/pfs-guan-ssai/nlu/wanghanzi/data/VQAv2/vqa_train.json
|
||||
- /mnt/pfs-guan-ssai/nlu/wanghanzi/data/VQAv2/answer_list.json
|
||||
|
||||
images:
|
||||
|
@ -20,6 +20,7 @@ datasets:
|
||||
- https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/okvqa/mscoco_val2014_annotations.json
|
||||
storage:
|
||||
- /mnt/pfs-guan-ssai/nlu/wanghanzi/data/OKVQA/okvqa_val_eval.json
|
||||
# - /mnt/pfs-guan-ssai/nlu/wanghanzi/data/OKVQA/okvqa_train.json
|
||||
- /mnt/pfs-guan-ssai/nlu/wanghanzi/data/OKVQA/okvqa_answer_list_train.json
|
||||
- /mnt/pfs-guan-ssai/nlu/wanghanzi/data/OKVQA/OpenEnded_mscoco_val2014_questions.json
|
||||
- /mnt/pfs-guan-ssai/nlu/wanghanzi/data/OKVQA/mscoco_val2014_annotations.json
|
||||
@ -32,6 +33,7 @@ datasets:
|
||||
- https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/okvqa/mscoco_val2014_annotations.json
|
||||
storage:
|
||||
- /mnt/pfs-guan-ssai/nlu/wanghanzi/data/OKVQA/okvqa_val_eval.json
|
||||
# - /mnt/pfs-guan-ssai/nlu/wanghanzi/data/OKVQA/okvqa_train.json
|
||||
# - /mnt/pfs-guan-ssai/nlu/wanghanzi/data/OKVQA/okvqa_val_eval_part100.json
|
||||
- /mnt/pfs-guan-ssai/nlu/wanghanzi/data/OKVQA/okvqa_answer_list_train.json
|
||||
- /mnt/pfs-guan-ssai/nlu/wanghanzi/data/OKVQA/OpenEnded_mscoco_val2014_questions.json
|
||||
|
@ -105,6 +105,8 @@ class COCOCaptionDataset(BaseDataset, __DisplMixin):
|
||||
'Using language, provide a short account of the image.',
|
||||
'Use a few words to illustrate what is happening in the picture.',
|
||||
]
|
||||
self.source = 'coco_cap'
|
||||
|
||||
def __getitem__(self, index):
|
||||
|
||||
# TODO this assumes image input, not general enough
|
||||
@ -118,13 +120,20 @@ class COCOCaptionDataset(BaseDataset, __DisplMixin):
|
||||
image = self.vis_processor(image)
|
||||
caption = self.text_processor(ann["caption"])
|
||||
|
||||
instruction = random.choice(self.instruction_pool)
|
||||
instruction = "<Img><ImageHere></Img> [caption] {} ".format(instruction)
|
||||
# instruction = random.choice(self.instruction_pool)
|
||||
# instruction = "<Img><ImageHere></Img> [caption] {} ".format(instruction)
|
||||
q_input = ""
|
||||
llm_input = random.choice(self.instruction_pool)
|
||||
|
||||
return {
|
||||
"image": image,
|
||||
"image_id": ann["image"],
|
||||
"answer": caption,
|
||||
"instruction_input": instruction,
|
||||
"q_input": q_input,
|
||||
"llm_input": llm_input,
|
||||
"text_input": llm_input,
|
||||
"text_output": caption,
|
||||
"source": 'coco_cap',
|
||||
}
|
||||
|
||||
class CaptionEvalDataset(BaseDataset, __DisplMixin):
|
||||
|
@ -31,6 +31,7 @@ class COCOCapEvalDataset(CaptionEvalDataset):
|
||||
split (string): val or test
|
||||
"""
|
||||
super().__init__(vis_processor, text_processor, vis_root, ann_paths)
|
||||
self.source = 'coco_cap'
|
||||
|
||||
def __getitem__(self, index):
|
||||
ann = self.annotation[index]
|
||||
|
@ -31,7 +31,6 @@ class MultiIterLoader:
|
||||
if ratios is None:
|
||||
ratios = [1.0] * len(loaders)
|
||||
else:
|
||||
# import pdb; pdb.set_trace()
|
||||
assert len(ratios) == len(loaders)
|
||||
ratios = [float(ratio) / sum(ratios) for ratio in ratios]
|
||||
|
||||
|
@ -12,7 +12,6 @@ from tqdm import tqdm
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
import torch.backends.cudnn as cudnn
|
||||
from datasets import load_dataset
|
||||
import sys
|
||||
sys.path.append("/mnt/pfs-guan-ssai/nlu/wanghanzi/multimodal/PromptMoE")
|
||||
|
||||
@ -248,6 +247,7 @@ if 'vsr' in args.dataset:
|
||||
img_path = cfg.evaluation_datasets_cfg["vsr"]["img_path"]
|
||||
batch_size = cfg.evaluation_datasets_cfg["vsr"]["batch_size"]
|
||||
max_new_tokens = cfg.evaluation_datasets_cfg["vsr"]["max_new_tokens"]
|
||||
from datasets import load_dataset
|
||||
|
||||
annotation = load_dataset("cambridgeltl/vsr_zeroshot", split='test')
|
||||
data = VSREvalData(annotation, vis_processor, img_path)
|
||||
|
@ -386,17 +386,23 @@ class BertOutput(nn.Module): # Add & Norm
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
# remove LayerNorm
|
||||
def __init__(self, config):
|
||||
nn.Module.__init__(self)
|
||||
# first layer
|
||||
self.intermediate_query = BertIntermediate(config)
|
||||
# second layer
|
||||
self.output_query = BertOutput(config)
|
||||
super().__init__()
|
||||
self.dense1 = nn.Linear(config.hidden_size, config.intermediate_size)
|
||||
if isinstance(config.hidden_act, str):
|
||||
self.intermediate_act_fn = ACT2FN[config.hidden_act]
|
||||
else:
|
||||
self.intermediate_act_fn = config.hidden_act
|
||||
self.dense2 = nn.Linear(config.intermediate_size, config.hidden_size)
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob) # adjust dropout ratio 0.1->0.2
|
||||
# self.dropout = nn.Dropout(0.2) # adjust dropout ratio 0.1->0.2
|
||||
|
||||
def forward(self, hidden_states: Tensor):
|
||||
input_tensor = hidden_states
|
||||
intermediate_output = self.intermediate_query(hidden_states)
|
||||
hidden_states = self.output_query(intermediate_output, input_tensor)
|
||||
hidden_states = self.dense1(hidden_states)
|
||||
hidden_states = self.intermediate_act_fn(hidden_states)
|
||||
hidden_states = self.dense2(hidden_states)
|
||||
hidden_states = self.dropout(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
@ -440,6 +446,7 @@ class BertLayer(nn.Module):
|
||||
)
|
||||
else:
|
||||
self.experts = ffn
|
||||
self.expert_ln = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -494,6 +501,7 @@ class BertLayer(nn.Module):
|
||||
moe_ffn_attention_input = query_attention_output[:, :query_length, :]
|
||||
moe_ffn_attention_mask = attention_mask.squeeze(dim=1).squeeze(dim=1)[:, :query_length]
|
||||
layer_output = self.feed_forward_query_moe(moe_ffn_attention_input, moe_ffn_attention_mask) # layer_output, gate_loss, gate_load
|
||||
# import pdb; pdb.set_trace() # test0107
|
||||
|
||||
if attention_output.shape[1] > query_length: # have text input in Qformer
|
||||
layer_output_text = apply_chunking_to_forward(
|
||||
@ -503,6 +511,7 @@ class BertLayer(nn.Module):
|
||||
attention_output[:, query_length:, :],
|
||||
)
|
||||
layer_output = (torch.cat([layer_output[0], layer_output_text], dim=1), layer_output[1], layer_output[2])
|
||||
|
||||
else:
|
||||
layer_output = apply_chunking_to_forward(
|
||||
self.feed_forward_chunk,
|
||||
@ -524,15 +533,14 @@ class BertLayer(nn.Module):
|
||||
|
||||
def feed_forward_query_moe(self, attention_output, expert_attention_mask):
|
||||
if not self.use_experts:
|
||||
layer_output = self.experts(attention_output)
|
||||
hidden_states = self.experts(attention_output)
|
||||
layer_output = self.expert_ln(hidden_states + attention_output)
|
||||
return layer_output, 0.0, []
|
||||
|
||||
# if not self.importance_processor.is_moe:
|
||||
# raise RuntimeError("Need to turn the model to a MoE first.")
|
||||
|
||||
layer_output, gate_loss, gate_load = self.experts(
|
||||
hidden_states, gate_loss, gate_load = self.experts(
|
||||
attention_output, expert_attention_mask
|
||||
)
|
||||
layer_output = self.expert_ln(hidden_states + attention_output)
|
||||
return layer_output, gate_loss, gate_load
|
||||
|
||||
class BertEncoder(nn.Module):
|
||||
|
@ -46,10 +46,9 @@ from transformers.utils import logging
|
||||
from transformers.models.bert.configuration_bert import BertConfig
|
||||
|
||||
from minigpt4.models.moe.utils import (
|
||||
FeedForward,
|
||||
MoEModelOutput,
|
||||
MoEModelOutputWithPooling,
|
||||
use_experts,
|
||||
use_experts_route,
|
||||
moe_layer_judge,
|
||||
)
|
||||
from minigpt4.models.moe.route_moe_layer import RouteMoELayer
|
||||
@ -378,13 +377,14 @@ class BertOutput(nn.Module): # Add & Norm
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
|
||||
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) # 1
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
|
||||
def forward(self, hidden_states, input_tensor):
|
||||
hidden_states = self.dense(hidden_states)
|
||||
hidden_states = self.dropout(hidden_states)
|
||||
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
||||
# Move LayerNorm & ResNet out of FFN After MoEFFN
|
||||
hidden_states = self.LayerNorm(hidden_states + input_tensor) # 1
|
||||
return hidden_states
|
||||
|
||||
|
||||
@ -429,7 +429,7 @@ class BertLayer(nn.Module):
|
||||
self.output_query = BertOutput(config)
|
||||
|
||||
# Add MoE FFN
|
||||
self.use_experts = use_experts(layer_num)
|
||||
self.use_experts = use_experts_route(layer_num)
|
||||
self.layer_judge = moe_layer_judge(layer_num)
|
||||
self.num_beams = config.moebert_num_beams
|
||||
ffn = FeedForward(config)
|
||||
@ -442,10 +442,13 @@ class BertLayer(nn.Module):
|
||||
num_beams=config.moebert_num_beams,
|
||||
layer_judge = self.layer_judge,
|
||||
route_method=config.route_method,
|
||||
weight_type=config.moe_weight_type,
|
||||
)
|
||||
else:
|
||||
self.experts = ffn
|
||||
|
||||
# self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
@ -463,7 +466,7 @@ class BertLayer(nn.Module):
|
||||
self_attn_past_key_value = (
|
||||
past_key_value[:2] if past_key_value is not None else None
|
||||
)
|
||||
# import pdb;pdb.set_trace()
|
||||
# import pdb; pdb.set_trace() # 0107test
|
||||
|
||||
# adjust the dimension of hidden_states, attention_mask, encoder_attention_mask and encoder_hidden_states to be the same
|
||||
if self.num_beams > 1:
|
||||
@ -494,10 +497,6 @@ class BertLayer(nn.Module):
|
||||
|
||||
present_key_value = self_attention_outputs[-1]
|
||||
|
||||
# import pdb;pdb.set_trace()
|
||||
# print(self.layer_num, hidden_states.shape, attention_mask.shape)
|
||||
|
||||
|
||||
if query_length > 0:
|
||||
query_attention_output = attention_output[:, :query_length, :]
|
||||
|
||||
@ -526,7 +525,8 @@ class BertLayer(nn.Module):
|
||||
moe_ffn_attention_input = query_attention_output[:, :query_length, :]
|
||||
moe_ffn_attention_mask = attention_mask.squeeze(dim=1).squeeze(dim=1)[:, :query_length]
|
||||
layer_output = self.feed_forward_query_moe(moe_ffn_attention_input, moe_ffn_attention_mask, beam_scores, expert_route)
|
||||
# layer_output = (layer_output, beam_scores, expert_route, beam_idx)
|
||||
# layer_output = (layer_output, beam_scores, expert_route, beam_idx, importance_loss)
|
||||
# import pdb; pdb.set_trace() # 0107test
|
||||
|
||||
if attention_output.shape[1] > query_length: # have text input in Qformer
|
||||
layer_output_text = apply_chunking_to_forward(
|
||||
@ -535,7 +535,8 @@ class BertLayer(nn.Module):
|
||||
self.seq_len_dim,
|
||||
attention_output[:, query_length:, :],
|
||||
)
|
||||
if layer_output[0].shape[0] == layer_output_text.shape[0]*self.num_beams and self.num_beams>1:
|
||||
if self.layer_judge == 'first' and self.num_beams>1:
|
||||
# if layer_output[0].shape[0] == layer_output_text.shape[0]*self.num_beams and self.num_beams>1:
|
||||
# adjust the dimension of layer_output_text to bz*num_beams
|
||||
layer_output_text = self.adjust_layer_output_text(layer_output_text)
|
||||
|
||||
@ -550,7 +551,8 @@ class BertLayer(nn.Module):
|
||||
# layer_output & layer_output_text dimen_0 from bz*num_beams to bz
|
||||
layer_output, layer_output_text = self.route_moe_last_layer_top1(layer_output, layer_output_text)
|
||||
|
||||
layer_output = (torch.cat([layer_output[0], layer_output_text], dim=1), layer_output[1], layer_output[2])
|
||||
layer_output = (torch.cat([layer_output[0], layer_output_text], dim=1), layer_output[1], layer_output[2], layer_output[3],layer_output[4])
|
||||
# import pdb; pdb.set_trace() # 0107test
|
||||
|
||||
else:
|
||||
layer_output = apply_chunking_to_forward(
|
||||
@ -559,7 +561,7 @@ class BertLayer(nn.Module):
|
||||
self.seq_len_dim,
|
||||
attention_output,
|
||||
)
|
||||
layer_output = (layer_output, None, None)
|
||||
layer_output = (layer_output, None, None, None, 0.0)
|
||||
|
||||
outputs = (layer_output,) + outputs
|
||||
|
||||
@ -594,24 +596,27 @@ class BertLayer(nn.Module):
|
||||
beam_scores_new = beam_scores[selects]
|
||||
expert_route_new = expert_route[selects]
|
||||
|
||||
return (hidden_states_new, beam_scores_new, expert_route_new), layer_output_text
|
||||
return (hidden_states_new, beam_scores_new, expert_route_new, layer_output[3], layer_output[4]), layer_output_text
|
||||
|
||||
|
||||
def feed_forward_chunk(self, attention_output):
|
||||
intermediate_output = self.intermediate(attention_output)
|
||||
layer_output = self.output(intermediate_output, attention_output)
|
||||
# layer_output = self.LayerNorm(layer_output + attention_output)
|
||||
return layer_output
|
||||
|
||||
def feed_forward_query_moe(self, attention_output, expert_attention_mask, beam_scores, expert_route):
|
||||
|
||||
if not self.use_experts:
|
||||
layer_output = self.experts(attention_output)
|
||||
return layer_output, None, None, None
|
||||
# layer_output = self.LayerNorm(layer_output + attention_output)
|
||||
return layer_output, None, None, None, 0.0
|
||||
|
||||
layer_output, beam_scores, expert_route, beam_idx = self.experts(
|
||||
layer_output, beam_scores, expert_route, beam_idx, importance_loss = self.experts(
|
||||
attention_output, expert_attention_mask, beam_scores, expert_route
|
||||
)
|
||||
return layer_output, beam_scores, expert_route, beam_idx
|
||||
|
||||
# layer_output = self.LayerNorm(layer_output + attention_output)
|
||||
return layer_output, beam_scores, expert_route, beam_idx, importance_loss
|
||||
|
||||
class BertEncoder(nn.Module):
|
||||
def __init__(self, config):
|
||||
@ -645,6 +650,7 @@ class BertEncoder(nn.Module):
|
||||
next_decoder_cache = () if use_cache else None
|
||||
beam_scores=None
|
||||
expert_route=None
|
||||
importance_loss = 0
|
||||
for i in range(self.config.num_hidden_layers):
|
||||
|
||||
layer_module = self.layer[i]
|
||||
@ -693,6 +699,7 @@ class BertEncoder(nn.Module):
|
||||
hidden_states = layer_outputs[0][0]
|
||||
beam_scores = beam_scores if layer_outputs[0][1] == None else layer_outputs[0][1]
|
||||
expert_route = expert_route if layer_outputs[0][2] == None else layer_outputs[0][2]
|
||||
importance_loss += layer_outputs[0][4]
|
||||
|
||||
if use_cache:
|
||||
next_decoder_cache += (layer_outputs[-1],)
|
||||
@ -724,6 +731,7 @@ class BertEncoder(nn.Module):
|
||||
cross_attentions=all_cross_attentions,
|
||||
beam_scores=beam_scores,
|
||||
expert_route=expert_route,
|
||||
gate_loss=importance_loss,
|
||||
)
|
||||
|
||||
|
||||
@ -1103,6 +1111,7 @@ class BertModel(BertPreTrainedModel):
|
||||
cross_attentions=encoder_outputs.cross_attentions,
|
||||
beam_scores=encoder_outputs.beam_scores,
|
||||
expert_route=encoder_outputs.expert_route,
|
||||
gate_loss=encoder_outputs.gate_loss
|
||||
)
|
||||
|
||||
|
||||
|
@ -62,7 +62,7 @@ class Blip2Base(BaseModel):
|
||||
return Qformer, query_tokens
|
||||
|
||||
@classmethod
|
||||
def init_RouteMoEQformer(cls, num_query_token, vision_width, moebert_expert_num, moebert_num_beams, route_method, cross_attention_freq=2):
|
||||
def init_RouteMoEQformer(cls, num_query_token, vision_width, moebert_expert_num, moebert_num_beams, route_method, moe_weight_type, cross_attention_freq=2):
|
||||
moe_encoder_config = BertConfig.from_pretrained("/mnt/pfs-guan-ssai/nlu/wanghanzi/models/bert-base-uncased")
|
||||
|
||||
moe_encoder_config.encoder_width = vision_width
|
||||
@ -74,6 +74,7 @@ class Blip2Base(BaseModel):
|
||||
moe_encoder_config.moebert_expert_num = moebert_expert_num
|
||||
moe_encoder_config.moebert_num_beams = moebert_num_beams
|
||||
moe_encoder_config.route_method = route_method
|
||||
moe_encoder_config.moe_weight_type = moe_weight_type
|
||||
|
||||
RouteMoEQformer = BertMoERouteLMHeadModel.from_pretrained(
|
||||
"/mnt/pfs-guan-ssai/nlu/wanghanzi/models/bert-base-uncased", config=moe_encoder_config
|
||||
|
@ -99,6 +99,7 @@ class Blip2VicunaInstruct(Blip2Base):
|
||||
moebert_expert_num=moebert_expert_num,
|
||||
moebert_num_beams=moebert_num_beams,
|
||||
route_method=moebert_route_method,
|
||||
moe_weight_type=moe_weight_type,
|
||||
cross_attention_freq=2
|
||||
)
|
||||
else:
|
||||
@ -118,7 +119,6 @@ class Blip2VicunaInstruct(Blip2Base):
|
||||
num_query_token, self.visual_encoder.num_features
|
||||
)
|
||||
|
||||
# import pdb;pdb.set_trace()
|
||||
if not qformer_text_input:
|
||||
self.Qformer.bert.embeddings.word_embeddings = None
|
||||
self.Qformer.bert.embeddings.position_embeddings = None
|
||||
@ -178,6 +178,19 @@ class Blip2VicunaInstruct(Blip2Base):
|
||||
if "_query" in name and "experts" not in name: # raw ffn_query not update
|
||||
param.requires_grad = False
|
||||
|
||||
ln_pattern = r"bert\.encoder\.layer\.\d+\.expert_ln\.(weight|bias)"
|
||||
if re.match(ln_pattern, name):
|
||||
key_orig = re.sub('expert_ln', 'output_query.LayerNorm', name)
|
||||
param.data.copy_(state_dict[key_orig])
|
||||
d1_pattern = r"bert\.encoder\.layer\.(\d+)\.experts(\.|\.experts\.\d+\.)dense1\.(weight|bias)"
|
||||
if re.match(d1_pattern, name):
|
||||
key_orig = re.sub(r'experts(\.|\.experts\.\d+\.)dense1', 'intermediate_query.dense', name)
|
||||
param.data.copy_(state_dict[key_orig])
|
||||
d2_pattern = r"bert\.encoder\.layer\.(\d+)\.experts(\.|\.experts\.\d+\.)dense2\.(weight|bias)"
|
||||
if re.match(d2_pattern, name):
|
||||
key_orig = re.sub(r'experts(\.|\.experts\.\d+\.)dense2', 'output_query.dense', name)
|
||||
param.data.copy_(state_dict[key_orig])
|
||||
|
||||
# freeze qformer
|
||||
if freeze_qformer:
|
||||
for name, param in self.Qformer.named_parameters():
|
||||
@ -205,6 +218,7 @@ class Blip2VicunaInstruct(Blip2Base):
|
||||
self.use_moeqformer = use_moeqformer
|
||||
self.use_route_moe = use_route_moe
|
||||
self.moebert_load_balance = moebert_load_balance
|
||||
self.moebert_num_beams = moebert_num_beams
|
||||
|
||||
self.gate_save_path = gate_save_path
|
||||
# if self.gate_save_path != None:
|
||||
@ -242,7 +256,7 @@ class Blip2VicunaInstruct(Blip2Base):
|
||||
# print(samples["text_input"])
|
||||
# print(samples["text_output"])
|
||||
# print('-----------------')
|
||||
# import pdb;pdb.set_trace()
|
||||
# import pdb;pdb.set_trace() # 0107test
|
||||
image = samples["image"]
|
||||
with self.maybe_autocast():
|
||||
image_embeds = self.ln_vision(self.visual_encoder(image))
|
||||
@ -278,10 +292,10 @@ class Blip2VicunaInstruct(Blip2Base):
|
||||
return_dict=True,
|
||||
output_hidden_states=True,
|
||||
)
|
||||
|
||||
# import pdb; pdb.set_trace()# 0107test
|
||||
query_output_to_linear = query_output.last_hidden_state[:,:query_tokens.size(1),:]
|
||||
|
||||
if self.use_moeqformer and not self.use_route_moe:
|
||||
if self.use_moeqformer:
|
||||
gate_loss = query_output.gate_loss # only available in QformerMoE
|
||||
|
||||
if self.gate_save_path != None:
|
||||
@ -312,7 +326,7 @@ class Blip2VicunaInstruct(Blip2Base):
|
||||
# 'gate_route_1': prob_gate_normalized[0][i].tolist(),
|
||||
})
|
||||
# for layer in [6,8,10]:
|
||||
# layer_data = all_hidden_states[layer]
|
||||
# layer_data = all_hidden_states[layer]s
|
||||
# file_path = os.path.join(self.gate_save_path, f'{image_id}_{str(layer)}.npy')
|
||||
# x = layer_data.data.cpu().numpy()
|
||||
# np.save(file_path,x)
|
||||
@ -323,7 +337,6 @@ class Blip2VicunaInstruct(Blip2Base):
|
||||
print("Gate Save Error....")
|
||||
print(e)
|
||||
|
||||
|
||||
inputs_llm = self.llm_proj(query_output_to_linear)
|
||||
atts_llm = torch.ones(inputs_llm.size()[:-1], dtype=torch.long).to(image.device)
|
||||
|
||||
@ -380,7 +393,7 @@ class Blip2VicunaInstruct(Blip2Base):
|
||||
labels=targets,
|
||||
)
|
||||
|
||||
if self.use_moeqformer and not self.use_route_moe:
|
||||
if self.use_moeqformer:
|
||||
loss = outputs.loss + self.moebert_load_balance * gate_loss
|
||||
else:
|
||||
loss = outputs.loss
|
||||
@ -441,6 +454,8 @@ class Blip2VicunaInstruct(Blip2Base):
|
||||
output_hidden_states=True,
|
||||
)
|
||||
|
||||
# import pdb; pdb.set_trace()
|
||||
|
||||
if self.gate_save_path != None:
|
||||
all_hidden_states = query_output.hidden_states
|
||||
# prob_gate_normalized = query_output.gate_loads
|
||||
@ -471,11 +486,11 @@ class Blip2VicunaInstruct(Blip2Base):
|
||||
# 'gate_route_3': prob_gate_normalized[2][i].tolist(),
|
||||
# 'gate_route_1': prob_gate_normalized[0][i].tolist(),
|
||||
})
|
||||
for layer in [6,8,10]:
|
||||
if layer == 6:
|
||||
layer_data = all_hidden_states[layer][i, :32, :]
|
||||
for layer in [6,7,8,9,10,11]:
|
||||
if layer in [6,11]:
|
||||
layer_data = all_hidden_states[layer][i, :, :]
|
||||
else:
|
||||
layer_data = all_hidden_states[layer][i*3, :32, :]
|
||||
layer_data = all_hidden_states[layer][i*self.moebert_num_beams, :, :]
|
||||
file_path = os.path.join(self.gate_save_path, f'{image_id}_{str(layer)}.npy')
|
||||
x = layer_data.data.cpu().numpy()
|
||||
np.save(file_path,x) # 大功告成
|
||||
@ -683,5 +698,6 @@ class Blip2VicunaInstruct(Blip2Base):
|
||||
for name, param in model.named_parameters():
|
||||
if param.requires_grad == True:
|
||||
print(name)
|
||||
|
||||
# [name for name, param in model.named_parameters() if (param.requires_grad == False and 'Qformer' in name and 'intermediate_query' in name)]
|
||||
# import pdb; pdb.set_trace()# 0107test
|
||||
return model
|
||||
|
@ -21,7 +21,6 @@ class MoELayer(nn.Module):
|
||||
else:
|
||||
raise KeyError("Routing method not supported.")
|
||||
|
||||
|
||||
def _forward_gate_sentence(self, x, attention_mask):
|
||||
"""
|
||||
x: query_attention_output , torch.Size([bz, 32, 768])
|
||||
@ -77,6 +76,64 @@ class MoELayer(nn.Module):
|
||||
print('Layer Qformer MoE: \n',prob_gate)
|
||||
return moe_result, select_prob_gate, gate
|
||||
|
||||
def _forward_gate_sentence_post(self, x, attention_mask):
|
||||
"""
|
||||
x: query_attention_output; torch.Size([bz, 32, 768])
|
||||
attention_mask: torch.ones([bz, 32])
|
||||
bz = 4
|
||||
x = torch.randn(bz,32,768)
|
||||
attention_mask = torch.ones([bz, 32])
|
||||
|
||||
"""
|
||||
attention_mask = torch.ones(attention_mask.shape[0], attention_mask.shape[1]).to(x.device)
|
||||
x_masked = x * attention_mask.unsqueeze(-1) # torch.Size([bz, 32, 768])
|
||||
|
||||
def forward_expert(input_x, expert_idx):
|
||||
# input_x += torch.randn(4,32,768)
|
||||
# return input_x
|
||||
output_x = self.experts[expert_idx].forward(input_x)
|
||||
return output_x
|
||||
|
||||
outputs = list()
|
||||
logits_gate_lst = list()
|
||||
for expert_idx in range(self.num_experts):
|
||||
output_x = forward_expert(x_masked, expert_idx)
|
||||
outputs.append(output_x.unsqueeze(0))
|
||||
|
||||
output_x_aver = output_x.sum(1) / attention_mask.unsqueeze(-1).sum(1) # torch.Size([bz, 768])
|
||||
# gate_acore = self.gates[expert_idx](output_x_aver)
|
||||
gate_score = self.gate(output_x_aver)
|
||||
logits_gate_lst.append(gate_score)
|
||||
|
||||
candidate_output = torch.cat(outputs) # torch.Size([num_expert, bz, 32, 768])
|
||||
logits_gate = torch.cat(logits_gate_lst,dim=1)# torch.Size([bz, num_expert])
|
||||
prob_gate = F.softmax(logits_gate, dim=-1) # torch.Size([bz, num_experts])
|
||||
topk_values, gate = torch.topk(prob_gate, self.topk, dim=1) # gate, 每个样本被分配的expert: torch.Size([bz, topk])
|
||||
num_sentences = F.one_hot(gate, self.num_experts).sum(1).gt(0).sum(0) # 每个expert被分配的样本数 torch.Size([num_expert])
|
||||
gate_load = num_sentences.clone()
|
||||
|
||||
# load balancing loss
|
||||
if self.use_balance_loss:
|
||||
balance_loss = self._balancing_loss(prob_gate, num_sentences)
|
||||
else:
|
||||
balance_loss = 0.0
|
||||
|
||||
# importance loss
|
||||
importance_loss = self._importance_auxiliary_loss(prob_gate)
|
||||
|
||||
# output_average = candidate_output.sum(2) / candidate_attn_mask.unsqueeze(-1).sum(2) # torch.Size([num_expert, bz, 768])
|
||||
# output_average = torch.permute(output_average, (1, 0, 2)) # torch.Size([bz, num_expert, 768])
|
||||
# logits_gate = self.gate(output_average) # torch.Size([bz, num_experts, 1])
|
||||
|
||||
prob_gate_topk = torch.zeros_like(prob_gate)
|
||||
prob_gate_topk.scatter_(1, gate, topk_values)
|
||||
prob_gate_normalized = prob_gate_topk / prob_gate_topk.sum(dim=1, keepdim=True) # torch.Size([bz, num_expert])
|
||||
candidate_output_ad = torch.permute(candidate_output, (1, 0, 2, 3)) # torch.Size([bz, num_expert, 32, 768])
|
||||
results = prob_gate_normalized.unsqueeze(-1).unsqueeze(-1) * candidate_output_ad # torch.Size([bz, num_expert, 32, 768])
|
||||
moe_result = torch.sum(results, dim=1) # torch.Size([bz, 32, 768])
|
||||
import pdb;pdb.set_trace()
|
||||
|
||||
return moe_result, (balance_loss+importance_loss), prob_gate_normalized
|
||||
|
||||
def forward(self, x, attention_mask):
|
||||
if self.route_method == "gate-token":
|
||||
@ -95,7 +152,7 @@ class MoELayer(nn.Module):
|
||||
|
||||
|
||||
class RouteMoELayer(nn.Module):
|
||||
def __init__(self, hidden_size, expert, gate, num_experts, num_beams=2, layer_judge=None, route_method="pre-route"):
|
||||
def __init__(self, hidden_size, expert, num_experts, num_beams=2, layer_judge=None, route_method="pre-route", weight_type="ffn_prob"):
|
||||
# remove hash list
|
||||
nn.Module.__init__(self)
|
||||
self.num_experts = num_experts
|
||||
@ -103,13 +160,26 @@ class RouteMoELayer(nn.Module):
|
||||
self.num_beams = num_beams
|
||||
self.hidden_size = hidden_size
|
||||
self.layer_judge = layer_judge
|
||||
self.weight_type = weight_type
|
||||
|
||||
self.route_method = route_method
|
||||
if self.route_method == "pre-route":
|
||||
self.gate = nn.Linear(hidden_size, num_experts, bias=False).float()
|
||||
elif self.route_method == "post-route":
|
||||
# gate = nn.Linear(hidden_size, 1, bias=False).float()
|
||||
self.gates = nn.ModuleList([copy.deepcopy(gate) for i in range(num_experts)])
|
||||
gate = nn.Linear(hidden_size, 1, bias=False).float()
|
||||
self.gate = gate
|
||||
# self.gates = nn.ModuleList([copy.deepcopy(gate) for i in range(num_experts)])
|
||||
|
||||
def _importance_auxiliary_loss(self, prob_gate):
|
||||
# From VMOE
|
||||
# _importance_auxiliary_loss
|
||||
axis = tuple(range(prob_gate.ndim - 1)) # All except last.
|
||||
importance_per_expert = torch.sum(prob_gate, dim=axis)
|
||||
std_importance_per_expert = torch.std(importance_per_expert)
|
||||
mean_importance_per_expert = torch.mean(importance_per_expert)
|
||||
# Compute coefficient of variation (i.e. std/mean) squared.
|
||||
return (std_importance_per_expert / mean_importance_per_expert)**2
|
||||
|
||||
|
||||
def forward_gate(self, x):
|
||||
"""
|
||||
@ -123,19 +193,21 @@ class RouteMoELayer(nn.Module):
|
||||
prob_gate = F.softmax(logits_gate, dim=-1) # torch.Size([bz*num_beams, num_experts])
|
||||
return prob_gate
|
||||
|
||||
def beam_search(self, current_scores_log, beam_scores, expert_route, batch_size):
|
||||
import pdb;pdb.set_trace()
|
||||
|
||||
def beam_search_backup(self, current_scores_log, beam_scores, expert_route, batch_size):
|
||||
if self.layer_judge=='first' and self.route_method=='pre-route':
|
||||
# current_scores_log torch.Size([bz, num_experts])
|
||||
assert beam_scores==None and expert_route==None
|
||||
current_scores = torch.exp(current_scores_log)
|
||||
topk_values, gate = torch.topk(current_scores, self.num_beams, dim=1) # gate, 每个样本被分配的expert: torch.Size([bz, topk])
|
||||
beam_scores = topk_values.view(self.num_beams * batch_size) # torch.Size([bz * num_beams])
|
||||
expert_route = gate.view(self.num_beams * batch_size).unsqueeze(1) # torch.Size([bz * num_beams,1])
|
||||
beam_idx = None
|
||||
beam_idx = torch.tensor(range(self.num_beams * batch_size))
|
||||
|
||||
else:
|
||||
if self.layer_judge=='first' and self.route_method == 'post-route':
|
||||
batch_size = batch_size
|
||||
next_scores_raw1 = torch.exp(current_scores_log) # torch.Size([bz, num_experts])
|
||||
next_scores_raw1 = torch.exp(current_scores_log) # torch.Size([bz, num_beams*num_experts])
|
||||
else:
|
||||
batch_size = int(batch_size // self.num_beams)
|
||||
next_scores_raw = current_scores_log + torch.log(beam_scores).unsqueeze(1) # torch.Size([4*3, 5]) # 取log 之后,可以直接相加概率
|
||||
@ -147,9 +219,6 @@ class RouteMoELayer(nn.Module):
|
||||
next_scores, next_experts = torch.topk(next_scores_raw1, self.num_beams, dim=1, largest=True, sorted=True)
|
||||
# next_scores torch.Size([bz, num_beams])
|
||||
# next_tokens torch.Size([bz, num_beams])
|
||||
print(next_scores_raw1)
|
||||
print(next_scores)
|
||||
print(next_experts)
|
||||
|
||||
next_batch_beam = list()
|
||||
for batch_idx in range(batch_size):
|
||||
@ -181,33 +250,91 @@ class RouteMoELayer(nn.Module):
|
||||
pre_route = expert_route[beam_idx,:]
|
||||
expert_route = torch.cat([pre_route, beam_experts.unsqueeze(1)], dim=-1)
|
||||
|
||||
return beam_scores, expert_route, beam_idx
|
||||
|
||||
def beam_search(self, current_scores_log, beam_scores, expert_route, batch_size):
|
||||
if self.layer_judge=='first' and self.route_method in ['pre-route', 'post-route']:
|
||||
# current_scores_log torch.Size([bz, num_experts])
|
||||
assert beam_scores==None and expert_route==None
|
||||
current_scores = torch.exp(current_scores_log)
|
||||
topk_values, gate = torch.topk(current_scores, self.num_beams, dim=1) # gate, 每个样本被分配的expert: torch.Size([bz, topk])
|
||||
beam_scores = topk_values.view(self.num_beams * batch_size) # torch.Size([bz * num_beams])
|
||||
expert_route = gate.view(self.num_beams * batch_size).unsqueeze(1) # torch.Size([bz * num_beams,1])
|
||||
beam_idx = torch.tensor(range(self.num_beams * batch_size))
|
||||
import pdb;pdb.set_trace()
|
||||
|
||||
else:
|
||||
batch_size = int(batch_size // self.num_beams)
|
||||
next_scores_raw = current_scores_log + torch.log(beam_scores).unsqueeze(1) # torch.Size([4*3, 5]) # 取log 之后,可以直接相加概率
|
||||
next_scores_exp = torch.exp(next_scores_raw)
|
||||
next_scores_raw1 = next_scores_exp.view(
|
||||
batch_size, self.num_beams * self.num_experts
|
||||
) # torch.Size([bz, num_beams*num_experts])
|
||||
|
||||
next_scores, next_experts = torch.topk(next_scores_raw1, self.num_beams, dim=1, largest=True, sorted=True)
|
||||
# next_scores torch.Size([bz, num_beams])
|
||||
# next_tokens torch.Size([bz, num_beams])
|
||||
|
||||
next_batch_beam = list()
|
||||
for batch_idx in range(batch_size):
|
||||
next_sent_beam = list()
|
||||
for rank, (expert_id, expert_score) in enumerate(
|
||||
zip(next_experts[batch_idx], next_scores[batch_idx])
|
||||
):
|
||||
expert_id = expert_id.item()
|
||||
beam_id = expert_id // self.num_experts
|
||||
ex_id = expert_id % self.num_experts
|
||||
effective_beam_id = batch_idx*self.num_beams + beam_id
|
||||
|
||||
next_sent_beam.append((expert_score, ex_id, effective_beam_id))
|
||||
next_batch_beam.extend(next_sent_beam)
|
||||
|
||||
# import pdb;pdb.set_trace()
|
||||
|
||||
beam_scores = beam_scores.new([x[0] for x in next_batch_beam])
|
||||
beam_experts = expert_route[:,-1].new([x[1] for x in next_batch_beam])
|
||||
beam_idx = expert_route[:,-1].new([x[2] for x in next_batch_beam])
|
||||
pre_route = expert_route[beam_idx,:]
|
||||
expert_route = torch.cat([pre_route, beam_experts.unsqueeze(1)], dim=-1)
|
||||
|
||||
print("next_scores_raw1:\n",next_scores_raw1)
|
||||
|
||||
return beam_scores, expert_route, beam_idx
|
||||
|
||||
|
||||
def forward_expert_ffn(self, x, expert_select, beam_scores):
|
||||
|
||||
def forward_expert_ffn(self, x, expert_select, current_scores):
|
||||
"""
|
||||
x_repeat : [bz*num_beams, 32,768]
|
||||
expert_select : [bz*num_beams]
|
||||
current_scores : [bz*num_beams, num_experts] / [bz, num_experts]
|
||||
"""
|
||||
# add_1212 l2_normalization
|
||||
# normalized_tensor = torch.nn.functional.normalize(beam_scores, p=2, dim=0) # L2 Normalization torch.Size([bz, topk])
|
||||
# add_1228 l2_normalization
|
||||
# normalized_tensor = torch.nn.functional.normalize(current_scores, p=2, dim=0) # L2 Normalization torch.Size([bz, topk])
|
||||
# tmp_prob = normalized_tensor.unsqueeze(-1).unsqueeze(-1)
|
||||
|
||||
import pdb;pdb.set_trace()
|
||||
outputs = list()
|
||||
for i in range(x.shape[0]):
|
||||
output_x = self.experts[expert_select[i]].forward(x[i])
|
||||
outputs.append(output_x.unsqueeze(0))
|
||||
candidate_output = torch.cat(outputs)
|
||||
for i in range(self.num_experts):
|
||||
output_x = self.experts[i].forward(x)
|
||||
outputs.append(output_x.unsqueeze(1))
|
||||
candidate_output = torch.cat(outputs, dim=1)
|
||||
expert_select_matrix = F.one_hot(expert_select, self.num_experts)
|
||||
|
||||
# candidate_output = candidate_output * tmp_prob
|
||||
return candidate_output # torch.Size([bz*num_beams, 32, 768])
|
||||
if self.weight_type == 'ffn_prob':
|
||||
tmp_prob = current_scores * expert_select_matrix
|
||||
candidate_output = candidate_output * tmp_prob.unsqueeze(-1).unsqueeze(-1)
|
||||
else:
|
||||
candidate_output = candidate_output * expert_select_matrix.unsqueeze(-1).unsqueeze(-1)
|
||||
import pdb;pdb.set_trace()
|
||||
output = torch.sum(candidate_output, dim=1)
|
||||
|
||||
return output # torch.Size([bz*num_beams, 32, 768])
|
||||
|
||||
def forward_pre_route(self, x, beam_scores, expert_route, use_log=True):
|
||||
import pdb;pdb.set_trace()
|
||||
current_scores = self.forward_gate(x) # [bz, num_beams] / [bz*num_beams, num_beams]
|
||||
|
||||
current_scores = self.forward_gate(x) # [bz*num_beams, 5]
|
||||
importance_loss = self._importance_auxiliary_loss(current_scores)
|
||||
|
||||
if use_log:
|
||||
current_scores_log = torch.log(current_scores) # 取log之后可以直接相加
|
||||
@ -215,26 +342,25 @@ class RouteMoELayer(nn.Module):
|
||||
current_scores_log = current_scores
|
||||
|
||||
batch_size, num_tokens = x.shape[0], x.shape[1]
|
||||
beam_scores, expert_route, _ = self.beam_search(current_scores_log, beam_scores, expert_route, batch_size)
|
||||
|
||||
beam_scores, expert_route, beam_idx = self.beam_search(current_scores_log, beam_scores, expert_route, batch_size)
|
||||
current_expert_select = expert_route[:,-1]
|
||||
|
||||
import pdb;pdb.set_trace()
|
||||
|
||||
if self.layer_judge=='first': # expand first dim to batch_size * num_beams
|
||||
replicated_tensor = x.unsqueeze(1).expand(batch_size, self.num_beams, num_tokens, self.hidden_size)
|
||||
x = replicated_tensor.contiguous().view(-1, num_tokens, self.hidden_size) # [bz*num_beams, 32,768]
|
||||
current_scores_t = current_scores.unsqueeze(1).expand(batch_size, self.num_beams, self.num_experts)
|
||||
current_scores = current_scores_t.contiguous().view(-1, self.num_experts) # [bz*num_beams, num_experts]
|
||||
|
||||
candidate_output = self.forward_expert_ffn(x, current_expert_select, beam_scores) # [bz*num_beams, 32,768]
|
||||
|
||||
return candidate_output, beam_scores, expert_route
|
||||
input_x = x[beam_idx]
|
||||
candidate_output = self.forward_expert_ffn(input_x, current_expert_select, current_scores) # [bz*num_beams, 32,768]
|
||||
import pdb;pdb.set_trace()
|
||||
|
||||
return candidate_output, beam_scores, expert_route, beam_idx, importance_loss
|
||||
|
||||
def forward_post_route(self, x, beam_scores, expert_route, use_log=True):
|
||||
|
||||
# if self.layer_judge=='first': # expand first dim to batch_size * num_beams
|
||||
# batch_size, num_tokens = x.shape[0], x.shape[1]
|
||||
# replicated_tensor = x.unsqueeze(1).expand(batch_size, self.num_beams, num_tokens, self.hidden_size)
|
||||
# x = replicated_tensor.contiguous().view(-1, num_tokens, self.hidden_size) # [bz*num_beams, 32,768]
|
||||
|
||||
attention_mask = torch.ones(x.shape[0], x.shape[1]).to(x.device)
|
||||
x_masked = x * attention_mask.unsqueeze(-1) # torch.Size([bz, 32, 768])
|
||||
|
||||
@ -242,15 +368,19 @@ class RouteMoELayer(nn.Module):
|
||||
output_x = self.experts[expert_idx].forward(input_x)
|
||||
return output_x
|
||||
|
||||
import pdb; pdb.set_trace()
|
||||
outputs = list()
|
||||
logits_gate_lst = list()
|
||||
for expert_idx in range(self.num_experts):
|
||||
output_x = forward_expert(x_masked, expert_idx)
|
||||
# output_x_aver = output_x.sum(1) / attention_mask.unsqueeze(-1).sum(1) # torch.Size([bz*num_beam, 768])
|
||||
output_x_aver = torch.mean(output_x, dim=1)
|
||||
# gate_score = self.gates[expert_idx](output_x_aver)
|
||||
gate_score = self.gate(output_x_aver)
|
||||
logits_gate_lst.append(gate_score)
|
||||
outputs.append(output_x.unsqueeze(0))
|
||||
output_x_aver = output_x.sum(1) / attention_mask.unsqueeze(-1).sum(1) # torch.Size([bz*num_beam, 768])
|
||||
gate_acore = self.gates[expert_idx](output_x_aver)
|
||||
logits_gate_lst.append(gate_acore)
|
||||
candidate_output = torch.cat(outputs) # torch.Size([num_expert, bz*num_beam, 32, 768])
|
||||
|
||||
candidate_output_raw = torch.cat(outputs) # torch.Size([num_expert, bz*num_beam, 32, 768])
|
||||
logits_gate = torch.cat(logits_gate_lst,dim=1)# torch.Size([bz*num_beam, num_expert])
|
||||
current_scores = F.softmax(logits_gate, dim=-1) # torch.Size([bz*num_beam, num_experts])
|
||||
|
||||
@ -259,25 +389,39 @@ class RouteMoELayer(nn.Module):
|
||||
else:
|
||||
current_scores_log = current_scores
|
||||
|
||||
import pdb;pdb.set_trace()
|
||||
# importance loss
|
||||
importance_loss = self._importance_auxiliary_loss(current_scores)
|
||||
|
||||
batch_size = x.shape[0] # bz*num_beam
|
||||
# import pdb; pdb.set_trace()
|
||||
|
||||
batch_size, num_tokens = x.shape[0], x.shape[1] # bz*num_beam
|
||||
beam_scores, expert_route, beam_idx = self.beam_search(current_scores_log, beam_scores, expert_route, batch_size)
|
||||
# beam_scores torch.Size([bz*num_beam])
|
||||
# expert_route torch.Size([bz*num_beam, layer_n])
|
||||
current_select_expert = expert_route[:,-1]
|
||||
# current_select_expert torch.Size([bz*num_beam, 1])
|
||||
|
||||
output = list()
|
||||
for i in range(beam_idx.shape[0]):
|
||||
b_idx = beam_idx[i]
|
||||
ex_idx = current_select_expert[i]
|
||||
ex_out = candidate_output[ex_idx, b_idx, :,:]
|
||||
output.append(ex_out.unsqueeze(0))
|
||||
# import pdb; pdb.set_trace()
|
||||
|
||||
final_output = torch.concat(output, dim=0)
|
||||
if self.layer_judge == 'first':
|
||||
replicated_tensor = candidate_output_raw.unsqueeze(2).expand(self.num_experts, batch_size, self.num_beams, num_tokens, self.hidden_size)
|
||||
candidate_output_raw = replicated_tensor.contiguous().view(self.num_experts, -1, num_tokens, self.hidden_size) # [bz*num_beams, 32,768]
|
||||
current_scores_t = current_scores.unsqueeze(1).expand(batch_size, self.num_beams, self.num_experts)
|
||||
current_scores = current_scores_t.contiguous().view(-1, self.num_experts) # [bz*num_beams, num_experts]
|
||||
|
||||
return final_output, beam_scores, expert_route, beam_idx
|
||||
candidate_output = candidate_output_raw.permute(1, 0, 2, 3)[beam_idx] # torch.Size([8, 2, 32, 768])
|
||||
expert_select_matrix = F.one_hot(current_select_expert, self.num_experts)
|
||||
if self.weight_type == 'ffn_prob':
|
||||
tmp_prob = current_scores[beam_idx] * expert_select_matrix
|
||||
output = candidate_output * tmp_prob.unsqueeze(-1).unsqueeze(-1)
|
||||
else:
|
||||
output = candidate_output * expert_select_matrix.unsqueeze(-1).unsqueeze(-1)
|
||||
final_output = torch.sum(output, dim=1)
|
||||
|
||||
import pdb; pdb.set_trace()
|
||||
print("current_scores:\n",current_scores)
|
||||
|
||||
return final_output, beam_scores, expert_route, beam_idx, importance_loss
|
||||
|
||||
def forward(self, x, attention_mask, beam_scores, expert_route, use_log=True):
|
||||
"""
|
||||
@ -286,12 +430,11 @@ class RouteMoELayer(nn.Module):
|
||||
|
||||
"""
|
||||
if self.route_method == 'pre-route':
|
||||
candidate_output, beam_scores, expert_route, _ = self.forward_pre_route(x, beam_scores, expert_route, use_log=True)
|
||||
candidate_output, beam_scores, expert_route, beam_idx, importance_loss = self.forward_pre_route(x, beam_scores, expert_route, use_log=True)
|
||||
elif self.route_method == "post-route":
|
||||
candidate_output, beam_scores, expert_route, beam_idx = self.forward_post_route(x, beam_scores, expert_route, use_log=True)
|
||||
|
||||
return candidate_output, beam_scores, expert_route, beam_idx
|
||||
candidate_output, beam_scores, expert_route, beam_idx, importance_loss = self.forward_post_route(x, beam_scores, expert_route, use_log=True)
|
||||
|
||||
return candidate_output, beam_scores, expert_route, beam_idx, importance_loss
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
@ -314,8 +457,8 @@ if __name__ == '__main__':
|
||||
config.add_cross_attention = True
|
||||
config.cross_attention_freq = cross_attention_freq
|
||||
config.query_length = num_query_token
|
||||
config.moebert_expert_num = 3
|
||||
config.moebert_num_beams = 3
|
||||
config.moebert_expert_num = 2
|
||||
config.moebert_num_beams = 2
|
||||
config.moebert_route_method = 'gate-sentence'
|
||||
config.moe_topk = 2
|
||||
config.use_balance_loss = False
|
||||
@ -332,40 +475,46 @@ if __name__ == '__main__':
|
||||
for layer_num in [6, 8, 10]:
|
||||
layer_judge = moe_layer_judge(layer_num)
|
||||
ffn = FeedForward(config)
|
||||
gate = nn.Linear(768, config.moebert_expert_num, bias=False).float()
|
||||
|
||||
# experts = RouteMoELayer(
|
||||
# hidden_size=768,
|
||||
# expert=ffn,
|
||||
# gate = gate,
|
||||
# num_experts=config.moebert_expert_num,
|
||||
# num_beams=config.moebert_num_beams,
|
||||
# layer_judge = layer_judge,
|
||||
# route_method = "pre-route"
|
||||
# route_method = "pre-route",
|
||||
# weight_type="no_ffn_prob"
|
||||
# )
|
||||
# layer_output = experts(x, None, beam_scores, expert_route)
|
||||
# hidden_states1, beam_scores, expert_route,_ = layer_output
|
||||
# hidden_states1, beam_scores, expert_route, beam_idx, importance_loss = layer_output
|
||||
|
||||
# print(beam_scores)
|
||||
# print(expert_route)
|
||||
# print(beam_idx)
|
||||
# print(importance_loss)
|
||||
# x = hidden_states1
|
||||
|
||||
gate1 = nn.Linear(768, 1, bias=False).float()
|
||||
experts_post = RouteMoELayer(
|
||||
hidden_size=768,
|
||||
expert=ffn,
|
||||
gate = gate1,
|
||||
num_experts=config.moebert_expert_num,
|
||||
num_beams=config.moebert_num_beams,
|
||||
layer_judge = layer_judge,
|
||||
route_method = "post-route"
|
||||
route_method = "post-route",
|
||||
weight_type="ffn_prob"
|
||||
)
|
||||
layer_output = experts_post(x1, None, beam_scores1, expert_route1, False)
|
||||
hidden_states2, beam_scores1, expert_route1, beam_idx = layer_output
|
||||
hidden_states2, beam_scores1, expert_route1, beam_idx, importance_loss = layer_output
|
||||
|
||||
print(beam_scores1)
|
||||
print(expert_route1)
|
||||
print(beam_idx)
|
||||
print(importance_loss)
|
||||
x1 = hidden_states2
|
||||
|
||||
|
||||
# gate = nn.Linear(768, config.moebert_expert_num, bias=False).float()
|
||||
# experts_moe = MoELayer(
|
||||
# hidden_size=config.hidden_size,
|
||||
# expert=ffn,
|
||||
@ -382,11 +531,62 @@ if __name__ == '__main__':
|
||||
|
||||
# print(select_prob_gate)
|
||||
# print(gate_load)
|
||||
|
||||
|
||||
# x = hidden_states1
|
||||
x1 = hidden_states2
|
||||
# x2 = hidden_states3
|
||||
|
||||
print("------------------------------------")
|
||||
import pdb; pdb.set_trace()
|
||||
|
||||
|
||||
|
||||
def forward_post_route_backup(self, x, beam_scores, expert_route, use_log=True):
|
||||
|
||||
attention_mask = torch.ones(x.shape[0], x.shape[1]).to(x.device)
|
||||
x_masked = x * attention_mask.unsqueeze(-1) # torch.Size([bz, 32, 768])
|
||||
|
||||
def forward_expert(input_x, expert_idx):
|
||||
output_x = self.experts[expert_idx].forward(input_x)
|
||||
return output_x
|
||||
|
||||
outputs = list()
|
||||
logits_gate_lst = list()
|
||||
for expert_idx in range(self.num_experts):
|
||||
output_x = forward_expert(x_masked, expert_idx)
|
||||
outputs.append(output_x.unsqueeze(0))
|
||||
# output_x_aver = output_x.sum(1) / attention_mask.unsqueeze(-1).sum(1) # torch.Size([bz*num_beam, 768])
|
||||
# gate_score = self.gates[expert_idx](output_x_aver)
|
||||
output_x_aver = torch.mean(output_x, dim=1)
|
||||
gate_score = self.gate(output_x_aver)
|
||||
logits_gate_lst.append(gate_score)
|
||||
candidate_output = torch.cat(outputs) # torch.Size([num_expert, bz*num_beam, 32, 768])
|
||||
logits_gate = torch.cat(logits_gate_lst,dim=1)# torch.Size([bz*num_beam, num_expert])
|
||||
current_scores = F.softmax(logits_gate, dim=-1) # torch.Size([bz*num_beam, num_experts])
|
||||
|
||||
if use_log:
|
||||
current_scores_log = torch.log(current_scores) # 取log之后可以直接相加
|
||||
else:
|
||||
current_scores_log = current_scores
|
||||
|
||||
# importance loss
|
||||
importance_loss = self._importance_auxiliary_loss(current_scores)
|
||||
|
||||
batch_size = x.shape[0] # bz*num_beam
|
||||
beam_scores, expert_route, beam_idx = self.beam_search(current_scores_log, beam_scores, expert_route, batch_size)
|
||||
# beam_scores torch.Size([bz*num_beam])
|
||||
# expert_route torch.Size([bz*num_beam, layer_n])
|
||||
current_select_expert = expert_route[:,-1]
|
||||
# current_select_expert torch.Size([bz*num_beam, 1])
|
||||
|
||||
output = list()
|
||||
for i in range(beam_idx.shape[0]):
|
||||
b_idx = beam_idx[i]
|
||||
ex_idx = current_select_expert[i]
|
||||
ex_out = candidate_output[ex_idx, b_idx, :,:]
|
||||
if self.weight_type == 'ffn_prob':
|
||||
prob = current_scores[b_idx, ex_idx]
|
||||
ex_out = ex_out*prob
|
||||
output.append(ex_out.unsqueeze(0))
|
||||
|
||||
final_output = torch.concat(output, dim=0)
|
||||
# import pdb;pdb.set_trace()
|
||||
return final_output, beam_scores, expert_route, beam_idx, importance_loss
|
||||
|
||||
|
@ -1,155 +0,0 @@
|
||||
import torch
|
||||
import copy
|
||||
import pickle
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
device = torch.device("cuda:2" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
|
||||
def forward_expert(input_x, expert_idx):
|
||||
input_x += torch.randn(32,768)
|
||||
return input_x
|
||||
# output_x = self.experts[expert_idx].forward(input_x)
|
||||
# return output_x
|
||||
|
||||
|
||||
def forward_ffn(x_repeat, expert_select):
|
||||
"""
|
||||
x_repeat : [bz*num_beams, 32,768]
|
||||
expert_select : [bz*num_beams]
|
||||
"""
|
||||
outputs = list()
|
||||
num_beams_bz = x_repeat.shape[0]
|
||||
for i in range(num_beams_bz):
|
||||
output_x = forward_expert(x_repeat[i], expert_select[i]) # (32,768)
|
||||
outputs.append(output_x.unsqueeze(0))
|
||||
candidate_output = torch.cat(outputs)
|
||||
return candidate_output # torch.Size([bz*num_beams, 32, 768])
|
||||
|
||||
def forward_gate(x, num_expert):
|
||||
"""
|
||||
x : torch.Size([bz*num_beams, 32, 768]) or torch.Size([bz, 32, 768])
|
||||
prob_gate : torch.Size([bz*num_beams, num_experts]) or torch.Size([bz, num_experts])
|
||||
"""
|
||||
# attention_mask = torch.ones(x.shape[0], x.shape[1]).to(x.device)
|
||||
# x_masked = x * attention_mask.unsqueeze(-1) # torch.Size([bz*num_beams, 32, 768])
|
||||
# x_average = x_masked.sum(1) / attention_mask.unsqueeze(-1).sum(1) # torch.Size([bz*num_beams, 768])
|
||||
# logits_gate = gate(x_average) # torch.Size([bz, num_experts])
|
||||
logits_gate = torch.randn(x.shape[0], num_expert)
|
||||
prob_gate = F.softmax(logits_gate, dim=-1) # torch.Size([bz*num_beams, num_experts])
|
||||
return prob_gate
|
||||
|
||||
def beam_search(layer, current_scores, beam_scores, expert_route, num_beams):
|
||||
if layer == 0 and beam_scores==None and expert_route==None:
|
||||
topk_values, gate = torch.topk(current_scores, num_beams, dim=1) # gate, 每个样本被分配的expert: torch.Size([bz, topk])
|
||||
beam_scores = topk_values.view(num_beams*batch_size) # torch.Size([bz * num_beams])
|
||||
expert_route = gate.view(num_beams*batch_size).unsqueeze(1) # torch.Size([bz * num_beams])
|
||||
|
||||
else:
|
||||
next_scores_raw = current_scores + beam_scores.unsqueeze(1) # torch.Size([4*3, 5]) # 取log 之后,可以直接相加概率
|
||||
next_scores_raw1 = next_scores_raw.view(
|
||||
batch_size, num_beams * num_expert
|
||||
) # torch.Size([4, 3*5])
|
||||
next_scores, next_experts = torch.topk(next_scores_raw1, num_beams, dim=1, largest=True, sorted=True)
|
||||
# next_scores torch.Size([4, 3*num_beams])
|
||||
# next_tokens torch.Size([4, 3*num_beams])
|
||||
|
||||
next_batch_beam = list()
|
||||
for batch_idx in range(batch_size):
|
||||
next_sent_beam = list()
|
||||
print(batch_idx)
|
||||
for rank, (expert_id, expert_score) in enumerate(
|
||||
zip(next_experts[batch_idx], next_scores[batch_idx])
|
||||
):
|
||||
expert_id = expert_id.item()
|
||||
beam_id = expert_id // num_expert
|
||||
ex_id = expert_id % num_expert
|
||||
effective_beam_id = batch_idx*num_beams + beam_id
|
||||
|
||||
# print(expert_id, beam_id, ex_id, effective_beam_id, expert_score)
|
||||
|
||||
next_sent_beam.append((expert_score, ex_id, effective_beam_id))
|
||||
next_batch_beam.extend(next_sent_beam)
|
||||
|
||||
# print()
|
||||
|
||||
import pdb;pdb.set_trace()
|
||||
|
||||
beam_scores = beam_scores.new([x[0] for x in next_batch_beam])
|
||||
beam_experts = expert_route[:,-1].new([x[1] for x in next_batch_beam])
|
||||
beam_idx = expert_route[:,-1].new([x[2] for x in next_batch_beam])
|
||||
|
||||
pre_route = expert_route[beam_idx,:]
|
||||
expert_route = torch.cat([pre_route, beam_experts.unsqueeze(1)], dim=-1)
|
||||
|
||||
return beam_scores, expert_route
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
batch_size = 3
|
||||
num_beams = 2
|
||||
num_expert = 5
|
||||
x = torch.randn(batch_size, 32, 768)
|
||||
beam_scores, expert_route = None, None
|
||||
|
||||
for layer in range(0,3):
|
||||
# import pdb;pdb.set_trace()
|
||||
|
||||
current_scores = forward_gate(x, num_expert)
|
||||
import pdb;pdb.set_trace()
|
||||
|
||||
beam_scores, expert_route = beam_search(layer, current_scores, beam_scores, expert_route, num_beams)
|
||||
current_expert_select = expert_route[:,-1]
|
||||
|
||||
if layer == 0:
|
||||
replicated_tensor = x.unsqueeze(1).expand(batch_size, num_beams, 32, 768)
|
||||
x = replicated_tensor.contiguous().view(-1, 32, 768) # [12,32,768] [bz*num_beams, 32,768]
|
||||
else:
|
||||
x = candidate_output
|
||||
|
||||
candidate_output = forward_ffn(x, current_expert_select) # torch.Size([4*3, 5])
|
||||
|
||||
x = candidate_output
|
||||
|
||||
|
||||
scores = beam_scores.view(batch_size, num_beams)
|
||||
topk_values, gate = torch.topk(scores, 1, dim=1)
|
||||
# gate [batch_size, 1]
|
||||
# topk_values [batch_size, 1]
|
||||
selects = [ (bz_idx * num_beams + gate[bz_idx].item()) for bz_idx in range(batch_size)]
|
||||
final_scores = beam_scores[selects]
|
||||
final_expert_route = expert_route[selects]
|
||||
final_output = candidate_output[selects]
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
# def forward_ffn_post(x_repeat, expert_select):
|
||||
# """
|
||||
# x_repeat : [bz*num_beams, 32,768]
|
||||
# expert_select : [bz*num_beams]
|
||||
# prob_gate : torch.Size([bz*num_beams, num_experts])
|
||||
# """
|
||||
# outputs = list()
|
||||
# logits_gate_lst = list()
|
||||
# # attention_mask = torch.ones([batch_size, 32])
|
||||
# for i in range(num_beams*batch_size):
|
||||
# output_x = forward_expert(x_repeat[i], expert_select[i]) # (32,768)
|
||||
# outputs.append(output_x.unsqueeze(0))
|
||||
# # output_x_aver = output_x.sum(1) / attention_mask.unsqueeze(-1).sum(1) # torch.Size([bz, 768])
|
||||
# # gate_acore = self.gates[expert_idx](output_x_aver)
|
||||
# # gate_score = self.gate(output_x_aver)
|
||||
# num_expert = 5
|
||||
# gate_score = torch.randn(1,num_expert)
|
||||
# logits_gate_lst.append(gate_score)
|
||||
|
||||
# candidate_output = torch.cat(outputs) # torch.Size([bz*num_beams, 32, 768])
|
||||
# logits_gate = torch.cat(logits_gate_lst,dim=0)# torch.Size([bz*num_beams, num_expert])
|
||||
# prob_gate = F.softmax(logits_gate, dim=-1) # torch.Size([bz*num_beams, num_experts])
|
||||
# return prob_gate, candidate_output
|
@ -5,7 +5,7 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
class MoELayer(nn.Module):
|
||||
def __init__(self, hidden_size, expert, num_experts, route_method, topk=1, use_balance_loss=True, weight_type='l2_norm'):
|
||||
def __init__(self, hidden_size, expert, num_experts, route_method, topk=1, use_balance_loss=True, weight_type='raw_prob'):
|
||||
# remove hash list
|
||||
nn.Module.__init__(self)
|
||||
self.num_experts = num_experts
|
||||
@ -81,54 +81,6 @@ class MoELayer(nn.Module):
|
||||
|
||||
return x, balance_loss, gate_load
|
||||
|
||||
def _forward_gate_sentence_top1_raw(self, x, attention_mask):
|
||||
"""
|
||||
x: query_attention_output , torch.Size([bz, 32, 768])
|
||||
attention_mask: torch.ones([bz, 32])
|
||||
|
||||
### Notice:
|
||||
the raw version of expert_attention_mask is the extended_attention_mask,
|
||||
which will be add to attention_score directly
|
||||
the values of extended_attention_mask are -0.0 or -10000
|
||||
it should be adjust to 1/0 version to be processed by experts
|
||||
"""
|
||||
attention_mask = torch.ones(attention_mask.shape[0], attention_mask.shape[1]).to(x.device)
|
||||
x_masked = x * attention_mask.unsqueeze(-1) # torch.Size([bz, 32, 768])
|
||||
x_average = x_masked.sum(1) / attention_mask.unsqueeze(-1).sum(1) # torch.Size([bz, 768])
|
||||
logits_gate = self.gate(x_average) # torch.Size([bz, num_experts])
|
||||
prob_gate = F.softmax(logits_gate, dim=-1) # torch.Size([bz, num_experts])
|
||||
gate = torch.argmax(prob_gate, dim=-1) # torch.Size([bz])
|
||||
|
||||
order = gate.argsort(0)
|
||||
num_sentences = F.one_hot(gate, self.num_experts).gt(0).sum(0)
|
||||
gate_load = num_sentences.clone()
|
||||
x = x[order] # reorder according to expert number
|
||||
x = x.split(num_sentences.tolist(), dim=0) # a list of length self.num_experts
|
||||
|
||||
# compute the load balancing loss
|
||||
P = prob_gate.mean(0)
|
||||
temp = num_sentences.float()
|
||||
f = temp / temp.sum(0, keepdim=True)
|
||||
balance_loss = self.num_experts * torch.sum(P * f)
|
||||
|
||||
prob_gate = prob_gate.gather(dim=1, index=gate.unsqueeze(1))
|
||||
prob_gate = prob_gate[order]
|
||||
prob_gate = prob_gate.split(num_sentences.tolist(), dim=0)
|
||||
|
||||
def forward_expert(input_x, prob_x, expert_idx):
|
||||
input_x = self.experts[expert_idx].forward(input_x)
|
||||
input_x = input_x * prob_x.unsqueeze(-1)
|
||||
return input_x
|
||||
|
||||
result = []
|
||||
for i in range(self.num_experts):
|
||||
if x[i].size(0) > 0:
|
||||
result.append(forward_expert(x[i], prob_gate[i], i))
|
||||
result = torch.vstack(result)
|
||||
result = result[order.argsort(0)] # restore original order
|
||||
|
||||
return result, balance_loss, gate_load
|
||||
|
||||
def _forward_gate_sentence_post(self, x, attention_mask):
|
||||
"""
|
||||
x: query_attention_output; torch.Size([bz, 32, 768])
|
||||
@ -174,13 +126,17 @@ class MoELayer(nn.Module):
|
||||
# importance loss
|
||||
importance_loss = self._importance_auxiliary_loss(prob_gate)
|
||||
|
||||
# output_average = candidate_output.sum(2) / candidate_attn_mask.unsqueeze(-1).sum(2) # torch.Size([num_expert, bz, 768])
|
||||
# output_average = torch.permute(output_average, (1, 0, 2)) # torch.Size([bz, num_expert, 768])
|
||||
# logits_gate = self.gate(output_average) # torch.Size([bz, num_experts, 1])
|
||||
|
||||
prob_gate_topk = torch.zeros_like(prob_gate)
|
||||
prob_gate_topk.scatter_(1, gate, topk_values)
|
||||
prob_gate_normalized = prob_gate_topk / prob_gate_topk.sum(dim=1, keepdim=True) # torch.Size([bz, num_expert])
|
||||
|
||||
if self.weight_type == 'average':
|
||||
# torch.Size([bz, num_expert]) 未选中的expert prob_gate_norm为0
|
||||
prob_gate_normalized = prob_gate_topk / prob_gate_topk.sum(dim=1, keepdim=True)
|
||||
elif self.weight_type == 'raw_prob':
|
||||
prob_gate_normalized = prob_gate_topk
|
||||
elif self.weight_type == 'softmax_norm':
|
||||
prob_gate_normalized = F.softmax(prob_gate_topk, dim=-1) # torch.Size([bz, num_expert])
|
||||
|
||||
candidate_output_ad = torch.permute(candidate_output, (1, 0, 2, 3)) # torch.Size([bz, num_expert, 32, 768])
|
||||
results = prob_gate_normalized.unsqueeze(-1).unsqueeze(-1) * candidate_output_ad # torch.Size([bz, num_expert, 32, 768])
|
||||
moe_result = torch.sum(results, dim=1) # torch.Size([bz, 32, 768])
|
||||
@ -188,6 +144,46 @@ class MoELayer(nn.Module):
|
||||
|
||||
return moe_result, (balance_loss+importance_loss), prob_gate_normalized
|
||||
|
||||
def router(self, x, attention_mask):
|
||||
# Prepare input x
|
||||
attention_mask = torch.ones(attention_mask.shape[0], attention_mask.shape[1]).to(x.device)
|
||||
x_masked = x * attention_mask.unsqueeze(-1) # torch.Size([bz, 32, 768])
|
||||
x_average = torch.mean(x_masked, dim=1) # torch.Size([bz, 768])
|
||||
|
||||
# Forward Gate
|
||||
# logits_gate: [bz, num_experts]
|
||||
logits_gate = self.gate(x_average)
|
||||
|
||||
# Probabilities for each sample of what expert it should be sent to.
|
||||
# prob_gate: [bz, num_experts]
|
||||
prob_gate = F.softmax(logits_gate, dim=-1)
|
||||
|
||||
# Get Top-K experts for each sample
|
||||
# gate: [bz, topk]
|
||||
# select_prob_gate: [bz, topk]
|
||||
select_prob_gate, gate = torch.topk(prob_gate, self.topk, dim=1)
|
||||
|
||||
# Reshap Prob_gate & Gate
|
||||
# expert_mask: [batch_size, topk, num_experts]
|
||||
# expert_gate: [batch_size, topk, num_experts]
|
||||
# combine_tensor: [batch_size, num_experts]
|
||||
expert_mask = F.one_hot(gate, self.num_experts)
|
||||
expert_gate = select_prob_gate.unsqueeze(-1) * expert_mask
|
||||
combine_tensor = torch.sum(expert_gate, dim=1)
|
||||
|
||||
# Calculate Balancing Loss
|
||||
if self.use_balance_loss:
|
||||
num_sentences = F.one_hot(gate, self.num_experts).sum(1).gt(0).sum(0) # 每个expert被分配的样本数 torch.Size([num_expert])
|
||||
balance_loss = self._balancing_loss(prob_gate, num_sentences)
|
||||
else:
|
||||
balance_loss = 0.0
|
||||
|
||||
# Calculate Importance Loss
|
||||
importance_loss = self._importance_auxiliary_loss(prob_gate)
|
||||
|
||||
# import pdb; pdb.set_trace()
|
||||
|
||||
return expert_mask, combine_tensor, balance_loss, importance_loss
|
||||
|
||||
def _forward_gate_sentence(self, x, attention_mask):
|
||||
"""
|
||||
@ -200,81 +196,37 @@ class MoELayer(nn.Module):
|
||||
the values of extended_attention_mask are -0.0 or -10000
|
||||
it should be adjust to 1/0 version to be processed by experts
|
||||
"""
|
||||
attention_mask = torch.ones(attention_mask.shape[0], attention_mask.shape[1]).to(x.device)
|
||||
x_masked = x * attention_mask.unsqueeze(-1) # torch.Size([bz, 32, 768])
|
||||
x_average = x_masked.sum(1) / attention_mask.unsqueeze(-1).sum(1) # torch.Size([bz, 768])
|
||||
logits_gate = self.gate(x_average) # torch.Size([bz, num_experts])
|
||||
prob_gate = F.softmax(logits_gate, dim=-1) # torch.Size([bz, num_experts])
|
||||
select_prob_gate, gate = torch.topk(prob_gate, self.topk, dim=1) # gate, 每个样本被分配的expert: torch.Size([bz, topk])
|
||||
|
||||
# 这里用l2 norm 去加权
|
||||
if self.weight_type == 'l2_norm':
|
||||
normalized_tensor = torch.nn.functional.normalize(select_prob_gate, p=2, dim=0) # L2 Normalization torch.Size([bz, topk])
|
||||
elif self.weight_type == 'average':
|
||||
normalized_tensor = select_prob_gate / select_prob_gate.sum(dim=1, keepdim=True)
|
||||
|
||||
num_sentences = F.one_hot(gate, self.num_experts).sum(1).gt(0).sum(0) # 每个expert被分配的样本数 torch.Size([num_expert])
|
||||
gate_load = num_sentences.clone()
|
||||
|
||||
# load balancing loss
|
||||
if self.use_balance_loss:
|
||||
balance_loss = self._balancing_loss(prob_gate, num_sentences)
|
||||
else:
|
||||
balance_loss = 0.0
|
||||
|
||||
# importance loss
|
||||
importance_loss = self._importance_auxiliary_loss(prob_gate)
|
||||
|
||||
# forward experts
|
||||
def forward_expert(input_x, expert_idx):
|
||||
input_x = self.experts[expert_idx].forward(input_x)
|
||||
return input_x
|
||||
|
||||
result_lst = list()
|
||||
for i in range(self.topk):
|
||||
# top1、top2... 分别为一组,进行gate分组之后过expert,然后乘以概率后相加
|
||||
tmp_gate = gate[:,i]
|
||||
tmp_prob = normalized_tensor[:,i].unsqueeze(-1).unsqueeze(-1)
|
||||
order = tmp_gate.argsort(0)
|
||||
num_sentences_t = F.one_hot(tmp_gate, self.num_experts).gt(0).sum(0)
|
||||
x1 = x[order] # reorder according to expert number
|
||||
x1 = x1.split(num_sentences_t.tolist(), dim=0) # a list of length self.num_experts
|
||||
# Forward Router
|
||||
expert_mask, combine_tensor, balance_loss, importance_loss = self.router(x, attention_mask)
|
||||
|
||||
# Forward Expert FFN
|
||||
result = []
|
||||
for i in range(self.num_experts):
|
||||
if x1[i].size(0) > 0:
|
||||
result.append(forward_expert(x1[i], i))
|
||||
result = torch.vstack(result)
|
||||
result = result[order.argsort(0)] # restore original order
|
||||
# result_lst.append(result * tmp_prob) # result * prob
|
||||
result_lst.append(result) # result * prob # add_1212
|
||||
for expert_idx in range(self.num_experts):
|
||||
output_x = self.experts[expert_idx].forward(x)
|
||||
result.append(output_x.unsqueeze(0))
|
||||
expert_output = torch.cat(result).permute(1,0,2,3) # torch.Size([batch_size, num_expert, num_tokens, hidden_states])
|
||||
|
||||
# multiply outputs of experts by the routing probability
|
||||
if self.weight_type == 'raw_prob':
|
||||
expert_outputs_combined = expert_output * combine_tensor.unsqueeze(-1).unsqueeze(-1) # torch.Size([batch_size, num_expert, num_tokens, hidden_states])
|
||||
elif self.weight_type == 'no_prob':
|
||||
combine_index = torch.sum(expert_mask, dim=1)
|
||||
expert_outputs_combined = expert_output * combine_index.unsqueeze(-1).unsqueeze(-1) # torch.Size([batch_size, num_expert, num_tokens, hidden_states])
|
||||
|
||||
outputs = torch.sum(expert_outputs_combined, dim=1) # torch.Size([batch_size, num_tokens, hidden_states])
|
||||
|
||||
moe_result = sum(result_lst)
|
||||
# import pdb; pdb.set_trace()
|
||||
return moe_result, (balance_loss+importance_loss), gate
|
||||
|
||||
def _forward_sentence_single_expert(self, x, attention_mask):
|
||||
x_masked = x * attention_mask.unsqueeze(-1)
|
||||
x_average = x_masked.sum(1) / attention_mask.unsqueeze(-1).sum(1)
|
||||
logits_gate = self.gate(x_average)
|
||||
prob_gate = F.softmax(logits_gate, dim=-1)
|
||||
gate = torch.argmax(prob_gate, dim=-1)
|
||||
|
||||
gate_load = F.one_hot(gate, self.num_experts).gt(0).sum(0)
|
||||
x = self.experts[gate.cpu().item()].forward(x)
|
||||
return x, 0.0, gate_load
|
||||
return outputs, (balance_loss+importance_loss), combine_tensor
|
||||
|
||||
def forward(self, x, attention_mask):
|
||||
if self.route_method == "gate-token":
|
||||
x, balance_loss, gate_load = self._forward_gate_token(x)
|
||||
elif self.route_method == "gate-sentence":
|
||||
if x.size(0) == 1:
|
||||
x, balance_loss, gate_load = self._forward_sentence_single_expert(x, attention_mask)
|
||||
else:
|
||||
x, balance_loss, gate_load = self._forward_gate_sentence(x, attention_mask)
|
||||
elif self.route_method == "gate-sentence-post":
|
||||
x, balance_loss, gate_load = self._forward_gate_sentence_post(x, attention_mask)
|
||||
else:
|
||||
raise KeyError("Routing method not supported.")
|
||||
|
||||
# import pdb; pdb.set_trace()
|
||||
return x, balance_loss, gate_load
|
||||
|
330
minigpt4/models/moe/moe_layer_backup.py
Normal file
330
minigpt4/models/moe/moe_layer_backup.py
Normal file
@ -0,0 +1,330 @@
|
||||
import copy
|
||||
import pickle
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
class MoELayer(nn.Module):
|
||||
def __init__(self, hidden_size, expert, num_experts, route_method, topk=1, use_balance_loss=True, weight_type='l2_norm'):
|
||||
# remove hash list
|
||||
nn.Module.__init__(self)
|
||||
self.num_experts = num_experts
|
||||
self.experts = nn.ModuleList([copy.deepcopy(expert) for i in range(num_experts)])
|
||||
self.route_method = route_method
|
||||
self.topk = topk
|
||||
self.use_balance_loss = use_balance_loss
|
||||
self.weight_type = weight_type
|
||||
|
||||
if route_method in ["gate-token", "gate-sentence"]:
|
||||
self.gate = nn.Linear(hidden_size, num_experts, bias=False).float()
|
||||
elif route_method in ["gate-sentence-post"]:
|
||||
gate = nn.Linear(hidden_size, 1, bias=False).float()
|
||||
# self.gates = nn.ModuleList([copy.deepcopy(gate) for i in range(num_experts)])
|
||||
self.gate = gate
|
||||
else:
|
||||
raise KeyError("Routing method not supported.")
|
||||
|
||||
def _balancing_loss(self, prob_gate, num_tokens):
|
||||
# From MOEBERT
|
||||
# compute the load balancing loss
|
||||
# prob_gate,是 [bz, num_expert],每个样本被分配给每个expert的概率
|
||||
# 等价于 VMOE 中 _gshard_auxiliary_loss
|
||||
P = prob_gate.mean(0) # torch.Size([num_expert]) 每个expert被分配到样本的平均概率
|
||||
temp = num_tokens.float()
|
||||
f = temp / temp.sum(0, keepdim=True) # 每个expert被分配的sample比例
|
||||
balance_loss = self.num_experts * torch.sum(P * f)
|
||||
return balance_loss
|
||||
|
||||
def _importance_auxiliary_loss(self, prob_gate):
|
||||
# From VMOE
|
||||
# _importance_auxiliary_loss
|
||||
axis = tuple(range(prob_gate.ndim - 1)) # All except last.
|
||||
importance_per_expert = torch.sum(prob_gate, dim=axis)
|
||||
std_importance_per_expert = torch.std(importance_per_expert)
|
||||
mean_importance_per_expert = torch.mean(importance_per_expert)
|
||||
# Compute coefficient of variation (i.e. std/mean) squared.
|
||||
return (std_importance_per_expert / mean_importance_per_expert)**2
|
||||
|
||||
def _forward_gate_token(self, x):
|
||||
bsz, seq_len, dim = x.size()
|
||||
|
||||
x = x.view(-1, dim)
|
||||
logits_gate = self.gate(x)
|
||||
prob_gate = F.softmax(logits_gate, dim=-1)
|
||||
gate = torch.argmax(prob_gate, dim=-1)
|
||||
|
||||
order = gate.argsort(0)
|
||||
num_tokens = F.one_hot(gate, self.num_experts).gt(0).sum(0)
|
||||
gate_load = num_tokens.clone()
|
||||
x = x[order] # reorder according to expert number
|
||||
x = x.split(num_tokens.tolist(), dim=0) # a list of length self.num_experts
|
||||
|
||||
# compute the load balancing loss
|
||||
P = prob_gate.mean(0)
|
||||
temp = num_tokens.float()
|
||||
f = temp / temp.sum(0, keepdim=True)
|
||||
balance_loss = self.num_experts * torch.sum(P * f)
|
||||
|
||||
prob_gate = prob_gate.gather(dim=1, index=gate.unsqueeze(1))
|
||||
prob_gate = prob_gate[order]
|
||||
prob_gate = prob_gate.split(num_tokens.tolist(), dim=0)
|
||||
|
||||
def forward_expert(input_x, prob_x, expert_idx):
|
||||
input_x = self.experts[expert_idx].forward(input_x)
|
||||
input_x = input_x * prob_x
|
||||
return input_x
|
||||
|
||||
x = [forward_expert(x[i], prob_gate[i], i) for i in range(self.num_experts)]
|
||||
x = torch.vstack(x)
|
||||
x = x[order.argsort(0)] # restore original order
|
||||
x = x.view(bsz, seq_len, dim)
|
||||
|
||||
return x, balance_loss, gate_load
|
||||
|
||||
def _forward_gate_sentence_top1_raw(self, x, attention_mask):
|
||||
"""
|
||||
x: query_attention_output , torch.Size([bz, 32, 768])
|
||||
attention_mask: torch.ones([bz, 32])
|
||||
|
||||
### Notice:
|
||||
the raw version of expert_attention_mask is the extended_attention_mask,
|
||||
which will be add to attention_score directly
|
||||
the values of extended_attention_mask are -0.0 or -10000
|
||||
it should be adjust to 1/0 version to be processed by experts
|
||||
"""
|
||||
attention_mask = torch.ones(attention_mask.shape[0], attention_mask.shape[1]).to(x.device)
|
||||
x_masked = x * attention_mask.unsqueeze(-1) # torch.Size([bz, 32, 768])
|
||||
x_average = x_masked.sum(1) / attention_mask.unsqueeze(-1).sum(1) # torch.Size([bz, 768])
|
||||
logits_gate = self.gate(x_average) # torch.Size([bz, num_experts])
|
||||
prob_gate = F.softmax(logits_gate, dim=-1) # torch.Size([bz, num_experts])
|
||||
gate = torch.argmax(prob_gate, dim=-1) # torch.Size([bz])
|
||||
|
||||
order = gate.argsort(0)
|
||||
num_sentences = F.one_hot(gate, self.num_experts).gt(0).sum(0)
|
||||
gate_load = num_sentences.clone()
|
||||
x = x[order] # reorder according to expert number
|
||||
x = x.split(num_sentences.tolist(), dim=0) # a list of length self.num_experts
|
||||
|
||||
# compute the load balancing loss
|
||||
P = prob_gate.mean(0)
|
||||
temp = num_sentences.float()
|
||||
f = temp / temp.sum(0, keepdim=True)
|
||||
balance_loss = self.num_experts * torch.sum(P * f)
|
||||
|
||||
prob_gate = prob_gate.gather(dim=1, index=gate.unsqueeze(1))
|
||||
prob_gate = prob_gate[order]
|
||||
prob_gate = prob_gate.split(num_sentences.tolist(), dim=0)
|
||||
|
||||
def forward_expert(input_x, prob_x, expert_idx):
|
||||
input_x = self.experts[expert_idx].forward(input_x)
|
||||
input_x = input_x * prob_x.unsqueeze(-1)
|
||||
return input_x
|
||||
|
||||
result = []
|
||||
for i in range(self.num_experts):
|
||||
if x[i].size(0) > 0:
|
||||
result.append(forward_expert(x[i], prob_gate[i], i))
|
||||
result = torch.vstack(result)
|
||||
result = result[order.argsort(0)] # restore original order
|
||||
|
||||
return result, balance_loss, gate_load
|
||||
|
||||
def _forward_gate_sentence_post(self, x, attention_mask):
|
||||
"""
|
||||
x: query_attention_output; torch.Size([bz, 32, 768])
|
||||
attention_mask: torch.ones([bz, 32])
|
||||
bz = 4
|
||||
x = torch.randn(bz,32,768)
|
||||
attention_mask = torch.ones([bz, 32])
|
||||
|
||||
"""
|
||||
attention_mask = torch.ones(attention_mask.shape[0], attention_mask.shape[1]).to(x.device)
|
||||
x_masked = x * attention_mask.unsqueeze(-1) # torch.Size([bz, 32, 768])
|
||||
|
||||
def forward_expert(input_x, expert_idx):
|
||||
# input_x += torch.randn(4,32,768)
|
||||
# return input_x
|
||||
output_x = self.experts[expert_idx].forward(input_x)
|
||||
return output_x
|
||||
|
||||
outputs = list()
|
||||
logits_gate_lst = list()
|
||||
for expert_idx in range(self.num_experts):
|
||||
output_x = forward_expert(x_masked, expert_idx)
|
||||
outputs.append(output_x.unsqueeze(0))
|
||||
|
||||
output_x_aver = output_x.sum(1) / attention_mask.unsqueeze(-1).sum(1) # torch.Size([bz, 768])
|
||||
# gate_acore = self.gates[expert_idx](output_x_aver)
|
||||
gate_score = self.gate(output_x_aver)
|
||||
logits_gate_lst.append(gate_score)
|
||||
|
||||
candidate_output = torch.cat(outputs) # torch.Size([num_expert, bz, 32, 768])
|
||||
logits_gate = torch.cat(logits_gate_lst,dim=1)# torch.Size([bz, num_expert])
|
||||
prob_gate = F.softmax(logits_gate, dim=-1) # torch.Size([bz, num_experts])
|
||||
topk_values, gate = torch.topk(prob_gate, self.topk, dim=1) # gate, 每个样本被分配的expert: torch.Size([bz, topk])
|
||||
num_sentences = F.one_hot(gate, self.num_experts).sum(1).gt(0).sum(0) # 每个expert被分配的样本数 torch.Size([num_expert])
|
||||
gate_load = num_sentences.clone()
|
||||
|
||||
# load balancing loss
|
||||
if self.use_balance_loss:
|
||||
balance_loss = self._balancing_loss(prob_gate, num_sentences)
|
||||
else:
|
||||
balance_loss = 0.0
|
||||
|
||||
# importance loss
|
||||
importance_loss = self._importance_auxiliary_loss(prob_gate)
|
||||
|
||||
# output_average = candidate_output.sum(2) / candidate_attn_mask.unsqueeze(-1).sum(2) # torch.Size([num_expert, bz, 768])
|
||||
# output_average = torch.permute(output_average, (1, 0, 2)) # torch.Size([bz, num_expert, 768])
|
||||
# logits_gate = self.gate(output_average) # torch.Size([bz, num_experts, 1])
|
||||
|
||||
prob_gate_topk = torch.zeros_like(prob_gate)
|
||||
prob_gate_topk.scatter_(1, gate, topk_values)
|
||||
|
||||
if self.weight_type == 'average':
|
||||
# torch.Size([bz, num_expert]) 未选中的expert prob_gate_norm为0
|
||||
prob_gate_normalized = prob_gate_topk / prob_gate_topk.sum(dim=1, keepdim=True)
|
||||
elif self.weight_type == 'raw_prob':
|
||||
prob_gate_normalized = prob_gate_topk
|
||||
elif self.weight_type == 'softmax_norm':
|
||||
prob_gate_normalized = F.softmax(prob_gate_topk, dim=-1) # torch.Size([bz, num_expert])
|
||||
|
||||
candidate_output_ad = torch.permute(candidate_output, (1, 0, 2, 3)) # torch.Size([bz, num_expert, 32, 768])
|
||||
results = prob_gate_normalized.unsqueeze(-1).unsqueeze(-1) * candidate_output_ad # torch.Size([bz, num_expert, 32, 768])
|
||||
moe_result = torch.sum(results, dim=1) # torch.Size([bz, 32, 768])
|
||||
# import pdb;pdb.set_trace()
|
||||
|
||||
return moe_result, (balance_loss+importance_loss), prob_gate_normalized
|
||||
|
||||
# def _forward_gate_sentence(self, x, attention_mask):
|
||||
|
||||
# attention_mask = torch.ones(attention_mask.shape[0], attention_mask.shape[1]).to(x.device)
|
||||
# x_masked = x * attention_mask.unsqueeze(-1) # torch.Size([bz, 32, 768])
|
||||
# x_average = x_masked.sum(1) / attention_mask.unsqueeze(-1).sum(1)
|
||||
# logits_gate = self.gate(x_average)
|
||||
# prob_gate = F.softmax(logits_gate, dim=-1)
|
||||
# gate = torch.argmax(prob_gate, dim=-1)
|
||||
|
||||
# order = gate.argsort(0)
|
||||
# num_sentences = F.one_hot(gate, self.num_experts).gt(0).sum(0)
|
||||
# gate_load = num_sentences.clone()
|
||||
# x = x[order] # reorder according to expert number
|
||||
# x = x.split(num_sentences.tolist(), dim=0) # a list of length self.num_experts
|
||||
|
||||
# # compute the load balancing loss
|
||||
# P = prob_gate.mean(0)
|
||||
# temp = num_sentences.float()
|
||||
# f = temp / temp.sum(0, keepdim=True)
|
||||
# balance_loss = self.num_experts * torch.sum(P * f)
|
||||
|
||||
# prob_gate = prob_gate.gather(dim=1, index=gate.unsqueeze(1))
|
||||
# prob_gate = prob_gate[order]
|
||||
# prob_gate = prob_gate.split(num_sentences.tolist(), dim=0)
|
||||
|
||||
# def forward_expert(input_x, prob_x, expert_idx):
|
||||
# input_x = self.experts[expert_idx].forward(input_x)
|
||||
# input_x = input_x * prob_x.unsqueeze(-1)
|
||||
# return input_x
|
||||
|
||||
# result = []
|
||||
# for i in range(self.num_experts):
|
||||
# if x[i].size(0) > 0:
|
||||
# result.append(forward_expert(x[i], prob_gate[i], i))
|
||||
# result = torch.vstack(result)
|
||||
# result = result[order.argsort(0)] # restore original order
|
||||
|
||||
# return result, balance_loss, gate_load
|
||||
|
||||
def _forward_gate_sentence(self, x, attention_mask):
|
||||
"""
|
||||
x: query_attention_output , torch.Size([bz, 32, 768])
|
||||
attention_mask: torch.ones([bz, 32])
|
||||
|
||||
### Notice:
|
||||
the raw version of expert_attention_mask is the extended_attention_mask,
|
||||
which will be add to attention_score directly
|
||||
the values of extended_attention_mask are -0.0 or -10000
|
||||
it should be adjust to 1/0 version to be processed by experts
|
||||
"""
|
||||
attention_mask = torch.ones(attention_mask.shape[0], attention_mask.shape[1]).to(x.device)
|
||||
x_masked = x * attention_mask.unsqueeze(-1) # torch.Size([bz, 32, 768])
|
||||
x_average = x_masked.sum(1) / attention_mask.unsqueeze(-1).sum(1) # torch.Size([bz, 768])
|
||||
logits_gate = self.gate(x_average) # torch.Size([bz, num_experts])
|
||||
prob_gate = F.softmax(logits_gate, dim=-1) # torch.Size([bz, num_experts])
|
||||
select_prob_gate, gate = torch.topk(prob_gate, self.topk, dim=1) # gate, 每个样本被分配的expert: torch.Size([bz, topk])
|
||||
|
||||
# 这里用l2 norm 去加权
|
||||
if self.weight_type == 'l2_norm':
|
||||
# actually neigther dim=0 nor dim=1 is right
|
||||
normalized_tensor = torch.nn.functional.normalize(select_prob_gate, p=2, dim=1) # L2 Normalization torch.Size([bz, topk])
|
||||
elif self.weight_type == 'l2_norm_0':
|
||||
normalized_tensor = torch.nn.functional.normalize(select_prob_gate, p=2, dim=0) # L2 Normalization torch.Size([bz, topk])
|
||||
elif self.weight_type == 'average':
|
||||
normalized_tensor = select_prob_gate / select_prob_gate.sum(dim=1, keepdim=True)
|
||||
elif self.weight_type == 'raw_prob':
|
||||
normalized_tensor = select_prob_gate
|
||||
|
||||
num_sentences = F.one_hot(gate, self.num_experts).sum(1).gt(0).sum(0) # 每个expert被分配的样本数 torch.Size([num_expert])
|
||||
gate_load = num_sentences.clone()
|
||||
|
||||
# load balancing loss
|
||||
if self.use_balance_loss:
|
||||
balance_loss = self._balancing_loss(prob_gate, num_sentences)
|
||||
else:
|
||||
balance_loss = 0.0
|
||||
|
||||
# importance loss
|
||||
importance_loss = self._importance_auxiliary_loss(prob_gate)
|
||||
|
||||
# forward experts
|
||||
def forward_expert(input_x, expert_idx):
|
||||
input_x = self.experts[expert_idx].forward(input_x)
|
||||
return input_x
|
||||
|
||||
result_lst = list()
|
||||
for i in range(self.topk):
|
||||
# top1、top2... 分别为一组,进行gate分组之后过expert,然后乘以概率后相加
|
||||
tmp_gate = gate[:,i]
|
||||
tmp_prob = normalized_tensor[:,i].unsqueeze(-1).unsqueeze(-1)
|
||||
order = tmp_gate.argsort(0)
|
||||
num_sentences_t = F.one_hot(tmp_gate, self.num_experts).gt(0).sum(0)
|
||||
x1 = x[order] # reorder according to expert number
|
||||
x1 = x1.split(num_sentences_t.tolist(), dim=0) # a list of length self.num_experts
|
||||
|
||||
result = []
|
||||
for i in range(self.num_experts):
|
||||
if x1[i].size(0) > 0:
|
||||
result.append(forward_expert(x1[i], i))
|
||||
result = torch.vstack(result)
|
||||
result = result[order.argsort(0)] # restore original order
|
||||
result_lst.append(result * tmp_prob) # result * prob
|
||||
# result_lst.append(result) # result * prob # add_1212
|
||||
|
||||
moe_result = sum(result_lst)
|
||||
return moe_result, (balance_loss+importance_loss), gate
|
||||
|
||||
def _forward_sentence_single_expert(self, x, attention_mask):
|
||||
x_masked = x * attention_mask.unsqueeze(-1)
|
||||
x_average = x_masked.sum(1) / attention_mask.unsqueeze(-1).sum(1)
|
||||
logits_gate = self.gate(x_average)
|
||||
prob_gate = F.softmax(logits_gate, dim=-1)
|
||||
gate = torch.argmax(prob_gate, dim=-1)
|
||||
|
||||
gate_load = F.one_hot(gate, self.num_experts).gt(0).sum(0)
|
||||
x = self.experts[gate.cpu().item()].forward(x)
|
||||
return x, 0.0, gate_load
|
||||
|
||||
def forward(self, x, attention_mask):
|
||||
if self.route_method == "gate-token":
|
||||
x, balance_loss, gate_load = self._forward_gate_token(x)
|
||||
elif self.route_method == "gate-sentence":
|
||||
if x.size(0) == 1:
|
||||
x, balance_loss, gate_load = self._forward_sentence_single_expert(x, attention_mask)
|
||||
else:
|
||||
x, balance_loss, gate_load = self._forward_gate_sentence(x, attention_mask)
|
||||
elif self.route_method == "gate-sentence-post":
|
||||
x, balance_loss, gate_load = self._forward_gate_sentence_post(x, attention_mask)
|
||||
else:
|
||||
raise KeyError("Routing method not supported.")
|
||||
# import pdb; pdb.set_trace()
|
||||
return x, balance_loss, gate_load
|
@ -92,7 +92,6 @@ class PrePromptMoE(PromptMoEBase):
|
||||
self.topk = topk
|
||||
if route_method in ["gate-token", "gate-single-token", "gate-sentence"]:
|
||||
self.gate = nn.Linear(hidden_size, num_experts, bias=False).float()
|
||||
print(self.gate)
|
||||
else:
|
||||
raise KeyError("Routing method not supported.")
|
||||
|
||||
|
@ -5,7 +5,7 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
class RouteMoELayer(nn.Module):
|
||||
def __init__(self, hidden_size, expert, num_experts, num_beams=2, layer_judge=None, route_method="pre-route"):
|
||||
def __init__(self, hidden_size, expert, num_experts, num_beams=2, layer_judge=None, route_method="pre-route", weight_type="ffn_prob"):
|
||||
# remove hash list
|
||||
nn.Module.__init__(self)
|
||||
self.num_experts = num_experts
|
||||
@ -13,6 +13,7 @@ class RouteMoELayer(nn.Module):
|
||||
self.num_beams = num_beams
|
||||
self.hidden_size = hidden_size
|
||||
self.layer_judge = layer_judge
|
||||
self.weight_type = weight_type
|
||||
|
||||
self.route_method = route_method
|
||||
if self.route_method == "pre-route":
|
||||
@ -22,6 +23,17 @@ class RouteMoELayer(nn.Module):
|
||||
self.gate = gate
|
||||
# self.gates = nn.ModuleList([copy.deepcopy(gate) for i in range(num_experts)])
|
||||
|
||||
def _importance_auxiliary_loss(self, prob_gate):
|
||||
# From VMOE
|
||||
# _importance_auxiliary_loss
|
||||
axis = tuple(range(prob_gate.ndim - 1)) # All except last.
|
||||
importance_per_expert = torch.sum(prob_gate, dim=axis)
|
||||
std_importance_per_expert = torch.std(importance_per_expert)
|
||||
mean_importance_per_expert = torch.mean(importance_per_expert)
|
||||
# Compute coefficient of variation (i.e. std/mean) squared.
|
||||
return (std_importance_per_expert / mean_importance_per_expert)**2
|
||||
|
||||
|
||||
def forward_gate(self, x):
|
||||
"""
|
||||
x : torch.Size([bz*num_beams, 32, 768]) or torch.Size([bz, 32, 768])
|
||||
@ -29,7 +41,8 @@ class RouteMoELayer(nn.Module):
|
||||
"""
|
||||
attention_mask = torch.ones(x.shape[0], x.shape[1]).to(x.device)
|
||||
x_masked = x * attention_mask.unsqueeze(-1) # torch.Size([bz*num_beams, 32, 768])
|
||||
x_average = x_masked.sum(1) / attention_mask.unsqueeze(-1).sum(1) # torch.Size([bz*num_beams, 768])
|
||||
# x_average = x_masked.sum(1) / attention_mask.unsqueeze(-1).sum(1) # torch.Size([bz*num_beams, 768])
|
||||
x_average = torch.mean(x_masked, dim=1) # torch.Size([bz*num_beams, 768])
|
||||
logits_gate = self.gate(x_average) # torch.Size([bz*num_beams, num_experts])
|
||||
prob_gate = F.softmax(logits_gate, dim=-1) # torch.Size([bz*num_beams, num_experts])
|
||||
return prob_gate
|
||||
@ -42,7 +55,7 @@ class RouteMoELayer(nn.Module):
|
||||
topk_values, gate = torch.topk(current_scores, self.num_beams, dim=1) # gate, 每个样本被分配的expert: torch.Size([bz, topk])
|
||||
beam_scores = topk_values.view(self.num_beams * batch_size) # torch.Size([bz * num_beams])
|
||||
expert_route = gate.view(self.num_beams * batch_size).unsqueeze(1) # torch.Size([bz * num_beams,1])
|
||||
beam_idx = None
|
||||
beam_idx = torch.tensor(range(self.num_beams * batch_size))
|
||||
else:
|
||||
if self.layer_judge=='first' and self.route_method == 'post-route':
|
||||
batch_size = batch_size
|
||||
@ -89,47 +102,56 @@ class RouteMoELayer(nn.Module):
|
||||
|
||||
return beam_scores, expert_route, beam_idx
|
||||
|
||||
|
||||
def forward_expert_ffn(self, x, expert_select, beam_scores):
|
||||
def forward_expert_ffn(self, x, expert_select, current_scores):
|
||||
"""
|
||||
x_repeat : [bz*num_beams, 32,768]
|
||||
expert_select : [bz*num_beams]
|
||||
current_scores : [bz*num_beams, num_experts] / [bz, num_experts]
|
||||
"""
|
||||
# add_1212 l2_normalization
|
||||
# normalized_tensor = torch.nn.functional.normalize(beam_scores, p=2, dim=0) # L2 Normalization torch.Size([bz, topk])
|
||||
# add_1228 l2_normalization
|
||||
# normalized_tensor = torch.nn.functional.normalize(current_scores, p=2, dim=0) # L2 Normalization torch.Size([bz, topk])
|
||||
# tmp_prob = normalized_tensor.unsqueeze(-1).unsqueeze(-1)
|
||||
|
||||
# import pdb;pdb.set_trace()
|
||||
outputs = list()
|
||||
for i in range(x.shape[0]):
|
||||
output_x = self.experts[expert_select[i]].forward(x[i])
|
||||
outputs.append(output_x.unsqueeze(0))
|
||||
candidate_output = torch.cat(outputs)
|
||||
|
||||
# candidate_output = candidate_output * tmp_prob
|
||||
return candidate_output # torch.Size([bz*num_beams, 32, 768])
|
||||
|
||||
for i in range(self.num_experts):
|
||||
output_x = self.experts[i].forward(x)
|
||||
outputs.append(output_x.unsqueeze(1))
|
||||
candidate_output = torch.cat(outputs, dim=1)
|
||||
expert_select_matrix = F.one_hot(expert_select, self.num_experts)
|
||||
if self.weight_type == 'ffn_prob':
|
||||
tmp_prob = current_scores * expert_select_matrix
|
||||
candidate_output = candidate_output * tmp_prob.unsqueeze(-1).unsqueeze(-1)
|
||||
else:
|
||||
candidate_output = candidate_output * expert_select_matrix.unsqueeze(-1).unsqueeze(-1)
|
||||
output = torch.sum(candidate_output, dim=1)
|
||||
# import pdb;pdb.set_trace()
|
||||
return output # torch.Size([bz*num_beams, 32, 768])
|
||||
|
||||
def forward_pre_route(self, x, beam_scores, expert_route, use_log=True):
|
||||
|
||||
current_scores = self.forward_gate(x) # [bz*num_beams, 5]
|
||||
current_scores = self.forward_gate(x) # [bz, num_beams] / [bz*num_beams, num_beams]
|
||||
|
||||
importance_loss = self._importance_auxiliary_loss(current_scores)
|
||||
|
||||
if use_log:
|
||||
current_scores_log = torch.log(current_scores) # 取log之后可以直接相加
|
||||
else:
|
||||
current_scores_log = current_scores
|
||||
|
||||
# import pdb;pdb.set_trace()
|
||||
batch_size, num_tokens = x.shape[0], x.shape[1]
|
||||
beam_scores, expert_route, beam_idx = self.beam_search(current_scores_log, beam_scores, expert_route, batch_size)
|
||||
|
||||
current_expert_select = expert_route[:,-1]
|
||||
|
||||
if self.layer_judge=='first': # expand first dim to batch_size * num_beams
|
||||
replicated_tensor = x.unsqueeze(1).expand(batch_size, self.num_beams, num_tokens, self.hidden_size)
|
||||
x = replicated_tensor.contiguous().view(-1, num_tokens, self.hidden_size) # [bz*num_beams, 32,768]
|
||||
current_scores_t = current_scores.unsqueeze(1).expand(batch_size, self.num_beams, self.num_experts)
|
||||
current_scores = current_scores_t.contiguous().view(-1, self.num_experts) # [bz*num_beams, num_experts]
|
||||
|
||||
candidate_output = self.forward_expert_ffn(x, current_expert_select, beam_scores) # [bz*num_beams, 32,768]
|
||||
|
||||
return candidate_output, beam_scores, expert_route, beam_idx
|
||||
input_x = x[beam_idx]
|
||||
candidate_output = self.forward_expert_ffn(input_x, current_expert_select, current_scores) # [bz*num_beams, 32,768]
|
||||
# import pdb;pdb.set_trace()
|
||||
return candidate_output, beam_scores, expert_route, beam_idx, importance_loss
|
||||
|
||||
|
||||
def forward_post_route(self, x, beam_scores, expert_route, use_log=True):
|
||||
@ -145,12 +167,14 @@ class RouteMoELayer(nn.Module):
|
||||
logits_gate_lst = list()
|
||||
for expert_idx in range(self.num_experts):
|
||||
output_x = forward_expert(x_masked, expert_idx)
|
||||
outputs.append(output_x.unsqueeze(0))
|
||||
output_x_aver = output_x.sum(1) / attention_mask.unsqueeze(-1).sum(1) # torch.Size([bz*num_beam, 768])
|
||||
# output_x_aver = output_x.sum(1) / attention_mask.unsqueeze(-1).sum(1) # torch.Size([bz*num_beam, 768])
|
||||
output_x_aver = torch.mean(output_x, dim=1)
|
||||
# gate_score = self.gates[expert_idx](output_x_aver)
|
||||
gate_score = self.gate(output_x_aver)
|
||||
logits_gate_lst.append(gate_score)
|
||||
candidate_output = torch.cat(outputs) # torch.Size([num_expert, bz*num_beam, 32, 768])
|
||||
outputs.append(output_x.unsqueeze(0))
|
||||
|
||||
candidate_output_raw = torch.cat(outputs) # torch.Size([num_expert, bz*num_beam, 32, 768])
|
||||
logits_gate = torch.cat(logits_gate_lst,dim=1)# torch.Size([bz*num_beam, num_expert])
|
||||
current_scores = F.softmax(logits_gate, dim=-1) # torch.Size([bz*num_beam, num_experts])
|
||||
|
||||
@ -159,23 +183,32 @@ class RouteMoELayer(nn.Module):
|
||||
else:
|
||||
current_scores_log = current_scores
|
||||
|
||||
batch_size = x.shape[0] # bz*num_beam
|
||||
# importance loss
|
||||
importance_loss = self._importance_auxiliary_loss(current_scores)
|
||||
|
||||
batch_size, num_tokens = x.shape[0], x.shape[1] # bz*num_beam
|
||||
beam_scores, expert_route, beam_idx = self.beam_search(current_scores_log, beam_scores, expert_route, batch_size)
|
||||
# beam_scores torch.Size([bz*num_beam])
|
||||
# expert_route torch.Size([bz*num_beam, layer_n])
|
||||
current_select_expert = expert_route[:,-1]
|
||||
# current_select_expert torch.Size([bz*num_beam, 1])
|
||||
|
||||
output = list()
|
||||
for i in range(beam_idx.shape[0]):
|
||||
b_idx = beam_idx[i]
|
||||
ex_idx = current_select_expert[i]
|
||||
ex_out = candidate_output[ex_idx, b_idx, :,:]
|
||||
output.append(ex_out.unsqueeze(0))
|
||||
if self.layer_judge == 'first':
|
||||
replicated_tensor = candidate_output_raw.unsqueeze(2).expand(self.num_experts, batch_size, self.num_beams, num_tokens, self.hidden_size)
|
||||
candidate_output_raw = replicated_tensor.contiguous().view(self.num_experts, -1, num_tokens, self.hidden_size) # [bz*num_beams, 32,768]
|
||||
current_scores_t = current_scores.unsqueeze(1).expand(batch_size, self.num_beams, self.num_experts)
|
||||
current_scores = current_scores_t.contiguous().view(-1, self.num_experts) # [bz*num_beams, num_experts]
|
||||
|
||||
final_output = torch.concat(output, dim=0)
|
||||
|
||||
return final_output, beam_scores, expert_route, beam_idx
|
||||
candidate_output = candidate_output_raw.permute(1, 0, 2, 3)[beam_idx] # torch.Size([8, 2, 32, 768])
|
||||
expert_select_matrix = F.one_hot(current_select_expert, self.num_experts)
|
||||
if self.weight_type == 'ffn_prob':
|
||||
tmp_prob = current_scores[beam_idx] * expert_select_matrix
|
||||
output = candidate_output * tmp_prob.unsqueeze(-1).unsqueeze(-1)
|
||||
else:
|
||||
output = candidate_output * expert_select_matrix.unsqueeze(-1).unsqueeze(-1)
|
||||
final_output = torch.sum(output, dim=1)
|
||||
|
||||
return final_output, beam_scores, expert_route, beam_idx, importance_loss
|
||||
|
||||
def forward(self, x, attention_mask, beam_scores, expert_route, use_log=True):
|
||||
"""
|
||||
@ -184,11 +217,11 @@ class RouteMoELayer(nn.Module):
|
||||
|
||||
"""
|
||||
if self.route_method == 'pre-route':
|
||||
candidate_output, beam_scores, expert_route, beam_idx = self.forward_pre_route(x, beam_scores, expert_route, use_log=True)
|
||||
candidate_output, beam_scores, expert_route, beam_idx, importance_loss = self.forward_pre_route(x, beam_scores, expert_route, use_log=True)
|
||||
elif self.route_method == "post-route":
|
||||
candidate_output, beam_scores, expert_route, beam_idx = self.forward_post_route(x, beam_scores, expert_route, use_log=True)
|
||||
candidate_output, beam_scores, expert_route, beam_idx, importance_loss = self.forward_post_route(x, beam_scores, expert_route, use_log=True)
|
||||
|
||||
return candidate_output, beam_scores, expert_route, beam_idx
|
||||
return candidate_output, beam_scores, expert_route, beam_idx, importance_loss
|
||||
|
||||
|
||||
|
||||
|
294
minigpt4/models/moe/test_moe_layer.py
Normal file
294
minigpt4/models/moe/test_moe_layer.py
Normal file
@ -0,0 +1,294 @@
|
||||
import copy
|
||||
import pickle
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
import copy
|
||||
import pickle
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
class MoELayer(nn.Module):
|
||||
def __init__(self, hidden_size, expert, num_experts, route_method, topk=1, use_balance_loss=True, weight_type='raw_prob, topk(softmax)'):
|
||||
# remove hash list
|
||||
nn.Module.__init__(self)
|
||||
self.num_experts = num_experts
|
||||
self.experts = nn.ModuleList([copy.deepcopy(expert) for i in range(num_experts)])
|
||||
self.route_method = route_method
|
||||
self.topk = topk
|
||||
self.use_balance_loss = use_balance_loss
|
||||
self.weight_type = weight_type
|
||||
|
||||
if route_method in ["gate-token", "gate-sentence"]:
|
||||
self.gate = nn.Linear(hidden_size, num_experts, bias=False).float()
|
||||
elif route_method in ["gate-sentence-post"]:
|
||||
gate = nn.Linear(hidden_size, 1, bias=False).float()
|
||||
# self.gates = nn.ModuleList([copy.deepcopy(gate) for i in range(num_experts)])
|
||||
self.gate = gate
|
||||
else:
|
||||
raise KeyError("Routing method not supported.")
|
||||
|
||||
def _balancing_loss(self, prob_gate, num_tokens):
|
||||
# From MOEBERT
|
||||
# compute the load balancing loss
|
||||
# prob_gate,是 [bz, num_expert],每个样本被分配给每个expert的概率
|
||||
# 等价于 VMOE 中 _gshard_auxiliary_loss
|
||||
P = prob_gate.mean(0) # torch.Size([num_expert]) 每个expert被分配到样本的平均概率
|
||||
temp = num_tokens.float()
|
||||
f = temp / temp.sum(0, keepdim=True) # 每个expert被分配的sample比例
|
||||
balance_loss = self.num_experts * torch.sum(P * f)
|
||||
return balance_loss
|
||||
|
||||
def _importance_auxiliary_loss(self, prob_gate):
|
||||
# From VMOE
|
||||
# _importance_auxiliary_loss
|
||||
axis = tuple(range(prob_gate.ndim - 1)) # All except last.
|
||||
importance_per_expert = torch.sum(prob_gate, dim=axis)
|
||||
std_importance_per_expert = torch.std(importance_per_expert)
|
||||
mean_importance_per_expert = torch.mean(importance_per_expert)
|
||||
# Compute coefficient of variation (i.e. std/mean) squared.
|
||||
return (std_importance_per_expert / mean_importance_per_expert)**2
|
||||
|
||||
def _forward_gate_token(self, x):
|
||||
bsz, seq_len, dim = x.size()
|
||||
|
||||
x = x.view(-1, dim)
|
||||
logits_gate = self.gate(x)
|
||||
prob_gate = F.softmax(logits_gate, dim=-1)
|
||||
gate = torch.argmax(prob_gate, dim=-1)
|
||||
|
||||
order = gate.argsort(0)
|
||||
num_tokens = F.one_hot(gate, self.num_experts).gt(0).sum(0)
|
||||
gate_load = num_tokens.clone()
|
||||
x = x[order] # reorder according to expert number
|
||||
x = x.split(num_tokens.tolist(), dim=0) # a list of length self.num_experts
|
||||
|
||||
# compute the load balancing loss
|
||||
P = prob_gate.mean(0)
|
||||
temp = num_tokens.float()
|
||||
f = temp / temp.sum(0, keepdim=True)
|
||||
balance_loss = self.num_experts * torch.sum(P * f)
|
||||
|
||||
prob_gate = prob_gate.gather(dim=1, index=gate.unsqueeze(1))
|
||||
prob_gate = prob_gate[order]
|
||||
prob_gate = prob_gate.split(num_tokens.tolist(), dim=0)
|
||||
|
||||
def forward_expert(input_x, prob_x, expert_idx):
|
||||
input_x = self.experts[expert_idx].forward(input_x)
|
||||
input_x = input_x * prob_x
|
||||
return input_x
|
||||
|
||||
x = [forward_expert(x[i], prob_gate[i], i) for i in range(self.num_experts)]
|
||||
x = torch.vstack(x)
|
||||
x = x[order.argsort(0)] # restore original order
|
||||
x = x.view(bsz, seq_len, dim)
|
||||
|
||||
return x, balance_loss, gate_load
|
||||
|
||||
def _forward_gate_sentence_post(self, x, attention_mask):
|
||||
"""
|
||||
x: query_attention_output; torch.Size([bz, 32, 768])
|
||||
attention_mask: torch.ones([bz, 32])
|
||||
bz = 4
|
||||
x = torch.randn(bz,32,768)
|
||||
attention_mask = torch.ones([bz, 32])
|
||||
|
||||
"""
|
||||
# Prepare Input x
|
||||
attention_mask = torch.ones(attention_mask.shape[0], attention_mask.shape[1]).to(x.device)
|
||||
x_masked = x * attention_mask.unsqueeze(-1) # torch.Size([bz, 32, 768])
|
||||
|
||||
# FeedForward(x) & Forward Gate
|
||||
outputs = list()
|
||||
logits_gate_lst = list()
|
||||
for expert_idx in range(self.num_experts):
|
||||
output_x = self.experts[expert_idx].forward(x_masked)
|
||||
outputs.append(output_x.unsqueeze(0))
|
||||
|
||||
output_x_aver = torch.mean(output_x, dim=1)
|
||||
# gate_acore = self.gates[expert_idx](output_x_aver)
|
||||
gate_score = self.gate(output_x_aver)
|
||||
logits_gate_lst.append(gate_score)
|
||||
candidate_output = torch.cat(outputs) # torch.Size([num_expert, bz, 32, 768])
|
||||
logits_gate = torch.cat(logits_gate_lst,dim=1)# torch.Size([bz, num_expert])
|
||||
|
||||
# Probabilities for each sample of what expert it should be sent to.
|
||||
prob_gate = F.softmax(logits_gate, dim=-1) # torch.Size([bz, num_experts])
|
||||
if 'softmax(topk)' in self.weight_type:
|
||||
prob_gate1, gate = torch.topk(logits_gate, self.topk, dim=1)
|
||||
select_prob_gate = F.softmax(prob_gate1, dim=-1)
|
||||
else:
|
||||
select_prob_gate, gate = torch.topk(prob_gate, self.topk, dim=1) # gate, 每个样本被分配的expert: torch.Size([bz, topk])
|
||||
|
||||
# Calculate Balancing Loss
|
||||
if self.use_balance_loss:
|
||||
num_sentences = F.one_hot(gate, self.num_experts).sum(1).gt(0).sum(0) # 每个expert被分配的样本数 torch.Size([num_expert])
|
||||
balance_loss = self._balancing_loss(prob_gate, num_sentences)
|
||||
else:
|
||||
balance_loss = 0.0
|
||||
# Calculate Importance Loss
|
||||
importance_loss = self._importance_auxiliary_loss(prob_gate)
|
||||
|
||||
# Reshap Prob_gate & Gate
|
||||
# expert_mask: [batch_size, topk, num_experts]
|
||||
# expert_gate: [batch_size, topk, num_experts]
|
||||
# combine_tensor: [batch_size, num_experts]
|
||||
expert_mask = F.one_hot(gate, self.num_experts)
|
||||
expert_gate = select_prob_gate.unsqueeze(-1) * expert_mask
|
||||
combine_tensor = torch.sum(expert_gate, dim=1)
|
||||
# combine_tensor = torch.zeros_like(prob_gate)
|
||||
# combine_tensor.scatter_(1, gate, select_prob_gate) # 等价操作,但可能不可导
|
||||
|
||||
candidate_output_ad = torch.permute(candidate_output, (1, 0, 2, 3)) # torch.Size([bz, num_expert, 32, 768])
|
||||
results = candidate_output_ad * combine_tensor.unsqueeze(-1).unsqueeze(-1) # torch.Size([bz, num_expert, 32, 768])
|
||||
outputs = torch.sum(results, dim=1) # torch.Size([bz, 32, 768])
|
||||
import pdb;pdb.set_trace()
|
||||
|
||||
return outputs, (balance_loss+importance_loss), combine_tensor
|
||||
|
||||
def pre_router(self, x, attention_mask):
|
||||
# Prepare input x
|
||||
attention_mask = torch.ones(attention_mask.shape[0], attention_mask.shape[1]).to(x.device)
|
||||
x_masked = x * attention_mask.unsqueeze(-1) # torch.Size([bz, 32, 768])
|
||||
x_average = torch.mean(x_masked, dim=1) # torch.Size([bz, 768])
|
||||
|
||||
# Forward Gate
|
||||
# logits_gate: [bz, num_experts]
|
||||
logits_gate = self.gate(x_average)
|
||||
|
||||
# Probabilities for each sample of what expert it should be sent to.
|
||||
# prob_gate: [bz, num_experts]
|
||||
prob_gate = F.softmax(logits_gate, dim=-1)
|
||||
|
||||
if 'softmax(topk)' in self.weight_type:
|
||||
prob_gate1, gate = torch.topk(logits_gate, self.topk, dim=1)
|
||||
select_prob_gate = F.softmax(prob_gate1, dim=-1)
|
||||
else:
|
||||
# topk(softmax)
|
||||
# Get Top-K experts for each sample
|
||||
# gate: [bz, topk]
|
||||
# select_prob_gate: [bz, topk]
|
||||
select_prob_gate, gate = torch.topk(prob_gate, self.topk, dim=1)
|
||||
|
||||
# Reshap Prob_gate & Gate
|
||||
# expert_mask: [batch_size, topk, num_experts]
|
||||
# expert_gate: [batch_size, topk, num_experts]
|
||||
# combine_tensor: [batch_size, num_experts]
|
||||
expert_mask = F.one_hot(gate, self.num_experts)
|
||||
expert_gate = select_prob_gate.unsqueeze(-1) * expert_mask
|
||||
combine_tensor = torch.sum(expert_gate, dim=1)
|
||||
|
||||
# Calculate Balancing Loss
|
||||
if self.use_balance_loss:
|
||||
num_sentences = F.one_hot(gate, self.num_experts).sum(1).gt(0).sum(0) # 每个expert被分配的样本数 torch.Size([num_expert])
|
||||
balance_loss = self._balancing_loss(prob_gate, num_sentences)
|
||||
else:
|
||||
balance_loss = 0.0
|
||||
|
||||
# Calculate Importance Loss
|
||||
importance_loss = self._importance_auxiliary_loss(prob_gate)
|
||||
|
||||
import pdb; pdb.set_trace()
|
||||
|
||||
return expert_mask, combine_tensor, balance_loss, importance_loss
|
||||
|
||||
def _forward_gate_sentence(self, x, attention_mask):
|
||||
"""
|
||||
x: query_attention_output , torch.Size([bz, 32, 768])
|
||||
attention_mask: torch.ones([bz, 32])
|
||||
|
||||
### Notice:
|
||||
the raw version of expert_attention_mask is the extended_attention_mask,
|
||||
which will be add to attention_score directly
|
||||
the values of extended_attention_mask are -0.0 or -10000
|
||||
it should be adjust to 1/0 version to be processed by experts
|
||||
"""
|
||||
# Forward Router
|
||||
expert_mask, combine_tensor, balance_loss, importance_loss = self.pre_router(x, attention_mask)
|
||||
|
||||
# Forward Expert FFN
|
||||
result = []
|
||||
for expert_idx in range(self.num_experts):
|
||||
output_x = self.experts[expert_idx].forward(x)
|
||||
result.append(output_x.unsqueeze(0))
|
||||
expert_output = torch.cat(result).permute(1,0,2,3) # torch.Size([batch_size, num_expert, num_tokens, hidden_states])
|
||||
|
||||
# multiply outputs of experts by the routing probability
|
||||
expert_outputs_combined = expert_output * combine_tensor.unsqueeze(-1).unsqueeze(-1) # torch.Size([batch_size, num_expert, num_tokens, hidden_states])
|
||||
outputs = torch.sum(expert_outputs_combined, dim=1) # torch.Size([batch_size, num_tokens, hidden_states])
|
||||
|
||||
import pdb; pdb.set_trace()
|
||||
|
||||
return outputs, (balance_loss+importance_loss), combine_tensor
|
||||
|
||||
|
||||
def forward(self, x, attention_mask):
|
||||
if self.route_method == "gate-token":
|
||||
x, balance_loss, gate_load = self._forward_gate_token(x)
|
||||
elif self.route_method == "gate-sentence":
|
||||
x, balance_loss, gate_load = self._forward_gate_sentence(x, attention_mask)
|
||||
elif self.route_method == "gate-sentence-post":
|
||||
x, balance_loss, gate_load = self._forward_gate_sentence_post(x, attention_mask)
|
||||
else:
|
||||
raise KeyError("Routing method not supported.")
|
||||
# import pdb; pdb.set_trace()
|
||||
return x, balance_loss, gate_load
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
import sys
|
||||
sys.path.append("/mnt/pfs-guan-ssai/nlu/wanghanzi/multimodal/PromptMoE")
|
||||
from minigpt4.models.QformerRouteMoE import BertConfig
|
||||
from minigpt4.models.QformerRouteMoE import FeedForward
|
||||
from minigpt4.models.moe.utils import (
|
||||
moe_layer_judge,
|
||||
)
|
||||
|
||||
vision_width = 1408
|
||||
cross_attention_freq = 2
|
||||
num_query_token = 32
|
||||
# init_QformerMoE
|
||||
config = BertConfig.from_pretrained("/mnt/pfs-guan-ssai/nlu/wanghanzi/models/bert-base-uncased")
|
||||
config.encoder_width = vision_width
|
||||
# insert cross-attention layer every other block
|
||||
config.add_cross_attention = True
|
||||
config.cross_attention_freq = cross_attention_freq
|
||||
config.query_length = num_query_token
|
||||
config.moebert_expert_num = 3
|
||||
config.moebert_num_beams = 2
|
||||
config.moebert_route_method = 'gate-sentence-post'
|
||||
config.moe_topk = 1
|
||||
config.use_balance_loss = False
|
||||
# config.moe_weight_type = 'raw_prob, softmax(topk)'
|
||||
config.moe_weight_type = 'raw_prob, topk(softmax)'
|
||||
|
||||
batch_size = 4
|
||||
x2 = torch.randn(batch_size, 32, 768)
|
||||
beam_scores, expert_route = None, None
|
||||
|
||||
for layer_num in [6, 8, 10]:
|
||||
layer_judge = moe_layer_judge(layer_num)
|
||||
ffn = FeedForward(config)
|
||||
gate = nn.Linear(768, config.moebert_expert_num, bias=False).float()
|
||||
|
||||
experts_moe = MoELayer(
|
||||
hidden_size=config.hidden_size,
|
||||
expert=ffn,
|
||||
num_experts=config.moebert_expert_num,
|
||||
route_method=config.moebert_route_method,
|
||||
topk=config.moe_topk,
|
||||
use_balance_loss=config.use_balance_loss,
|
||||
weight_type=config.moe_weight_type,
|
||||
)
|
||||
attn_mask = torch.ones([batch_size, 32])
|
||||
layer_output = experts_moe(x2, attn_mask)
|
||||
hidden_states3, aux_loss, combine_tensor = layer_output
|
||||
|
||||
print(combine_tensor)
|
||||
print(aux_loss)
|
||||
x2 = hidden_states3
|
||||
|
||||
print("------------------------------------")
|
||||
import pdb; pdb.set_trace()
|
@ -19,16 +19,34 @@ def use_experts(layer_idx):
|
||||
else:
|
||||
return False
|
||||
|
||||
def use_experts_route(layer_idx):
|
||||
# if layer_idx % 2 == 0:
|
||||
# use moe_ffn after cross_attns
|
||||
# if int(layer_idx) in [0,2,4,6,8,10]:
|
||||
if int(layer_idx) in [6,7,8,9,10,11]:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
def moe_layer_judge(layer_idx):
|
||||
if layer_idx == 6:
|
||||
return 'first'
|
||||
elif layer_idx == 8:
|
||||
elif layer_idx in [7,8,9,10]:
|
||||
return 'mid'
|
||||
elif layer_idx == 10:
|
||||
elif layer_idx == 11:
|
||||
return 'last'
|
||||
else:
|
||||
return None
|
||||
|
||||
# if layer_idx == 0:
|
||||
# return 'first'
|
||||
# elif layer_idx in [2,4,6,8]:
|
||||
# return 'mid'
|
||||
# elif layer_idx == 10:
|
||||
# return 'last'
|
||||
# else:
|
||||
# return None
|
||||
|
||||
def process_ffn(model):
|
||||
if model.config.model_type == "bert":
|
||||
inner_model = model.bert
|
||||
|
@ -10,7 +10,6 @@ model:
|
||||
load_finetuned: False
|
||||
vit_model: eva_clip_g
|
||||
pretrained: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/blip2/blip2_vicuna7b/blip2_pretrained_vicuna7b.pth"
|
||||
# finetuned: "/mnt/pfs-guan-ssai/nlu/wanghanzi/experiments/blip2/vicuna7b/qformer_moe/mix_coco_gqa_balance_raw_QformerMoE_Post_train_qf_train_qt_aver_weight_5ex_top1_1loss_textinqf_epo3_s42_1201/20231201184/checkpoint_best.pth"
|
||||
finetuned: ""
|
||||
q_former_model: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/blip2/blip2_vicuna7b/blip2_pretrained_vicuna7b.pth"
|
||||
|
||||
@ -38,7 +37,7 @@ model:
|
||||
|
||||
# moe
|
||||
use_moeqformer: True
|
||||
moebert_expert_num: 5
|
||||
moebert_expert_num: 3
|
||||
moebert_route_method: "gate-sentence-post"
|
||||
moebert_load_balance: 0
|
||||
moe_topk: 1
|
||||
@ -110,6 +109,7 @@ run:
|
||||
max_epoch: 1
|
||||
num_workers: 4
|
||||
warmup_steps: 600
|
||||
iters_per_epoch: 1000
|
||||
|
||||
seed: 42
|
||||
output_dir: "/mnt/pfs-guan-ssai/nlu/wanghanzi/experiments/blip2/vicuna7b/qformer_moe/mix_coco_gqa_balance_raw_QformerMoE_Post_train_qf_train_qt_aver_weight_5ex_top1_1loss_textinqf_epo3_s42_1201/"
|
||||
|
@ -10,7 +10,7 @@ model:
|
||||
load_finetuned: True
|
||||
vit_model: eva_clip_g
|
||||
pretrained: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/blip2/blip2_vicuna7b/blip2_pretrained_vicuna7b.pth"
|
||||
finetuned: "/mnt/pfs-guan-ssai/nlu/wanghanzi/experiments/blip2/vicuna7b/qformer_moe_route/mix_coco_gqa_1610k_raw_QformerMoE_Route_Post_linear_gate_3ex_3beam_1loss_top3layer_log_textinqf_epo3_1216/20231216155/checkpoint_best.pth"
|
||||
finetuned: "/mnt/pfs-guan-ssai/nlu/wanghanzi/experiments/blip2/vicuna7b/qformer_moe_route/mix_coco_gqa_1610k_raw_QformerMoE_Route_Post_ffn_prob_linear_gate_2ex_2beam_1gate_2loss_5e5lr_top6layer_textinqf_epo8_0112/20240112212/checkpoint_best.pth"
|
||||
q_former_model: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/blip2/blip2_vicuna7b/blip2_pretrained_vicuna7b.pth"
|
||||
|
||||
# vit encoder
|
||||
@ -38,10 +38,12 @@ model:
|
||||
# moe
|
||||
use_moeqformer: True
|
||||
use_route_moe: True
|
||||
moebert_expert_num: 3
|
||||
moebert_num_beams: 3
|
||||
moebert_route_method: "post-route"
|
||||
gate_save_path: "/mnt/pfs-guan-ssai/nlu/wanghanzi/experiments/blip2/vicuna7b/qformer_moe_route/route_save/mix_coco_gqa_balance_raw_QformerMoE_Route_linear_gate_5ex_2beam_1loss_textinqf_epo5_toplayer3_1209_eval_latest1/"
|
||||
moebert_load_balance: 0
|
||||
moebert_expert_num: 2
|
||||
moebert_num_beams: 2
|
||||
moe_weight_type: 'ffn_prob'
|
||||
gate_save_path: "/mnt/pfs-guan-ssai/nlu/wanghanzi/experiments/blip2/vicuna7b/qformer_moe_route/route_save/mix_coco_gqa_1610k_raw_QformerMoE_Route_Post_ffn_prob_linear_gate_2ex_2beam_1gate_2loss_5e5lr_top6layer_textinqf_epo8_0112/"
|
||||
|
||||
datasets:
|
||||
gqa:
|
||||
@ -81,19 +83,20 @@ run:
|
||||
task: instruction_tuning
|
||||
# optimizer
|
||||
lr_sched: "linear_warmup_cosine_lr"
|
||||
init_lr: 2e-5
|
||||
init_lr: 5e-5
|
||||
min_lr: 1e-6
|
||||
warmup_lr: 1e-6
|
||||
log_freq: 5
|
||||
save_freq: 1500
|
||||
|
||||
weight_decay: 0.05
|
||||
max_epoch: 5
|
||||
max_epoch: 10
|
||||
num_workers: 4
|
||||
warmup_steps: 600
|
||||
iters_per_epoch: 3000
|
||||
|
||||
seed: 42
|
||||
output_dir: "/mnt/pfs-guan-ssai/nlu/wanghanzi/experiments/blip2/vicuna7b/qformer_moe_route/eval/mix_coco_gqa_balance_raw_QformerMoE_Route_linear_gate_5ex_2beam_1loss_textinqf_epo5_toplayer3_1209/"
|
||||
output_dir: "/mnt/pfs-guan-ssai/nlu/wanghanzi/experiments/blip2/vicuna7b/qformer_moe_route/eval/mix_coco_gqa_1610k_raw_QformerMoE_Route_Post_ffn_prob_linear_gate_2ex_2beam_1gate_2loss_5e5lr_top6layer_textinqf_epo8_0112/"
|
||||
|
||||
amp: True
|
||||
resume_ckpt_path: null
|
||||
|
@ -38,10 +38,12 @@ model:
|
||||
# moe
|
||||
use_moeqformer: True
|
||||
use_route_moe: True
|
||||
moebert_route_method: "post-route"
|
||||
moebert_load_balance: 0
|
||||
moebert_expert_num: 3
|
||||
moebert_num_beams: 3
|
||||
moebert_route_method: "post-route"
|
||||
gate_save_path: "/mnt/pfs-guan-ssai/nlu/wanghanzi/experiments/blip2/vicuna7b/qformer_moe/route_save/mix_coco_gqa_balance_raw_QformerMoE_Route_linear_gate_5ex_2beam_1loss_textinqf_epo5_toplayer3_1209/"
|
||||
moe_weight_type: 'ffn_prob'
|
||||
use_balance_loss: False
|
||||
|
||||
datasets:
|
||||
gqa: # train: 943000, 12578, 12578)
|
||||
@ -97,19 +99,20 @@ run:
|
||||
task: instruction_tuning
|
||||
# optimizer
|
||||
lr_sched: "linear_warmup_cosine_lr"
|
||||
init_lr: 2e-5
|
||||
init_lr: 5e-5
|
||||
min_lr: 1e-6
|
||||
warmup_lr: 1e-6
|
||||
log_freq: 5
|
||||
save_freq: 1500
|
||||
|
||||
weight_decay: 0.05
|
||||
max_epoch: 5
|
||||
max_epoch: 8
|
||||
num_workers: 4
|
||||
warmup_steps: 600
|
||||
iters_per_epoch: 5000
|
||||
|
||||
seed: 42
|
||||
output_dir: "/mnt/pfs-guan-ssai/nlu/wanghanzi/experiments/blip2/vicuna7b/qformer_moe_route/mix_coco_gqa_balance_raw_QformerMoE_Route_linear_gate_5ex_2beam_1loss_textinqf_epo5_toplayer3_1209/"
|
||||
output_dir: "/mnt/pfs-guan-ssai/nlu/wanghanzi/experiments/blip2/vicuna7b/qformer_moe_route/mix_coco_gqa_1610k_raw_QformerMoE_Route_Post_ffn_prob_linear_1gate_3ex_3beam_1loss_5e5lr_top6layer_textinqf_epo8_0117/"
|
||||
|
||||
amp: True
|
||||
resume_ckpt_path: null
|
||||
|
@ -0,0 +1,129 @@
|
||||
# Copyright (c) 2022, salesforce.com, inc.
|
||||
# All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
||||
# 0107test
|
||||
|
||||
model:
|
||||
arch: blip2_vicuna_instruct
|
||||
model_type: vicuna7b_pretrain
|
||||
load_pretrained: True
|
||||
load_finetuned: False
|
||||
vit_model: eva_clip_g
|
||||
pretrained: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/blip2/blip2_vicuna7b/blip2_pretrained_vicuna7b.pth"
|
||||
# finetuned: ""
|
||||
q_former_model: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/blip2/blip2_vicuna7b/blip2_pretrained_vicuna7b.pth"
|
||||
|
||||
# vit encoder
|
||||
image_size: 224
|
||||
drop_path_rate: 0
|
||||
use_grad_checkpoint: False
|
||||
vit_precision: "fp16"
|
||||
|
||||
# Q-Former
|
||||
num_query_token: 32
|
||||
qformer_text_input: True
|
||||
|
||||
# vicuna
|
||||
llm_model: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/vicuna-7b-v1.1"
|
||||
prompt: ""
|
||||
max_txt_len: 256
|
||||
max_output_txt_len: 256
|
||||
|
||||
# freeze
|
||||
freeze_vit: True
|
||||
freeze_llm: True
|
||||
freeze_qformer: False
|
||||
freeze_t5_proj: False
|
||||
|
||||
# moe
|
||||
use_moeqformer: True
|
||||
use_route_moe: True
|
||||
moebert_route_method: "post-route"
|
||||
moebert_load_balance: 0
|
||||
moebert_expert_num: 2
|
||||
moebert_num_beams: 2
|
||||
moe_weight_type: 'ffn_prob'
|
||||
use_balance_loss: False
|
||||
# gate_save_path: "/mnt/pfs-guan-ssai/nlu/wanghanzi/experiments/blip2/vicuna7b/qformer_moe/route_save/mix_coco_gqa_balance_raw_QformerMoE_Route_linear_gate_5ex_2beam_1loss_textinqf_epo5_toplayer3_1209/"
|
||||
|
||||
datasets:
|
||||
# gqa: # train: 943000, 12578, 12578)
|
||||
# type: balanced_sft_raw
|
||||
# batch_size: 1
|
||||
# vis_processor:
|
||||
# train:
|
||||
# name: "blip2_image_train"
|
||||
# image_size: 224
|
||||
# eval:
|
||||
# name: "blip2_image_eval"
|
||||
# image_size: 224
|
||||
# text_processor:
|
||||
# train:
|
||||
# name: "blip_caption"
|
||||
# eval:
|
||||
# name: "blip_caption"
|
||||
# sample_ratio: 10
|
||||
|
||||
ok_vqa: # train, valid (9009, 5046)
|
||||
batch_size: 1
|
||||
vis_processor:
|
||||
train:
|
||||
name: "blip2_image_train"
|
||||
image_size: 224
|
||||
eval:
|
||||
name: "blip2_image_eval"
|
||||
image_size: 224
|
||||
text_processor:
|
||||
train:
|
||||
name: "blip_caption"
|
||||
eval:
|
||||
name: "blip_caption"
|
||||
sample_ratio: 1
|
||||
|
||||
# coco_vqa: # 658104
|
||||
# batch_size: 1
|
||||
# vis_processor:
|
||||
# train:
|
||||
# name: "blip2_image_train"
|
||||
# image_size: 224
|
||||
# eval:
|
||||
# name: "blip2_image_eval"
|
||||
# image_size: 224
|
||||
# text_processor:
|
||||
# train:
|
||||
# name: "blip_caption"
|
||||
# eval:
|
||||
# name: "blip_caption"
|
||||
# sample_ratio: 9
|
||||
|
||||
run:
|
||||
task: instruction_tuning
|
||||
# optimizer
|
||||
lr_sched: "linear_warmup_cosine_lr"
|
||||
init_lr: 2e-5
|
||||
min_lr: 1e-6
|
||||
warmup_lr: 1e-6
|
||||
log_freq: 5
|
||||
save_freq: 1500
|
||||
|
||||
weight_decay: 0.05
|
||||
max_epoch: 5
|
||||
num_workers: 4
|
||||
warmup_steps: 600
|
||||
|
||||
seed: 42
|
||||
output_dir: "/mnt/pfs-guan-ssai/nlu/wanghanzi/experiments/blip2/vicuna7b/qformer_moe_route/mix_coco_gqa_balance_raw_QformerMoE_Route_linear_gate_5ex_2beam_1loss_textinqf_epo5_toplayer3_1209/"
|
||||
|
||||
amp: True
|
||||
resume_ckpt_path: null
|
||||
|
||||
evaluate: False
|
||||
train_splits: ["train"]
|
||||
valid_splits: ["val"]
|
||||
# test_splits: ["val"]
|
||||
|
||||
device: "cuda"
|
||||
world_size: 1
|
||||
dist_url: "env://"
|
||||
distributed: True
|
@ -10,7 +10,7 @@ model:
|
||||
load_finetuned: True
|
||||
vit_model: eva_clip_g
|
||||
pretrained: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/blip2/blip2_vicuna7b/blip2_pretrained_vicuna7b.pth"
|
||||
finetuned: "/mnt/pfs-guan-ssai/nlu/wanghanzi/experiments/blip2/vicuna7b/qformer_moe_route/mix_1048k_raw_QformerMoE_Route_Post_NoNorm_5ex_2beam_1loss_top3layer_textinqf_epo6_1215/20231216161/checkpoint_best.pth"
|
||||
finetuned: "/mnt/pfs-guan-ssai/nlu/wanghanzi/experiments/blip2/vicuna7b/qformer_moe_route/mix_coco_gqa_1610k_raw_QformerMoE_Route_Post_ffn_prob_linear_gate_2ex_2beam_1gate_1loss_5e5lr_top6layer_textinqf_epo8_0111/20240111145/checkpoint_best.pth"
|
||||
q_former_model: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/blip2/blip2_vicuna7b/blip2_pretrained_vicuna7b.pth"
|
||||
|
||||
# vit encoder
|
||||
@ -39,8 +39,11 @@ model:
|
||||
use_moeqformer: True
|
||||
use_route_moe: True
|
||||
moebert_route_method: "post-route"
|
||||
moebert_expert_num: 5
|
||||
moebert_load_balance: 0
|
||||
moebert_expert_num: 2
|
||||
moebert_num_beams: 2
|
||||
moe_weight_type: 'ffn_prob'
|
||||
use_balance_loss: False
|
||||
|
||||
datasets:
|
||||
ok_vqa: # train, valid (9009, 5046)
|
||||
@ -78,7 +81,7 @@ evaluation_datasets:
|
||||
run:
|
||||
task: instruction_tuning
|
||||
name: vqa_benchmark_evaluation
|
||||
save_path: "/mnt/pfs-guan-ssai/nlu/wanghanzi/experiments/blip2/vicuna7b/qformer_moe/eval/benchmarks/mix_1048k_raw_QformerMoE_Route_Post_NoNorm_5ex_2beam_1loss_top3layer_textinqf_epo6_1215/"
|
||||
save_path: "/mnt/pfs-guan-ssai/nlu/wanghanzi/experiments/blip2/vicuna7b/qformer_moe_route/eval/benchmarks/mix_coco_gqa_1610k_raw_QformerMoE_Route_Post_ffn_prob_linear_gate_2ex_2beam_1gate_1loss_5e5lr_top6layer_textinqf_epo8_0111/"
|
||||
seed: 42
|
||||
|
||||
|
||||
|
@ -0,0 +1,131 @@
|
||||
# Copyright (c) 2022, salesforce.com, inc.
|
||||
# All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
||||
|
||||
model:
|
||||
arch: blip2_vicuna_instruct
|
||||
model_type: vicuna7b_pretrain
|
||||
load_pretrained: True
|
||||
load_finetuned: False
|
||||
vit_model: eva_clip_g
|
||||
pretrained: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/blip2/blip2_vicuna7b/blip2_pretrained_vicuna7b.pth"
|
||||
# finetuned: ""
|
||||
q_former_model: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/blip2/blip2_vicuna7b/blip2_pretrained_vicuna7b.pth"
|
||||
|
||||
# vit encoder
|
||||
image_size: 224
|
||||
drop_path_rate: 0
|
||||
use_grad_checkpoint: False
|
||||
vit_precision: "fp16"
|
||||
|
||||
# Q-Former
|
||||
num_query_token: 32
|
||||
qformer_text_input: True
|
||||
|
||||
# vicuna7b
|
||||
llm_model: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/vicuna-7b-v1.1"
|
||||
prompt: ""
|
||||
max_txt_len: 256
|
||||
max_output_txt_len: 256
|
||||
|
||||
# freeze
|
||||
freeze_vit: True
|
||||
freeze_llm: True
|
||||
freeze_qformer: False
|
||||
freeze_t5_proj: False
|
||||
|
||||
# moe
|
||||
use_moeqformer: True
|
||||
use_route_moe: True
|
||||
moebert_route_method: "post-route"
|
||||
moebert_load_balance: 0.05
|
||||
moebert_expert_num: 3
|
||||
moebert_num_beams: 3
|
||||
moe_weight_type: 'ffn_prob'
|
||||
use_balance_loss: False
|
||||
|
||||
datasets:
|
||||
gqa: # train: 943000, 12578, 12578)
|
||||
type: balanced_sft_raw
|
||||
# batch_size: 16
|
||||
batch_size: 32
|
||||
vis_processor:
|
||||
train:
|
||||
name: "blip2_image_train"
|
||||
image_size: 224
|
||||
eval:
|
||||
name: "blip2_image_eval"
|
||||
image_size: 224
|
||||
text_processor:
|
||||
train:
|
||||
name: "blip_caption"
|
||||
eval:
|
||||
name: "blip_caption"
|
||||
sample_ratio: 50
|
||||
|
||||
ok_vqa: # train, valid (9009, 5046)
|
||||
# batch_size: 16
|
||||
batch_size: 32
|
||||
vis_processor:
|
||||
train:
|
||||
name: "blip2_image_train"
|
||||
image_size: 224
|
||||
eval:
|
||||
name: "blip2_image_eval"
|
||||
image_size: 224
|
||||
text_processor:
|
||||
train:
|
||||
name: "blip_caption"
|
||||
eval:
|
||||
name: "blip_caption"
|
||||
sample_ratio: 8
|
||||
|
||||
coco_vqa: # 658104
|
||||
# batch_size: 16
|
||||
batch_size: 32
|
||||
vis_processor:
|
||||
train:
|
||||
name: "blip2_image_train"
|
||||
image_size: 224
|
||||
eval:
|
||||
name: "blip2_image_eval"
|
||||
image_size: 224
|
||||
text_processor:
|
||||
train:
|
||||
name: "blip_caption"
|
||||
eval:
|
||||
name: "blip_caption"
|
||||
sample_ratio: 15
|
||||
|
||||
run:
|
||||
task: instruction_tuning
|
||||
# optimizer
|
||||
lr_sched: "linear_warmup_cosine_lr"
|
||||
# init_lr: 2e-5
|
||||
init_lr: 5e-5
|
||||
min_lr: 1e-6
|
||||
warmup_lr: 1e-6
|
||||
log_freq: 5
|
||||
save_freq: 1500
|
||||
|
||||
weight_decay: 0.05
|
||||
max_epoch: 8
|
||||
num_workers: 4
|
||||
warmup_steps: 600
|
||||
iters_per_epoch: 5000
|
||||
|
||||
seed: 42
|
||||
output_dir: "/mnt/pfs-guan-ssai/nlu/wanghanzi/experiments/blip2/vicuna7b/qformer_moe_route/mix_coco_gqa_1610k_raw_QformerMoE_Route_Post_ffn_prob_linear_gate_3ex_3beam_1gate_2loss_5e5lr_top6layer_textinqf_epo8_0112/"
|
||||
|
||||
amp: True
|
||||
resume_ckpt_path: null
|
||||
|
||||
evaluate: False
|
||||
train_splits: ["train"]
|
||||
valid_splits: ["val"]
|
||||
|
||||
device: "cuda"
|
||||
world_size: 1
|
||||
dist_url: "env://"
|
||||
distributed: True
|
@ -37,19 +37,19 @@ model:
|
||||
|
||||
# moe
|
||||
use_moeqformer: True
|
||||
moebert_expert_num: 5
|
||||
moebert_expert_num: 3
|
||||
moebert_route_method: "gate-sentence"
|
||||
moebert_load_balance: 0
|
||||
moe_topk: 1
|
||||
use_balance_loss: False
|
||||
moe_weight_type: 'l2_norm'
|
||||
gate_save_path: "/mnt/pfs-guan-ssai/nlu/wanghanzi/experiments/blip2/vicuna7b/qformer_moe/gate_save/mix_coco_gqa_balance_raw_QformerMoE_train_qf_train_qt_linear_gate_5ex_top1_1loss_textinqf_training_epo5_toplayer3_1206/"
|
||||
moe_weight_type: 'raw_prob'
|
||||
# gate_save_path: "/mnt/pfs-guan-ssai/nlu/wanghanzi/experiments/blip2/vicuna7b/qformer_moe/gate_save/mix_coco_gqa_balance_raw_QformerMoE_train_qf_train_qt_linear_gate_5ex_top1_1loss_textinqf_training_epo5_toplayer3_1206/"
|
||||
|
||||
|
||||
datasets:
|
||||
gqa: # train: 94254
|
||||
type: balanced_sft_raw_part
|
||||
batch_size: 32
|
||||
batch_size: 1
|
||||
vis_processor:
|
||||
train:
|
||||
name: "blip2_image_train"
|
||||
@ -65,7 +65,7 @@ datasets:
|
||||
sample_ratio: 50
|
||||
|
||||
ok_vqa: # train, valid (9009, 5046
|
||||
batch_size: 32
|
||||
batch_size: 1
|
||||
vis_processor:
|
||||
train:
|
||||
name: "blip2_image_train"
|
||||
@ -80,22 +80,22 @@ datasets:
|
||||
name: "blip_caption"
|
||||
sample_ratio: 8
|
||||
|
||||
coco_vqa: # 214352 vqa_val
|
||||
type: vqa_v2_part
|
||||
batch_size: 32
|
||||
vis_processor:
|
||||
train:
|
||||
name: "blip2_image_train"
|
||||
image_size: 224
|
||||
eval:
|
||||
name: "blip2_image_eval"
|
||||
image_size: 224
|
||||
text_processor:
|
||||
train:
|
||||
name: "blip_caption"
|
||||
eval:
|
||||
name: "blip_caption"
|
||||
sample_ratio: 15
|
||||
# coco_vqa: # 214352 vqa_val
|
||||
# type: vqa_v2_part
|
||||
# batch_size: 1
|
||||
# vis_processor:
|
||||
# train:
|
||||
# name: "blip2_image_train"
|
||||
# image_size: 224
|
||||
# eval:
|
||||
# name: "blip2_image_eval"
|
||||
# image_size: 224
|
||||
# text_processor:
|
||||
# train:
|
||||
# name: "blip_caption"
|
||||
# eval:
|
||||
# name: "blip_caption"
|
||||
# sample_ratio: 15
|
||||
|
||||
run:
|
||||
task: instruction_tuning
|
||||
@ -108,12 +108,13 @@ run:
|
||||
save_freq: 1500
|
||||
|
||||
weight_decay: 0.05
|
||||
max_epoch: 5
|
||||
max_epoch: 1
|
||||
num_workers: 4
|
||||
warmup_steps: 600
|
||||
iters_per_epoch: 1000
|
||||
|
||||
seed: 42
|
||||
output_dir: "/mnt/pfs-guan-ssai/nlu/wanghanzi/experiments/blip2/vicuna7b/qformer_moe/mix_coco_gqa_balance_raw_QformerMoE_train_qf_train_qt_linear_gate_5ex_top1_1loss_textinqf_training_epo5_toplayer3_1206/"
|
||||
output_dir: "/mnt/pfs-guan-ssai/nlu/wanghanzi/experiments/blip2/vicuna7b/qformer_moe/mix_coco_gqa_balance_raw_QformerMoE_train_qf_train_qt_linear_gate_5ex_top1_1loss_textinqf_training_epo5_toplayer3_1220_test/"
|
||||
|
||||
amp: True
|
||||
resume_ckpt_path: null
|
||||
|
@ -0,0 +1,125 @@
|
||||
# Copyright (c) 2022, salesforce.com, inc.
|
||||
# All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
||||
|
||||
model:
|
||||
arch: blip2_vicuna_instruct
|
||||
model_type: vicuna7b_pretrain
|
||||
load_pretrained: True
|
||||
load_finetuned: False
|
||||
vit_model: eva_clip_g
|
||||
pretrained: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/blip2/blip2_vicuna7b/blip2_pretrained_vicuna7b.pth"
|
||||
# finetuned: ""
|
||||
q_former_model: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/blip2/blip2_vicuna7b/blip2_pretrained_vicuna7b.pth"
|
||||
|
||||
# vit encoder
|
||||
image_size: 224
|
||||
drop_path_rate: 0
|
||||
use_grad_checkpoint: False
|
||||
vit_precision: "fp16"
|
||||
|
||||
# Q-Former
|
||||
num_query_token: 32
|
||||
qformer_text_input: True
|
||||
|
||||
# vicuna7b
|
||||
llm_model: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/vicuna-7b-v1.1"
|
||||
prompt: ""
|
||||
max_txt_len: 256
|
||||
max_output_txt_len: 256
|
||||
|
||||
# freeze
|
||||
freeze_vit: True
|
||||
freeze_llm: True
|
||||
freeze_qformer: False
|
||||
freeze_t5_proj: False
|
||||
|
||||
# moe
|
||||
use_moeqformer: False
|
||||
moebert_expert_num: 1
|
||||
moebert_route_method: "gate-sentence"
|
||||
moebert_load_balance: 0.05
|
||||
moe_topk: 1
|
||||
|
||||
datasets:
|
||||
gqa: # train: 943000, 12578, 12578)
|
||||
type: balanced_sft_raw
|
||||
batch_size: 16
|
||||
vis_processor:
|
||||
train:
|
||||
name: "blip2_image_train"
|
||||
image_size: 224
|
||||
eval:
|
||||
name: "blip2_image_eval"
|
||||
image_size: 224
|
||||
text_processor:
|
||||
train:
|
||||
name: "blip_caption"
|
||||
eval:
|
||||
name: "blip_caption"
|
||||
sample_ratio: 50
|
||||
|
||||
ok_vqa: # train, valid (9009, 5046)
|
||||
batch_size: 16
|
||||
vis_processor:
|
||||
train:
|
||||
name: "blip2_image_train"
|
||||
image_size: 224
|
||||
eval:
|
||||
name: "blip2_image_eval"
|
||||
image_size: 224
|
||||
text_processor:
|
||||
train:
|
||||
name: "blip_caption"
|
||||
eval:
|
||||
name: "blip_caption"
|
||||
sample_ratio: 8
|
||||
|
||||
coco_vqa: # 658104
|
||||
batch_size: 16
|
||||
vis_processor:
|
||||
train:
|
||||
name: "blip2_image_train"
|
||||
image_size: 224
|
||||
eval:
|
||||
name: "blip2_image_eval"
|
||||
image_size: 224
|
||||
text_processor:
|
||||
train:
|
||||
name: "blip_caption"
|
||||
eval:
|
||||
name: "blip_caption"
|
||||
sample_ratio: 15
|
||||
|
||||
run:
|
||||
task: instruction_tuning
|
||||
# optimizer
|
||||
lr_sched: "linear_warmup_cosine_lr"
|
||||
# init_lr: 2e-5
|
||||
init_lr: 5e-5
|
||||
min_lr: 1e-6
|
||||
warmup_lr: 1e-6
|
||||
log_freq: 5
|
||||
save_freq: 1500
|
||||
|
||||
weight_decay: 0.05
|
||||
max_epoch: 8
|
||||
num_workers: 4
|
||||
warmup_steps: 600
|
||||
iters_per_epoch: 5000
|
||||
|
||||
seed: 42
|
||||
output_dir: "/mnt/pfs-guan-ssai/nlu/wanghanzi/experiments/blip2/vicuna7b/qformer_moe/mix_coco_gqa_1610k_raw_QformerMoE_train_qf_train_qt_1ex_top1_textinqf_epo8_lr5e5_seed42_0112/"
|
||||
|
||||
amp: True
|
||||
resume_ckpt_path: null
|
||||
|
||||
evaluate: False
|
||||
train_splits: ["train"]
|
||||
valid_splits: ["val"]
|
||||
|
||||
device: "cuda"
|
||||
world_size: 1
|
||||
dist_url: "env://"
|
||||
distributed: True
|
@ -110,6 +110,7 @@ class RunnerBase:
|
||||
else:
|
||||
p_wd.append(p)
|
||||
num_parameters += p.data.nelement()
|
||||
# import pdb; pdb.set_trace() # 0107test
|
||||
logging.info("number of trainable parameters: %d" % num_parameters)
|
||||
optim_params = [
|
||||
{
|
||||
|
@ -241,10 +241,14 @@ class BaseTask:
|
||||
|
||||
# after_train_step()
|
||||
if use_amp:
|
||||
# torch.autograd.set_detect_anomaly(True)
|
||||
# 反向传播时检测是否有异常值,定位code
|
||||
# with torch.autograd.detect_anomaly():
|
||||
scaler.scale(loss).backward()
|
||||
else:
|
||||
loss.backward()
|
||||
|
||||
# import pdb; pdb.set_trace() # 0107test
|
||||
# update gradients every accum_grad_iters iterations
|
||||
if (i + 1) % accum_grad_iters == 0:
|
||||
if use_amp:
|
||||
@ -252,6 +256,9 @@ class BaseTask:
|
||||
scaler.update()
|
||||
else:
|
||||
optimizer.step()
|
||||
|
||||
# import pdb; pdb.set_trace()# 0107test
|
||||
|
||||
optimizer.zero_grad()
|
||||
# if self.cfg.wandb_log:
|
||||
# if self.cfg.run_cfg.wandb_log:
|
||||
|
@ -45,3 +45,5 @@ visualizer
|
||||
tensorboard
|
||||
kmeans_pytorch
|
||||
visual_genome
|
||||
gpustat
|
||||
torchviz
|
36
setup.py
Normal file
36
setup.py
Normal file
@ -0,0 +1,36 @@
|
||||
"""
|
||||
Copyright (c) 2022, salesforce.com, inc.
|
||||
All rights reserved.
|
||||
SPDX-License-Identifier: BSD-3-Clause
|
||||
For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
||||
"""
|
||||
|
||||
from setuptools import setup, find_namespace_packages
|
||||
import platform
|
||||
|
||||
DEPENDENCY_LINKS = []
|
||||
if platform.system() == "Windows":
|
||||
DEPENDENCY_LINKS.append("https://download.pytorch.org/whl/torch_stable.html")
|
||||
|
||||
|
||||
def fetch_requirements(filename):
|
||||
with open(filename) as f:
|
||||
return [ln.strip() for ln in f.read().split("\n")]
|
||||
|
||||
|
||||
setup(
|
||||
name="PromptMoE",
|
||||
version="1.0.1",
|
||||
author="Hanzi Wang",
|
||||
description="PromptMoE & QformerMoE Based on LAVIS",
|
||||
long_description=open("README.md", "r", encoding="utf-8").read(),
|
||||
long_description_content_type="text/markdown",
|
||||
keywords="Vision-Language, Multimodal, Image Captioning, Generative AI, Deep Learning, Library, PyTorch",
|
||||
license="3-Clause BSD",
|
||||
packages=find_namespace_packages(include="lavis.*"),
|
||||
install_requires=fetch_requirements("requirements.txt"),
|
||||
python_requires=">=3.7.0",
|
||||
include_package_data=True,
|
||||
dependency_links=DEPENDENCY_LINKS,
|
||||
zip_safe=False,
|
||||
)
|
5570
test.pdf/backward_graph
Normal file
5570
test.pdf/backward_graph
Normal file
File diff suppressed because it is too large
Load Diff
BIN
test.pdf/backward_graph.pdf
Normal file
BIN
test.pdf/backward_graph.pdf
Normal file
Binary file not shown.
360
test.txt
Normal file
360
test.txt
Normal file
@ -0,0 +1,360 @@
|
||||
tmp_name = [name for name, p in model.named_parameters() if (p.requires_grad and '10.expert' in name)]
|
||||
|
||||
tmp = [p for name, p in model.named_parameters() if (p.requires_grad and '10.expert' in name)]
|
||||
|
||||
tensor([[-1.4032e-02, 3.7242e-03, 8.4997e-03, -3.4016e-03, -6.4855e-03,
|
||||
4.3595e-02, 3.4423e-02, -8.6274e-03, -1.9702e-02, 9.1813e-03,
|
||||
1.1643e-02, 2.3939e-02, -2.0908e-02, 3.4555e-03, 9.1636e-03,
|
||||
1.5413e-02, 2.4148e-02, -1.0880e-03, 1.1193e-02, -1.3591e-02,
|
||||
9.3484e-03, 1.5999e-02, -9.6086e-04, 3.8322e-02, -8.0687e-03,
|
||||
-1.4056e-02, 3.9486e-02, 3.5167e-02, -9.3226e-03, -1.0493e-02,
|
||||
-2.5795e-02, -9.7541e-03, 4.4437e-03, 7.7226e-03, 7.5210e-03,
|
||||
-1.3526e-02, -5.0316e-03, -1.1149e-02, 6.0583e-03, 2.0564e-02,
|
||||
-6.4477e-03, 1.4170e-02, -3.7847e-02, 1.1780e-02, 1.3321e-02,
|
||||
-8.2501e-03, -1.0298e-02, 1.4805e-02, -1.2432e-02, -1.9159e-02,
|
||||
-5.7095e-04, -3.8618e-02, -2.4230e-02, -1.4991e-03, -1.4114e-02,
|
||||
-1.5365e-02, 1.5640e-02, -4.8623e-02, -2.9991e-02, 1.2796e-02,
|
||||
-4.9917e-03, 2.3846e-03, 7.7368e-03, 1.2913e-02, 1.5300e-02,
|
||||
8.5125e-03, 1.1582e-02, 8.1161e-03, 4.2259e-03, 7.6109e-03,
|
||||
-2.0747e-02, -3.5099e-03, 2.2282e-02, 5.0493e-02, -1.7849e-02,
|
||||
-3.7106e-02, -1.4944e-02, -1.4582e-02, -2.2458e-02, -4.6173e-05,
|
||||
-8.1270e-03, 1.9037e-02, -2.0086e-02, 3.0980e-03, -9.3947e-03,
|
||||
1.3054e-02, 2.3203e-02, -9.9304e-03, -2.6038e-02, 1.8679e-02,
|
||||
9.2081e-03, -2.1770e-02, -1.6568e-03, -3.6503e-02, 2.0054e-02,
|
||||
1.2886e-02, -1.8021e-02, 3.4457e-02, -1.3704e-02, -6.1498e-03,
|
||||
-8.6769e-03, 1.5024e-02, -1.3875e-02, 1.7416e-02, -1.1178e-02,
|
||||
-2.4088e-02, -1.7802e-02, 3.3326e-02, -1.1216e-02, -8.6330e-03,
|
||||
-5.5359e-03, -1.1939e-02, -1.7777e-02, -2.8666e-02, -3.8280e-02,
|
||||
4.2682e-02, 1.4946e-02, 9.6427e-03, 8.2754e-03, -1.0516e-03,
|
||||
2.9560e-02, 2.4552e-03, -4.8354e-02, 1.5568e-02, 2.5881e-02,
|
||||
-1.7354e-02, -3.1232e-02, 2.3683e-02, -2.3239e-02, 2.2966e-02,
|
||||
5.6349e-03, -8.7595e-03, 1.5173e-02, 2.7660e-02, -4.3304e-03,
|
||||
-2.5330e-02, -2.1795e-02, 1.6856e-02, -2.1587e-04, 2.3707e-02,
|
||||
-2.3667e-02, 3.5378e-02, -7.9245e-03, 7.1029e-04, -3.2800e-02,
|
||||
-1.5402e-03, -8.5634e-03, -1.1356e-02, -2.1935e-03, -1.8854e-02,
|
||||
-1.9705e-03, -3.8333e-02, 2.9131e-02, -4.4470e-02, -2.0893e-03,
|
||||
1.2937e-02, -1.7116e-02, 2.7778e-02, 1.0311e-02, -6.4017e-03,
|
||||
3.7647e-02, -1.9953e-02, -5.3925e-03, 3.6978e-02, -1.5534e-02,
|
||||
1.2241e-02, 1.3597e-02, 2.0703e-03, 2.4213e-03, 9.2604e-03,
|
||||
6.6108e-03, -5.8213e-03, 9.8167e-03, -9.8300e-04, -1.0236e-02,
|
||||
2.9581e-02, 1.0987e-02, 2.0046e-02, -1.0500e-02, -3.2221e-03,
|
||||
-2.6303e-02, 1.3688e-02, -2.2529e-02, -5.7654e-03, 1.1784e-02,
|
||||
1.6221e-02, 2.8743e-02, 5.7565e-03, 1.8129e-02, 1.5140e-02,
|
||||
-1.1748e-02, -1.7528e-02, 4.7977e-02, 1.5568e-02, 4.7030e-04,
|
||||
3.2757e-03, 1.6631e-02, 1.9986e-02, -7.3463e-03, 1.1435e-02,
|
||||
-1.4739e-02, -3.2959e-03, -2.8770e-03, 2.9260e-02, 1.7007e-02,
|
||||
3.0611e-02, 2.2102e-02, -3.3819e-02, -1.9403e-02, 2.5524e-02,
|
||||
3.0738e-02, -1.9951e-02, -1.4553e-02, -1.5796e-02, -2.3143e-02,
|
||||
-2.8826e-02, 2.4739e-02, -5.8602e-03, 4.1871e-02, 5.0821e-04,
|
||||
3.3493e-02, 2.3524e-02, 2.3191e-02, 9.0416e-03, 3.3262e-02,
|
||||
-1.6805e-02, 1.1545e-02, -1.7195e-02, -3.8696e-02, -8.4358e-04,
|
||||
-8.1605e-03, 3.1372e-03, 1.0726e-03, 1.0865e-03, 1.0760e-02,
|
||||
-5.2421e-03, 1.3039e-02, 3.6873e-04, 1.0464e-02, -1.1544e-02,
|
||||
-2.2775e-02, -4.8439e-02, -1.0711e-02, 4.4236e-03, 2.0351e-02,
|
||||
2.4479e-03, -1.9968e-02, -2.2941e-02, -2.0486e-02, -1.9528e-02,
|
||||
-2.3176e-02, -3.2731e-03, 1.1789e-02, 2.0921e-02, 2.9809e-03,
|
||||
-8.8507e-03, -3.5716e-02, 8.8418e-03, 5.3665e-05, -1.1288e-02,
|
||||
-7.5571e-03, 2.1053e-02, -3.7381e-03, -4.0165e-03, -2.2628e-03,
|
||||
3.7554e-03, -1.6597e-02, 7.6946e-03, -3.2689e-02, 2.2016e-02,
|
||||
5.5122e-03, 4.5455e-02, 6.7586e-03, 1.5714e-02, 5.2125e-03,
|
||||
3.9596e-03, 1.8134e-02, 1.5834e-03, -1.6239e-02, -1.3889e-02,
|
||||
-2.3522e-02, 1.4738e-02, 5.5867e-03, -7.0727e-03, -2.8140e-03,
|
||||
1.6849e-02, -3.1327e-02, -3.2443e-02, 4.7851e-03, 1.2980e-02,
|
||||
-2.0014e-04, -9.9475e-03, 8.0657e-03, 1.9468e-02, -1.5774e-02,
|
||||
1.7017e-02, -8.7196e-03, -4.0681e-03, -6.9754e-03, -2.2007e-02,
|
||||
-6.6217e-03, -1.8219e-02, 4.2186e-02, -5.6621e-03, -9.3449e-03,
|
||||
-1.1662e-02, 2.8700e-02, -9.0654e-03, 3.1569e-02, -2.9825e-03,
|
||||
-3.8198e-02, -5.2723e-02, -4.8325e-02, -2.7871e-03, 5.1127e-03,
|
||||
1.4511e-02, 9.3245e-03, -2.3339e-02, -8.6658e-03, 1.5276e-02,
|
||||
-1.5823e-02, -3.4476e-03, 1.4601e-02, 6.3504e-03, -1.4307e-02,
|
||||
2.2817e-02, 2.1998e-02, 1.7330e-02, -2.4448e-02, 4.0178e-03,
|
||||
3.2280e-03, -1.2721e-02, 1.9661e-02, 7.5263e-03, 2.0245e-02,
|
||||
4.5525e-02, -1.5658e-02, -4.0676e-02, 9.3160e-03, 1.1920e-02,
|
||||
-1.9317e-02, 1.7848e-02, -5.8601e-03, 1.1786e-03, 8.3864e-03,
|
||||
-1.8341e-02, 2.5985e-02, -1.1387e-02, -1.5069e-02, -2.8097e-02,
|
||||
2.4966e-02, 1.4790e-02, 2.0424e-02, -1.3062e-02, 3.1314e-02,
|
||||
1.7811e-02, 7.2393e-03, 1.4413e-02, -1.2746e-02, 3.1039e-02,
|
||||
-1.1697e-02, -1.4826e-02, -8.8397e-03, 1.5157e-02, -1.5855e-02,
|
||||
-1.8157e-03, 1.3024e-02, -1.8902e-03, 2.5212e-02, -3.4886e-02,
|
||||
4.3029e-02, -4.0842e-02, 1.1362e-02, -1.4654e-02, -1.3337e-02,
|
||||
-3.1832e-02, 3.6222e-03, 8.2804e-03, -1.4269e-02, 2.8399e-03,
|
||||
-1.2008e-02, 2.4685e-02, -4.3070e-03, 6.3163e-03, -1.3517e-02,
|
||||
-1.3807e-02, 2.4617e-02, 2.1453e-02, 4.7332e-03, 9.1636e-03,
|
||||
-1.2881e-02, 1.9077e-02, 1.7571e-04, -5.2817e-03, -2.8821e-02,
|
||||
5.8223e-03, -3.0979e-02, 2.4609e-02, 3.6666e-02, -1.0950e-02,
|
||||
2.0421e-02, -2.6378e-03, 3.1825e-02, -9.6689e-04, -2.8398e-02,
|
||||
-2.7513e-02, 1.6946e-02, -2.4110e-02, -1.3575e-02, -1.3443e-02,
|
||||
8.4217e-03, 2.6754e-02, -2.3309e-03, -2.5086e-02, 1.1844e-02,
|
||||
1.4152e-02, 1.2989e-02, -5.7336e-03, 4.7391e-03, 3.4106e-02,
|
||||
1.0142e-02, -1.8029e-02, -1.5410e-04, -1.3548e-02, 9.1742e-03,
|
||||
-3.0150e-02, 1.5666e-02, 4.3049e-03, 1.6273e-02, 2.0672e-02,
|
||||
-1.2458e-02, 4.5496e-02, 3.2131e-02, -3.0967e-03, 2.1891e-02,
|
||||
2.5524e-02, -1.1998e-02, -1.8866e-03, -1.0945e-02, 5.9930e-03,
|
||||
-8.4233e-03, -8.9095e-03, -1.8261e-02, 1.9308e-02, -1.9728e-02,
|
||||
-1.4216e-02, 1.4952e-02, 5.7355e-04, -2.4753e-02, -1.0948e-02,
|
||||
1.0965e-02, 1.3607e-03, 3.4974e-02, -4.1396e-03, 2.5519e-02,
|
||||
1.0364e-02, -1.5851e-02, -4.9224e-03, 1.0903e-02, -1.0523e-04,
|
||||
3.1355e-02, -1.5105e-02, 5.6972e-03, -8.4078e-03, -1.9868e-02,
|
||||
1.7186e-03, 2.9396e-02, -4.1439e-02, 1.4124e-02, -3.7745e-03,
|
||||
3.3007e-02, 8.0368e-04, 8.5574e-03, 1.7269e-02, 1.1955e-02,
|
||||
8.8142e-03, -1.3123e-02, 1.6817e-02, -1.5456e-02, -1.3868e-02,
|
||||
2.4139e-02, -9.1566e-03, -1.8477e-02, -4.7972e-03, -6.8459e-03,
|
||||
1.6818e-02, 3.1645e-03, -3.0901e-02, -5.6036e-03, -1.4758e-02,
|
||||
2.0473e-02, -7.5411e-05, 2.0673e-03, -7.0061e-03, 9.5544e-03,
|
||||
1.6600e-02, -1.7315e-02, -2.0168e-02, -5.3008e-03, 2.0206e-02,
|
||||
2.4209e-03, 2.1205e-02, -8.9188e-03, -4.1350e-04, -1.0638e-02,
|
||||
1.3705e-02, 9.5925e-05, 3.8877e-02, 3.2884e-02, -2.7730e-03,
|
||||
1.0052e-02, 1.9311e-02, 1.1341e-02, -1.2988e-02, -1.7157e-02,
|
||||
3.2095e-02, -1.8493e-02, -9.2551e-03, -2.6509e-03, -1.1130e-02,
|
||||
1.6581e-02, 1.0216e-02, 1.3687e-02, 1.1860e-02, -3.0462e-03,
|
||||
-1.2082e-02, 2.8502e-03, -1.2620e-02, 8.8330e-03, 1.7357e-02,
|
||||
1.8383e-02, -2.3130e-02, -3.2654e-02, 1.2853e-02, -7.8144e-03,
|
||||
1.9418e-04, 3.8635e-03, 4.9333e-02, 1.9350e-02, -2.0643e-02,
|
||||
8.4650e-04, 5.0242e-02, 1.6576e-02, -8.9166e-03, -5.8805e-03,
|
||||
-4.1484e-02, 9.3217e-03, -1.1292e-02, -8.7944e-03, -3.3190e-03,
|
||||
5.7970e-03, -6.6078e-03, -2.4052e-02, -5.6347e-03, 8.4539e-03,
|
||||
1.9250e-02, 7.9559e-03, -3.0055e-03, -3.0398e-04, 2.7007e-02,
|
||||
3.1046e-03, 1.8332e-02, 5.5470e-03, 6.6815e-03, 1.1466e-02,
|
||||
1.9738e-02, 1.2176e-02, -2.0220e-02, 8.6928e-03, 4.2451e-03,
|
||||
4.4517e-03, -5.1524e-03, 1.0805e-02, -2.1935e-02, -1.7575e-02,
|
||||
-1.2529e-02, -2.2191e-02, -1.0854e-02, -9.4462e-03, -2.9102e-02,
|
||||
2.6752e-02, -1.0919e-02, -2.6724e-02, 8.3694e-04, 2.9832e-03,
|
||||
1.4416e-02, -2.9906e-02, 2.3556e-02, -6.6624e-03, 2.6671e-02,
|
||||
-3.6474e-02, 1.7237e-02, -2.5176e-02, 6.5560e-03, -2.6062e-02,
|
||||
-2.3838e-02, 3.0629e-02, 2.5382e-02, 1.2302e-02, -1.1665e-02,
|
||||
-7.0603e-03, 1.9931e-02, 2.3401e-02, -2.6047e-03, -2.7728e-02,
|
||||
-1.7212e-02, 2.3061e-02, -2.5961e-02, 3.9764e-04, -2.9022e-02,
|
||||
-1.5546e-03, 4.5519e-03, 2.3589e-02, -3.5005e-02, 4.1890e-03,
|
||||
-1.5586e-02, 1.2389e-02, -2.1045e-02, 1.6377e-03, -1.1328e-02,
|
||||
1.0195e-02, 6.4322e-03, -3.8431e-02, 2.2918e-02, -4.0123e-03,
|
||||
6.6680e-02, 4.1135e-02, -1.5031e-02, -1.3550e-02, -2.2566e-02,
|
||||
-2.3622e-03, -2.9323e-02, 2.1756e-02, 1.8399e-03, -4.2460e-03,
|
||||
-1.5128e-03, -2.4731e-02, 1.8663e-02, 1.3469e-02, -1.3897e-02,
|
||||
2.6399e-02, -8.0740e-03, -4.6753e-03, 3.9857e-02, 6.2364e-03,
|
||||
2.2371e-03, 2.1501e-03, 5.9443e-02, 1.3574e-02, 7.6483e-03,
|
||||
-6.2290e-03, 1.4324e-02, 1.2572e-02, 2.7331e-02, -6.0165e-03,
|
||||
-5.9154e-03, -3.7000e-02, 1.4001e-02, 1.2869e-02, -2.8854e-02,
|
||||
-9.4147e-03, 8.3965e-03, -1.4530e-03, -7.4215e-03, 9.0369e-03,
|
||||
-2.4612e-02, 2.0625e-02, 2.2329e-02, -1.5216e-02, 1.4947e-03,
|
||||
-3.6020e-02, -2.0702e-02, -4.0410e-02, -1.3157e-02, -1.5085e-02,
|
||||
1.2911e-02, -2.7552e-02, -2.9781e-02, -4.7424e-03, 2.0521e-02,
|
||||
-4.0043e-02, -4.8763e-02, -1.3175e-02, 2.6802e-02, 2.8869e-02,
|
||||
6.5014e-03, -2.3213e-02, 1.4438e-02, -7.6318e-03, -1.9928e-03,
|
||||
1.8509e-03, 2.9728e-03, 1.5225e-02, -2.9405e-03, -7.2875e-03,
|
||||
2.9562e-05, -1.8661e-02, 9.1341e-03, -2.4919e-02, 2.9786e-02,
|
||||
9.5186e-03, 1.5435e-02, -1.1080e-02, 1.1192e-02, -2.7315e-03,
|
||||
6.9769e-05, -1.5392e-02, 4.9892e-03, 7.9857e-03, 2.0063e-02,
|
||||
-2.0283e-02, -1.2596e-02, -4.1985e-04, -6.9686e-03, -5.4704e-02,
|
||||
-1.9142e-02, 9.9706e-03, 2.3217e-02, -5.0579e-03, -4.9132e-02,
|
||||
2.0023e-02, -2.6238e-02, 1.0709e-02, 2.1528e-02, -1.6390e-03,
|
||||
-6.7829e-03, 1.3211e-02, -9.6793e-03, 1.3130e-02, -1.2878e-02,
|
||||
1.7365e-02, 1.2509e-02, 1.2986e-03, -3.9292e-02, 9.5784e-03,
|
||||
-8.0514e-03, -3.5619e-02, -3.2298e-02, 6.5933e-04, 9.9298e-03,
|
||||
3.7268e-02, -3.4047e-02, -7.8385e-03, 2.3999e-02, 1.0386e-02,
|
||||
1.7853e-02, -1.0122e-04, 5.2483e-04, -7.3150e-03, 1.0818e-02,
|
||||
1.6245e-02, -3.5619e-02, -9.9190e-03, 4.0132e-03, 9.7788e-03,
|
||||
2.7039e-02, -4.7858e-02, -2.0010e-02, -2.3702e-02, 7.8376e-04,
|
||||
-2.5326e-02, 1.1698e-02, -1.3041e-02, 3.8634e-03, 9.3083e-03,
|
||||
4.8204e-03, 3.9503e-02, -4.1356e-03]], requires_grad=True)
|
||||
model.Qformer.bert.encoder.layer[10].experts.gate.weight
|
||||
|
||||
layer 11
|
||||
0:
|
||||
model.Qformer.bert.encoder.layer[11].output.dense.weight.grad
|
||||
model.Qformer.bert.encoder.layer[11].intermediate.dense.weight.grad
|
||||
|
||||
nan:
|
||||
model.Qformer.bert.encoder.layer[11].attention.output.dense.weight.grad
|
||||
model.Qformer.bert.encoder.layer[11].attention.self.query.weight.grad
|
||||
model.Qformer.bert.encoder.layer[11].experts.intermediate_query.dense.weight.grad
|
||||
model.Qformer.bert.encoder.layer[11].experts.output_query.dense.weight.grad
|
||||
|
||||
None:
|
||||
model.Qformer.bert.encoder.layer[11].intermediate_query.dense.weight.grad
|
||||
model.Qformer.bert.encoder.layer[11].output_query.dense.weight.grad
|
||||
|
||||
layer 8
|
||||
0:
|
||||
model.Qformer.bert.encoder.layer[8].experts.experts[0].intermediate_query.dense.weight.grad
|
||||
model.Qformer.bert.encoder.layer[8].experts.experts[2].intermediate_query.dense.weight.grad
|
||||
model.Qformer.bert.encoder.layer[8].experts.experts[0].output_query.dense.weight.grad
|
||||
model.Qformer.bert.encoder.layer[8].experts.experts[2].output_query.dense.weight.grad
|
||||
|
||||
nan:
|
||||
model.Qformer.bert.encoder.layer[8].experts.experts[1].intermediate_query.dense.weight.grad
|
||||
model.Qformer.bert.encoder.layer[8].experts.experts[1].output_query.dense.weight.grad
|
||||
(Qformer)model.Qformer.bert.encoder.layer[8].intermediate_query.dense.weight.grad
|
||||
|
||||
None:
|
||||
model.Qformer.bert.encoder.layer[8].experts.gate.weight.grad == None
|
||||
model.Qformer.bert.encoder.layer[8].experts.gate.weight.requires_grad == True
|
||||
|
||||
|
||||
model.Qformer.bert.encoder.layer[6].experts.gate.weight
|
||||
Qformer.bert.encoder.layer.6.experts.gate.weight
|
||||
|
||||
tensor([[-0.0089, -0.0123, -0.0168, ..., -0.0072, 0.0295, -0.0167],
|
||||
[ 0.0305, 0.0277, -0.0215, ..., 0.0149, 0.0016, -0.0415],
|
||||
[ 0.0199, 0.0151, 0.0237, ..., 0.0007, 0.0023, 0.0167]],
|
||||
requires_grad=True)
|
||||
|
||||
tensor([[-0.0089, -0.0123, -0.0168, ..., -0.0072, 0.0295, -0.0167],
|
||||
[ 0.0305, 0.0277, -0.0215, ..., 0.0149, 0.0016, -0.0415],
|
||||
[ 0.0199, 0.0151, 0.0237, ..., 0.0007, 0.0023, 0.0167]],
|
||||
requires_grad=True)
|
||||
|
||||
|
||||
tensor([[ 4.5972e-02, -1.5231e-02, -6.9533e-03, 3.2431e-02, -7.9703e-03,
|
||||
1.5567e-02, 2.9619e-03, -2.2609e-04, 1.8580e-02, -2.8783e-02,
|
||||
1.3093e-02, -1.0594e-02, 1.1918e-02, 4.4701e-02, 2.0108e-02,
|
||||
-1.1011e-03, -8.2449e-03, 8.8876e-03, 4.6096e-03, 2.3274e-02,
|
||||
-9.2557e-03, 2.5704e-03, 1.8919e-02, -5.3251e-03, -3.2665e-03,
|
||||
-3.2663e-02, -5.6756e-02, -2.3400e-02, 1.3674e-02, -6.6185e-03,
|
||||
1.4429e-03, 1.2354e-02, 2.5934e-03, 2.1895e-02, -1.9793e-02,
|
||||
1.5497e-03, 4.3056e-03, -4.0023e-02, 9.8740e-03, 3.8631e-03,
|
||||
-1.2918e-02, -3.6782e-02, -9.8365e-03, 3.2182e-02, 2.3729e-02,
|
||||
2.3509e-03, 1.8473e-02, 1.5583e-02, -1.1029e-02, -1.0738e-02,
|
||||
-3.0278e-02, -9.8731e-03, -1.0500e-02, 7.9832e-05, -1.0345e-02,
|
||||
8.2803e-03, -5.9923e-03, -1.2669e-02, 1.2065e-03, 7.5720e-03,
|
||||
-1.9286e-02, 4.0070e-02, 3.6221e-03, -1.7486e-02, 2.1725e-02,
|
||||
-3.3231e-02, 7.3948e-03, -1.0924e-02, 3.1448e-02, 1.2101e-02,
|
||||
6.1737e-03, -2.0851e-02, -3.7964e-02, 8.0938e-03, -8.8967e-03,
|
||||
2.5925e-02, -7.8063e-04, 8.6102e-03, 2.7370e-02, 1.2323e-02,
|
||||
4.0606e-03, 3.9316e-02, -1.0837e-02, -2.6835e-03, 3.1941e-03,
|
||||
-1.2017e-02, -2.3022e-02, 8.3533e-03, -2.2668e-02, 1.4438e-02,
|
||||
-2.3664e-02, 4.5595e-02, -1.0962e-02, 1.7547e-02, -1.6739e-03,
|
||||
1.2048e-02, 2.0544e-02, 2.8837e-02, -1.6736e-02, 2.1207e-02,
|
||||
8.7612e-03, 2.8757e-02, -3.8561e-03, 8.4050e-03, -1.1503e-02,
|
||||
-5.8332e-03, 1.5734e-02, -1.0773e-02, 7.5827e-03, 6.5794e-03,
|
||||
2.4291e-02, 2.6811e-02, 1.1681e-02, -3.3246e-02, 4.5776e-03,
|
||||
-9.0628e-04, -2.9400e-02, 4.2933e-03, 1.5885e-03, 5.5757e-02,
|
||||
7.5518e-03, 1.0099e-02, 5.3507e-03, -3.0182e-02, 2.0830e-02,
|
||||
1.0102e-02, -9.3074e-03, 3.1161e-02, -1.7800e-02, -4.4445e-03,
|
||||
-3.1503e-02, 2.3028e-02, 8.3472e-03, 7.4444e-03, 1.8838e-02,
|
||||
-1.1977e-02, -2.6713e-02, 1.1364e-02, 8.3522e-04, 3.3736e-03,
|
||||
6.9425e-03, -2.0632e-02, 1.8155e-02, -2.1711e-02, -3.4703e-02,
|
||||
-3.6268e-03, -4.8810e-03, -2.8142e-02, -1.5781e-02, -3.3166e-02,
|
||||
-2.9910e-02, -9.7459e-03, -6.7474e-03, 1.7988e-02, 9.0176e-03,
|
||||
1.9452e-02, 4.2009e-02, 1.7217e-02, 1.4959e-02, -1.6552e-02,
|
||||
-3.8206e-03, -2.4889e-02, 7.7993e-03, -1.9285e-02, -1.9770e-02,
|
||||
2.6936e-02, -5.0484e-03, -2.5117e-02, -2.3122e-02, 1.3754e-02,
|
||||
1.6025e-02, -9.1569e-03, -2.0068e-02, -1.6013e-02, -2.1775e-02,
|
||||
-2.4154e-02, 6.2840e-03, -1.3684e-02, 2.5378e-02, -1.3166e-02,
|
||||
-1.2201e-02, 1.0011e-02, -8.2324e-03, -5.6623e-03, -1.0383e-02,
|
||||
-1.6251e-02, 1.0723e-02, -3.0207e-03, -6.9374e-03, -2.3161e-03,
|
||||
-2.0850e-03, -3.4216e-02, 3.3997e-02, 3.7444e-02, -3.4273e-02,
|
||||
1.5051e-02, -9.5605e-03, -2.6979e-03, 1.8848e-02, 2.3090e-02,
|
||||
1.9669e-02, -3.9656e-02, 1.0453e-02, 5.2222e-03, -7.2493e-03,
|
||||
1.4122e-02, 5.6583e-04, -1.3991e-02, 4.0975e-02, 1.3947e-02,
|
||||
4.6919e-03, 7.9121e-03, 2.6936e-02, 1.2338e-02, 1.9048e-02,
|
||||
7.7740e-03, -6.4494e-03, -5.2965e-02, 8.1929e-03, -1.3503e-02,
|
||||
3.7466e-03, -3.3504e-02, -8.1192e-03, 1.0463e-02, -2.1568e-02,
|
||||
1.0076e-02, -1.3420e-02, -6.3353e-04, 7.4253e-03, 2.2281e-02,
|
||||
5.2829e-03, 1.4102e-02, 1.4427e-02, 1.6331e-02, -2.3305e-04,
|
||||
-4.4875e-02, 6.5300e-03, 2.4963e-02, 2.2141e-03, 3.9830e-02,
|
||||
1.1405e-02, 8.6810e-03, -2.0404e-03, -1.8579e-03, 1.4765e-02,
|
||||
5.4752e-03, -1.3364e-02, -1.3082e-03, 1.5873e-03, 1.9309e-02,
|
||||
3.4367e-02, 1.8459e-02, -1.1323e-02, -1.8764e-02, -1.5370e-02,
|
||||
3.6180e-03, 2.8253e-02, -1.6867e-03, 3.5884e-03, -2.1952e-02,
|
||||
-1.5026e-02, -2.1070e-02, -1.2149e-02, 1.1162e-02, -3.0343e-02,
|
||||
-4.1372e-02, 1.0880e-02, 2.2365e-02, 1.2896e-02, 2.9694e-02,
|
||||
-8.4248e-03, -7.8876e-03, -6.7049e-03, 2.3700e-02, 4.7528e-03,
|
||||
-7.8350e-03, -5.9220e-03, 3.8396e-02, -4.1598e-02, -2.3161e-03,
|
||||
1.3419e-02, 7.1029e-03, 1.4195e-02, -1.1124e-02, 1.5812e-02,
|
||||
-1.9789e-02, -2.3883e-02, -8.2788e-04, 1.4670e-02, -2.1482e-02,
|
||||
-1.1182e-02, -1.6532e-02, -8.0637e-03, -3.7822e-02, 3.9402e-02,
|
||||
-1.4097e-03, -7.6648e-03, -3.7156e-02, 2.5791e-02, 6.1038e-03,
|
||||
-6.3429e-03, 3.2865e-03, 3.6277e-02, 9.4312e-03, -2.1003e-02,
|
||||
-3.6885e-03, 1.7147e-02, -1.3079e-02, -4.9414e-02, -3.2066e-02,
|
||||
1.4835e-02, -2.9742e-02, 1.8358e-02, -2.1733e-02, 3.0256e-03,
|
||||
1.7825e-02, 1.1079e-02, 1.1619e-02, -2.3680e-02, -7.8721e-03,
|
||||
2.4456e-03, 4.3608e-02, -4.5674e-03, -3.6818e-02, 3.3952e-02,
|
||||
3.3108e-02, -3.1665e-03, -2.3468e-03, 1.5091e-02, 7.0856e-03,
|
||||
1.1723e-02, -2.0713e-02, -6.9180e-03, 3.7929e-02, 3.7671e-03,
|
||||
4.6663e-02, 9.5301e-03, 1.2638e-02, -6.5623e-03, -3.1771e-03,
|
||||
-1.7568e-02, 1.8711e-03, -1.2310e-02, 2.1518e-02, 4.3408e-03,
|
||||
-6.7171e-03, -5.0451e-03, 2.6870e-02, -1.9832e-02, 7.0422e-03,
|
||||
1.1274e-02, -2.4637e-02, -4.8450e-03, 2.1892e-02, -2.6059e-02,
|
||||
1.5605e-02, -1.1617e-02, -1.9273e-02, -8.6735e-04, -9.8002e-04,
|
||||
-1.8553e-02, 2.1239e-02, 2.1078e-02, -1.2091e-02, 9.7025e-03,
|
||||
1.3426e-02, -1.1710e-02, -2.2242e-03, 6.4133e-03, -1.4820e-02,
|
||||
1.4682e-02, 3.0679e-02, 1.1526e-02, 1.0072e-02, -1.1572e-02,
|
||||
2.6128e-02, 4.0879e-03, -1.7936e-02, 1.3715e-02, -2.3667e-02,
|
||||
2.0419e-03, -1.6887e-02, 1.2595e-02, -2.1988e-02, -2.3777e-02,
|
||||
-1.0399e-02, 2.4868e-03, -1.2265e-02, -1.8543e-02, 3.4672e-02,
|
||||
2.1114e-02, 2.0523e-02, 7.6818e-03, 2.9282e-02, -5.9593e-03,
|
||||
-2.8496e-02, 2.8482e-03, 3.6874e-04, 4.7455e-02, -2.9770e-02,
|
||||
-2.0684e-02, -2.0749e-02, -5.7681e-02, -2.6175e-03, -2.4488e-02,
|
||||
-5.2550e-03, -7.1191e-03, 3.8192e-02, 4.3438e-02, 5.4181e-03,
|
||||
2.8392e-02, 1.9493e-02, -3.5262e-02, 1.4839e-02, 4.6481e-03,
|
||||
1.7219e-02, 2.0160e-02, 4.9998e-03, 2.1316e-02, -8.7929e-04,
|
||||
-2.1542e-02, 3.9816e-03, 1.5879e-02, 9.9231e-03, 1.3962e-02,
|
||||
-5.3418e-03, 3.9857e-02, 2.0997e-02, -2.1291e-05, 1.8133e-02,
|
||||
-1.2472e-02, 4.9437e-03, -1.5099e-02, 4.8860e-02, 6.1980e-03,
|
||||
2.0197e-02, 1.3141e-04, -3.1087e-03, -2.2718e-03, 2.3804e-02,
|
||||
6.0726e-03, -2.0485e-02, -2.0514e-02, -2.7679e-02, -3.0412e-02,
|
||||
-1.7661e-02, -1.7462e-02, 7.5216e-03, 2.2238e-02, 1.1413e-03,
|
||||
2.6647e-02, -2.3855e-02, 2.2652e-03, -4.3256e-03, -9.3274e-03,
|
||||
2.5149e-02, 6.8432e-03, 4.2664e-03, 3.8221e-02, 7.7480e-03,
|
||||
8.7203e-03, -1.2851e-03, -1.1325e-02, -1.0650e-02, -2.8079e-02,
|
||||
-1.5375e-02, 2.2630e-02, -4.3439e-03, 1.3493e-02, -1.8223e-02,
|
||||
9.9750e-03, -2.4560e-02, 1.0904e-03, -3.1198e-02, 4.7331e-03,
|
||||
1.6713e-02, -1.7653e-02, -3.8674e-02, 1.5458e-02, 4.0555e-02,
|
||||
6.9451e-03, 1.1988e-03, 8.0718e-04, 3.9985e-03, -2.2781e-02,
|
||||
8.1173e-04, 2.0106e-02, -1.2800e-02, -1.2961e-02, -2.1273e-02,
|
||||
-4.4104e-05, -3.6080e-02, -1.9392e-02, 3.2862e-02, -5.6041e-03,
|
||||
2.3288e-02, -4.6795e-02, 1.7282e-02, 5.7052e-03, 2.2405e-02,
|
||||
1.9871e-03, -1.4333e-02, 5.3773e-03, 4.3568e-02, 9.8980e-03,
|
||||
-1.9403e-03, 1.8981e-02, -2.5712e-02, -3.3621e-03, 2.9886e-02,
|
||||
1.3326e-03, 1.1318e-02, -3.3238e-03, -1.5494e-02, -3.0565e-02,
|
||||
1.7137e-02, -2.7874e-02, -1.1257e-02, 3.2250e-02, -2.5293e-02,
|
||||
-3.0693e-03, -2.7787e-02, 1.4931e-02, 2.4202e-03, -4.0572e-03,
|
||||
5.0273e-03, 9.7496e-03, 2.2601e-02, 3.2389e-02, -1.1910e-02,
|
||||
9.1037e-03, 5.6000e-02, -1.9640e-02, 1.5469e-02, -3.3027e-02,
|
||||
1.4839e-02, 2.5071e-02, -1.2687e-02, -1.3466e-02, 1.9031e-02,
|
||||
-7.3403e-03, -1.5207e-02, -1.4486e-02, 2.0678e-02, -4.1996e-02,
|
||||
1.0585e-02, 3.6276e-02, 6.1149e-03, 1.6405e-02, 1.5643e-02,
|
||||
1.5060e-02, -5.1235e-03, -2.2824e-02, -1.3752e-02, -1.5742e-02,
|
||||
2.4032e-02, -2.1782e-03, -1.3158e-02, 3.9482e-03, 3.2267e-02,
|
||||
-2.2632e-03, 1.2055e-02, 4.4731e-02, 1.8271e-02, -1.1486e-02,
|
||||
1.7836e-02, 1.7886e-03, -2.4020e-02, 2.6064e-02, -2.2122e-04,
|
||||
1.8643e-02, -2.9808e-02, -6.1845e-03, -4.4464e-03, 8.8374e-04,
|
||||
1.5268e-02, 1.7205e-03, 5.7832e-02, -1.7486e-02, 1.1897e-02,
|
||||
5.8081e-02, 1.7667e-02, -7.7282e-03, 1.4036e-02, -1.4936e-03,
|
||||
6.0635e-04, 1.6124e-03, -1.6916e-02, -1.1239e-02, 1.8497e-02,
|
||||
1.2334e-03, -2.0706e-02, 3.2959e-03, 2.9186e-02, 3.7506e-02,
|
||||
1.2037e-02, -1.4903e-02, 8.5606e-03, 3.4136e-03, 1.1850e-02,
|
||||
-7.4782e-03, 5.3924e-03, -2.4772e-02, 2.6840e-02, -2.7656e-02,
|
||||
-3.2637e-02, -1.2779e-02, 1.0730e-02, 1.4096e-03, 3.1572e-02,
|
||||
7.8976e-04, 3.1674e-02, 8.5333e-03, -1.2679e-02, 1.1176e-02,
|
||||
-2.0446e-02, 1.8628e-02, -4.0158e-02, -2.3358e-02, -2.2504e-02,
|
||||
-2.8759e-02, -1.4597e-02, -8.5879e-03, 1.0550e-02, -3.5556e-02,
|
||||
-1.9046e-02, -1.9159e-02, -2.2703e-02, -7.2056e-03, 4.2380e-02,
|
||||
-9.7475e-03, -2.4754e-02, 1.3992e-03, -1.0411e-02, 1.5708e-02,
|
||||
-8.2899e-03, -6.4856e-03, 1.6359e-02, -5.1969e-04, -5.0958e-03,
|
||||
-4.1232e-02, 2.7349e-03, -1.7723e-02, 1.3388e-02, 2.2776e-03,
|
||||
-2.0786e-02, -1.8082e-02, -2.4866e-03, 2.2141e-02, 6.9998e-03,
|
||||
-5.5714e-03, 2.1088e-02, 5.8745e-03, 1.2788e-02, 4.2977e-03,
|
||||
5.8631e-03, -1.8121e-02, 1.9242e-03, 2.3622e-02, 1.4917e-02,
|
||||
-5.3198e-03, -3.9222e-02, -2.4697e-02, 9.1218e-03, -1.0711e-02,
|
||||
1.0268e-02, 1.5148e-02, -4.4508e-02, 4.6783e-03, 2.8093e-03,
|
||||
9.1253e-03, -7.3281e-03, 1.0114e-03, -9.2369e-04, 1.4841e-02,
|
||||
2.2642e-02, 2.3675e-02, 1.3902e-02, -5.6343e-03, 1.4851e-02,
|
||||
-9.5169e-03, -3.1721e-02, 1.6696e-02, 2.9285e-02, -1.4090e-02,
|
||||
2.1128e-02, 4.8656e-02, 3.8431e-02, -3.5470e-02, -4.8230e-03,
|
||||
-1.6513e-02, 4.1917e-02, 8.9090e-03, -1.4022e-04, 4.0182e-03,
|
||||
7.1723e-03, 3.1419e-02, -4.8508e-03, 1.7768e-03, -7.3688e-03,
|
||||
3.4637e-03, -2.3227e-02, 3.9606e-05, -2.4731e-02, -1.3640e-02,
|
||||
-5.1718e-03, 2.6662e-02, -1.2871e-02, -1.6009e-02, -5.3720e-03,
|
||||
2.7397e-04, -3.4016e-03, 2.6429e-02, 3.8069e-02, 1.0929e-02,
|
||||
-1.0620e-02, 1.2165e-02, -2.6018e-02, 1.6021e-02, 4.0644e-02,
|
||||
-8.0898e-03, -3.5198e-02, -1.9602e-02, 2.4986e-02, -5.8400e-03,
|
||||
3.2070e-02, -1.8265e-02, -5.4518e-03, 2.8195e-02, 5.5598e-02,
|
||||
-3.9959e-02, 1.5521e-02, -2.8416e-02, 3.1130e-02, -1.0038e-02,
|
||||
2.1522e-02, -1.1654e-02, 2.2382e-02, -5.4467e-03, -2.2840e-02,
|
||||
2.7036e-03, -4.4607e-02, -4.1953e-02, 2.0079e-02, -5.0121e-03,
|
||||
-1.7495e-02, 4.4070e-03, 3.7400e-04, 1.0899e-02, 1.7008e-02,
|
||||
-1.6307e-02, -1.9986e-02, -2.3865e-02, -2.5618e-02, -2.9981e-02,
|
||||
-2.7230e-03, 2.7079e-02, 5.2920e-03, 2.1069e-02, -2.5896e-02,
|
||||
-1.6256e-02, -1.4182e-03, 1.1829e-02, 1.0360e-02, 2.8883e-02,
|
||||
-6.8762e-03, 1.4032e-02, -4.3389e-03]], requires_grad=True)
|
109
test1.txt
Normal file
109
test1.txt
Normal file
@ -0,0 +1,109 @@
|
||||
from torchviz import make_dot
|
||||
dot = make_dot(query_output.last_hidden_state, params=dict(self.Qformer.bert.named_parameters()))
|
||||
log_dir = '/mnt/pfs-guan-ssai/nlu/wanghanzi/multimodal/PromptMoE/'
|
||||
dot.render(filename="Pre_PromptMoE_RawProb_backward_graph", directory=log_dir, format="pdf")
|
||||
|
||||
|
||||
# Pre-Prompt-MoE
|
||||
model.Qformer.bert.encoder.layer[6].experts.gate.weight.grad
|
||||
model.Qformer.bert.encoder.layer[8].experts.gate.weight.grad
|
||||
model.Qformer.bert.encoder.layer[10].experts.gate.weight.grad
|
||||
model.Qformer.bert.encoder.layer[6].experts.experts[0].dense1.weight.grad
|
||||
model.Qformer.bert.encoder.layer[10].experts.experts[0].dense1.weight.grad
|
||||
model.Qformer.bert.encoder.layer[8].experts.experts[0].dense1.weight.grad
|
||||
model.Qformer.bert.encoder.layer[8].experts.experts[1].dense1.weight.grad
|
||||
model.Qformer.bert.encoder.layer[8].experts.experts[2].dense1.weight.grad
|
||||
|
||||
|
||||
model.Qformer.bert.encoder.layer[8].experts.gate.weight.grad
|
||||
model.Qformer.bert.encoder.layer[9].intermediate_query.dense.weight
|
||||
model.Qformer.bert.encoder.layer[9].intermediate_query.dense.weight.grad
|
||||
model.Qformer.bert.encoder.layer[10].intermediate.dense.weight.grad
|
||||
model.Qformer.bert.encoder.layer[11].intermediate.dense.weight.grad
|
||||
|
||||
model.Qformer.bert.encoder.layer[10].intermediate_query.dense.weight
|
||||
model.Qformer.bert.encoder.layer[10].experts.experts[2].dense1.weight
|
||||
model.Qformer.bert.encoder.layer[10].experts.experts[1].dense1.weight
|
||||
model.Qformer.bert.encoder.layer[10].experts.experts[0].dense1.weight
|
||||
model.Qformer.bert.encoder.layer[10].intermediate_query.dense.weight == model.Qformer.bert.encoder.layer[10].experts.experts[0].dense1.weight
|
||||
|
||||
# Pre-MoE gate-sentence
|
||||
# model.Qformer.bert.encoder.layer[8].experts.gate.weight.grad 不更新
|
||||
|
||||
# Pre-MoE gate-token
|
||||
# 正常更新
|
||||
|
||||
# Post-MoE gate-sentence
|
||||
model.Qformer.bert.encoder.layer[8].experts.gate.weight.grad
|
||||
# model.Qformer.bert.encoder.layer[8].experts.gate.weight.grad 正常更新
|
||||
# model.Qformer.bert.encoder.layer[6].experts.gate.weight.grad 全是0/-0
|
||||
# model.Qformer.bert.encoder.layer[10].experts.gate.weight.grad 全是0/-0
|
||||
|
||||
# Route-MoE
|
||||
# Pre-MoE 算的beam_scores有问题
|
||||
|
||||
# Post-Route 会更新多个expert的参数;会更新gate的参数
|
||||
# Layer 6 更新了两个expert的参数 (layer 6 layer 8)
|
||||
# model.Qformer.bert.encoder.layer[11].intermediate.dense.weight.grad 是0?都是0
|
||||
# model.Qformer.bert.encoder.layer[11].output.dense.weight.grad
|
||||
|
||||
model.Qformer.bert.encoder.layer[6].experts.gate.weight.grad
|
||||
model.Qformer.bert.encoder.layer[6].experts.experts[0].intermediate_query.dense.weight.grad
|
||||
model.Qformer.bert.encoder.layer[6].experts.experts[1].intermediate_query.dense.weight.grad
|
||||
|
||||
model.Qformer.bert.encoder.layer[7].experts.gate.weight.grad
|
||||
model.Qformer.bert.encoder.layer[7].experts.experts[0].intermediate_query.dense.weight.grad
|
||||
model.Qformer.bert.encoder.layer[7].experts.experts[1].intermediate_query.dense.weight.grad
|
||||
|
||||
model.Qformer.bert.encoder.layer[8].experts.gate.weight.grad
|
||||
model.Qformer.bert.encoder.layer[8].experts.experts[0].intermediate_query.dense.weight.grad
|
||||
model.Qformer.bert.encoder.layer[8].experts.experts[1].intermediate_query.dense.weight.grad
|
||||
|
||||
model.Qformer.bert.encoder.layer[9].experts.gate.weight.grad
|
||||
model.Qformer.bert.encoder.layer[9].experts.experts[0].intermediate_query.dense.weight.grad
|
||||
model.Qformer.bert.encoder.layer[9].experts.experts[1].intermediate_query.dense.weight.grad
|
||||
|
||||
model.Qformer.bert.encoder.layer[10].experts.gate.weight.grad
|
||||
model.Qformer.bert.encoder.layer[10].experts.experts[0].intermediate_query.dense.weight.grad
|
||||
model.Qformer.bert.encoder.layer[10].experts.experts[1].intermediate_query.dense.weight.grad
|
||||
|
||||
model.Qformer.bert.encoder.layer[11].experts.gate.weight.grad
|
||||
model.Qformer.bert.encoder.layer[11].experts.experts[0].intermediate_query.dense.weight.grad
|
||||
model.Qformer.bert.encoder.layer[11].experts.experts[1].intermediate_query.dense.weight.grad
|
||||
|
||||
|
||||
(Pdb) [p for n, p in self.model.named_parameters() if n == 'Qformer.bert.encoder.layer.10.experts.experts.0.dense1.weight']
|
||||
[Parameter containing:
|
||||
tensor([[-0.0328, 0.0414, 0.0010, ..., -0.0068, 0.0244, 0.0587],
|
||||
[ 0.0120, 0.0458, 0.0171, ..., -0.0439, -0.0107, -0.0397],
|
||||
[ 0.0239, 0.0191, -0.0145, ..., 0.0008, -0.0067, 0.0090],
|
||||
...,
|
||||
[ 0.0174, -0.0465, -0.0106, ..., -0.0095, 0.0153, -0.0195],
|
||||
[-0.0151, -0.0082, -0.0320, ..., -0.0016, -0.0232, -0.0147],
|
||||
[ 0.0142, -0.0286, 0.0161, ..., -0.0160, -0.0306, -0.0272]],
|
||||
device='cuda:0', requires_grad=True)]
|
||||
(Pdb) [p for n, p in self.model.named_parameters() if n == 'Qformer.bert.encoder.layer.8.experts.experts.0.dense1.weight']
|
||||
[Parameter containing:
|
||||
tensor([[ 0.0024, 0.0218, -0.0186, ..., -0.0178, -0.0067, 0.0820],
|
||||
[-0.0759, -0.0002, -0.0548, ..., 0.0292, 0.0531, 0.0779],
|
||||
[-0.0220, -0.0037, -0.0520, ..., -0.0426, -0.0261, -0.0357],
|
||||
...,
|
||||
[-0.0448, 0.0471, 0.0133, ..., -0.0062, -0.0217, -0.0203],
|
||||
[ 0.0532, 0.0197, 0.0320, ..., -0.0010, -0.0838, 0.0682],
|
||||
[ 0.0284, 0.0038, -0.0007, ..., -0.0305, 0.0296, 0.0056]],
|
||||
device='cuda:0', requires_grad=True)]
|
||||
(Pdb) [p for n, p in self.model.named_parameters() if n == 'Qformer.bert.encoder.layer.6.experts.experts.0.dense1.weight']
|
||||
[Parameter containing:
|
||||
tensor([[ 6.5176e-02, -4.6473e-02, -2.7396e-02, ..., 2.1774e-03,
|
||||
6.1457e-02, 1.9180e-03],
|
||||
[ 7.3707e-03, 6.1392e-02, -2.7108e-02, ..., 4.0778e-02,
|
||||
-1.9791e-02, -1.1612e-02],
|
||||
[ 2.1193e-02, -3.8323e-02, -6.0238e-02, ..., -1.4539e-02,
|
||||
9.2965e-02, 3.9153e-02],
|
||||
...,
|
||||
[ 5.3203e-03, -1.7276e-02, -3.2191e-02, ..., -1.6435e-02,
|
||||
-1.8553e-02, -2.8158e-02],
|
||||
[-6.9853e-02, 9.2719e-03, -1.8895e-03, ..., -2.6425e-02,
|
||||
1.4880e-03, 3.4505e-02],
|
||||
[-1.2168e-03, 3.7038e-02, 4.8047e-02, ..., -3.4523e-03,
|
||||
-1.3030e-05, -1.4778e-02]], device='cuda:0', requires_grad=True)]
|
Loading…
Reference in New Issue
Block a user