Route MoE Promote (Post/Pre) update 0110

This commit is contained in:
root 2024-01-10 16:56:52 +08:00 committed by wanghanzi
parent ce67e5669a
commit eb022668a3
38 changed files with 12983 additions and 491 deletions

File diff suppressed because it is too large Load Diff

Binary file not shown.

View File

@ -1,4 +1,4 @@
name: minigptv name: promptmoe
channels: channels:
- pytorch - pytorch
- defaults - defaults

View File

@ -17,14 +17,14 @@ datasets:
# md5: aa31ac474cf6250ebb81d18348a07ed8 # md5: aa31ac474cf6250ebb81d18348a07ed8
storage: storage:
- /mnt/pfs-guan-ssai/nlu/wanghanzi/data/COCO_Cap/coco_karpathy_train.json - /mnt/pfs-guan-ssai/nlu/wanghanzi/data/COCO_Cap/coco_karpathy_train.json
val: # val:
url: https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_val.json # url: https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_val.json
storage: # storage:
- /mnt/pfs-guan-ssai/nlu/wanghanzi/data/COCO_Cap/coco_karpathy_val.json # - /mnt/pfs-guan-ssai/nlu/wanghanzi/data/COCO_Cap/coco_karpathy_val.json
test: # test:
url: https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_test.json # url: https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_test.json
storage: # storage:
- /mnt/pfs-guan-ssai/nlu/wanghanzi/data/COCO_Cap/coco_karpathy_test.json # - /mnt/pfs-guan-ssai/nlu/wanghanzi/data/COCO_Cap/coco_karpathy_test.json
images: images:
storage: /mnt/pfs-guan-ssai/nlu/dingyifeng/data/COCO storage: /mnt/pfs-guan-ssai/nlu/dingyifeng/data/COCO

View File

@ -20,6 +20,7 @@ datasets:
- https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/vqav2/v2_mscoco_val2014_annotations.json - https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/vqav2/v2_mscoco_val2014_annotations.json
storage: storage:
- /mnt/pfs-guan-ssai/nlu/wanghanzi/data/VQAv2/vqa_val_eval.json - /mnt/pfs-guan-ssai/nlu/wanghanzi/data/VQAv2/vqa_val_eval.json
# - /mnt/pfs-guan-ssai/nlu/wanghanzi/data/VQAv2/vqa_train.json
- /mnt/pfs-guan-ssai/nlu/wanghanzi/data/VQAv2/answer_list.json - /mnt/pfs-guan-ssai/nlu/wanghanzi/data/VQAv2/answer_list.json
- /mnt/pfs-guan-ssai/nlu/wanghanzi/data/VQAv2/v2_OpenEnded_mscoco_val2014_questions.json - /mnt/pfs-guan-ssai/nlu/wanghanzi/data/VQAv2/v2_OpenEnded_mscoco_val2014_questions.json
- /mnt/pfs-guan-ssai/nlu/wanghanzi/data/VQAv2/v2_mscoco_val2014_annotations.json - /mnt/pfs-guan-ssai/nlu/wanghanzi/data/VQAv2/v2_mscoco_val2014_annotations.json
@ -29,6 +30,7 @@ datasets:
- https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/vqav2/answer_list.json - https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/vqav2/answer_list.json
storage: storage:
- /mnt/pfs-guan-ssai/nlu/wanghanzi/data/VQAv2/vqa_test.json - /mnt/pfs-guan-ssai/nlu/wanghanzi/data/VQAv2/vqa_test.json
# - /mnt/pfs-guan-ssai/nlu/wanghanzi/data/VQAv2/vqa_train.json
- /mnt/pfs-guan-ssai/nlu/wanghanzi/data/VQAv2/answer_list.json - /mnt/pfs-guan-ssai/nlu/wanghanzi/data/VQAv2/answer_list.json
images: images:

View File

@ -20,6 +20,7 @@ datasets:
- https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/okvqa/mscoco_val2014_annotations.json - https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/okvqa/mscoco_val2014_annotations.json
storage: storage:
- /mnt/pfs-guan-ssai/nlu/wanghanzi/data/OKVQA/okvqa_val_eval.json - /mnt/pfs-guan-ssai/nlu/wanghanzi/data/OKVQA/okvqa_val_eval.json
# - /mnt/pfs-guan-ssai/nlu/wanghanzi/data/OKVQA/okvqa_train.json
- /mnt/pfs-guan-ssai/nlu/wanghanzi/data/OKVQA/okvqa_answer_list_train.json - /mnt/pfs-guan-ssai/nlu/wanghanzi/data/OKVQA/okvqa_answer_list_train.json
- /mnt/pfs-guan-ssai/nlu/wanghanzi/data/OKVQA/OpenEnded_mscoco_val2014_questions.json - /mnt/pfs-guan-ssai/nlu/wanghanzi/data/OKVQA/OpenEnded_mscoco_val2014_questions.json
- /mnt/pfs-guan-ssai/nlu/wanghanzi/data/OKVQA/mscoco_val2014_annotations.json - /mnt/pfs-guan-ssai/nlu/wanghanzi/data/OKVQA/mscoco_val2014_annotations.json
@ -32,6 +33,7 @@ datasets:
- https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/okvqa/mscoco_val2014_annotations.json - https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/okvqa/mscoco_val2014_annotations.json
storage: storage:
- /mnt/pfs-guan-ssai/nlu/wanghanzi/data/OKVQA/okvqa_val_eval.json - /mnt/pfs-guan-ssai/nlu/wanghanzi/data/OKVQA/okvqa_val_eval.json
# - /mnt/pfs-guan-ssai/nlu/wanghanzi/data/OKVQA/okvqa_train.json
# - /mnt/pfs-guan-ssai/nlu/wanghanzi/data/OKVQA/okvqa_val_eval_part100.json # - /mnt/pfs-guan-ssai/nlu/wanghanzi/data/OKVQA/okvqa_val_eval_part100.json
- /mnt/pfs-guan-ssai/nlu/wanghanzi/data/OKVQA/okvqa_answer_list_train.json - /mnt/pfs-guan-ssai/nlu/wanghanzi/data/OKVQA/okvqa_answer_list_train.json
- /mnt/pfs-guan-ssai/nlu/wanghanzi/data/OKVQA/OpenEnded_mscoco_val2014_questions.json - /mnt/pfs-guan-ssai/nlu/wanghanzi/data/OKVQA/OpenEnded_mscoco_val2014_questions.json

View File

@ -105,6 +105,8 @@ class COCOCaptionDataset(BaseDataset, __DisplMixin):
'Using language, provide a short account of the image.', 'Using language, provide a short account of the image.',
'Use a few words to illustrate what is happening in the picture.', 'Use a few words to illustrate what is happening in the picture.',
] ]
self.source = 'coco_cap'
def __getitem__(self, index): def __getitem__(self, index):
# TODO this assumes image input, not general enough # TODO this assumes image input, not general enough
@ -118,13 +120,20 @@ class COCOCaptionDataset(BaseDataset, __DisplMixin):
image = self.vis_processor(image) image = self.vis_processor(image)
caption = self.text_processor(ann["caption"]) caption = self.text_processor(ann["caption"])
instruction = random.choice(self.instruction_pool) # instruction = random.choice(self.instruction_pool)
instruction = "<Img><ImageHere></Img> [caption] {} ".format(instruction) # instruction = "<Img><ImageHere></Img> [caption] {} ".format(instruction)
q_input = ""
llm_input = random.choice(self.instruction_pool)
return { return {
"image": image, "image": image,
"image_id": ann["image"],
"answer": caption, "answer": caption,
"instruction_input": instruction, "q_input": q_input,
"llm_input": llm_input,
"text_input": llm_input,
"text_output": caption,
"source": 'coco_cap',
} }
class CaptionEvalDataset(BaseDataset, __DisplMixin): class CaptionEvalDataset(BaseDataset, __DisplMixin):

View File

@ -31,6 +31,7 @@ class COCOCapEvalDataset(CaptionEvalDataset):
split (string): val or test split (string): val or test
""" """
super().__init__(vis_processor, text_processor, vis_root, ann_paths) super().__init__(vis_processor, text_processor, vis_root, ann_paths)
self.source = 'coco_cap'
def __getitem__(self, index): def __getitem__(self, index):
ann = self.annotation[index] ann = self.annotation[index]

View File

@ -31,7 +31,6 @@ class MultiIterLoader:
if ratios is None: if ratios is None:
ratios = [1.0] * len(loaders) ratios = [1.0] * len(loaders)
else: else:
# import pdb; pdb.set_trace()
assert len(ratios) == len(loaders) assert len(ratios) == len(loaders)
ratios = [float(ratio) / sum(ratios) for ratio in ratios] ratios = [float(ratio) / sum(ratios) for ratio in ratios]

View File

@ -12,7 +12,6 @@ from tqdm import tqdm
import torch import torch
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
import torch.backends.cudnn as cudnn import torch.backends.cudnn as cudnn
from datasets import load_dataset
import sys import sys
sys.path.append("/mnt/pfs-guan-ssai/nlu/wanghanzi/multimodal/PromptMoE") sys.path.append("/mnt/pfs-guan-ssai/nlu/wanghanzi/multimodal/PromptMoE")
@ -248,6 +247,7 @@ if 'vsr' in args.dataset:
img_path = cfg.evaluation_datasets_cfg["vsr"]["img_path"] img_path = cfg.evaluation_datasets_cfg["vsr"]["img_path"]
batch_size = cfg.evaluation_datasets_cfg["vsr"]["batch_size"] batch_size = cfg.evaluation_datasets_cfg["vsr"]["batch_size"]
max_new_tokens = cfg.evaluation_datasets_cfg["vsr"]["max_new_tokens"] max_new_tokens = cfg.evaluation_datasets_cfg["vsr"]["max_new_tokens"]
from datasets import load_dataset
annotation = load_dataset("cambridgeltl/vsr_zeroshot", split='test') annotation = load_dataset("cambridgeltl/vsr_zeroshot", split='test')
data = VSREvalData(annotation, vis_processor, img_path) data = VSREvalData(annotation, vis_processor, img_path)

View File

@ -386,17 +386,23 @@ class BertOutput(nn.Module): # Add & Norm
class FeedForward(nn.Module): class FeedForward(nn.Module):
# remove LayerNorm
def __init__(self, config): def __init__(self, config):
nn.Module.__init__(self) super().__init__()
# first layer self.dense1 = nn.Linear(config.hidden_size, config.intermediate_size)
self.intermediate_query = BertIntermediate(config) if isinstance(config.hidden_act, str):
# second layer self.intermediate_act_fn = ACT2FN[config.hidden_act]
self.output_query = BertOutput(config) else:
self.intermediate_act_fn = config.hidden_act
self.dense2 = nn.Linear(config.intermediate_size, config.hidden_size)
self.dropout = nn.Dropout(config.hidden_dropout_prob) # adjust dropout ratio 0.1->0.2
# self.dropout = nn.Dropout(0.2) # adjust dropout ratio 0.1->0.2
def forward(self, hidden_states: Tensor): def forward(self, hidden_states: Tensor):
input_tensor = hidden_states hidden_states = self.dense1(hidden_states)
intermediate_output = self.intermediate_query(hidden_states) hidden_states = self.intermediate_act_fn(hidden_states)
hidden_states = self.output_query(intermediate_output, input_tensor) hidden_states = self.dense2(hidden_states)
hidden_states = self.dropout(hidden_states)
return hidden_states return hidden_states
@ -440,6 +446,7 @@ class BertLayer(nn.Module):
) )
else: else:
self.experts = ffn self.experts = ffn
self.expert_ln = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
def forward( def forward(
self, self,
@ -494,7 +501,8 @@ class BertLayer(nn.Module):
moe_ffn_attention_input = query_attention_output[:, :query_length, :] moe_ffn_attention_input = query_attention_output[:, :query_length, :]
moe_ffn_attention_mask = attention_mask.squeeze(dim=1).squeeze(dim=1)[:, :query_length] moe_ffn_attention_mask = attention_mask.squeeze(dim=1).squeeze(dim=1)[:, :query_length]
layer_output = self.feed_forward_query_moe(moe_ffn_attention_input, moe_ffn_attention_mask) # layer_output, gate_loss, gate_load layer_output = self.feed_forward_query_moe(moe_ffn_attention_input, moe_ffn_attention_mask) # layer_output, gate_loss, gate_load
# import pdb; pdb.set_trace() # test0107
if attention_output.shape[1] > query_length: # have text input in Qformer if attention_output.shape[1] > query_length: # have text input in Qformer
layer_output_text = apply_chunking_to_forward( layer_output_text = apply_chunking_to_forward(
self.feed_forward_chunk, self.feed_forward_chunk,
@ -503,6 +511,7 @@ class BertLayer(nn.Module):
attention_output[:, query_length:, :], attention_output[:, query_length:, :],
) )
layer_output = (torch.cat([layer_output[0], layer_output_text], dim=1), layer_output[1], layer_output[2]) layer_output = (torch.cat([layer_output[0], layer_output_text], dim=1), layer_output[1], layer_output[2])
else: else:
layer_output = apply_chunking_to_forward( layer_output = apply_chunking_to_forward(
self.feed_forward_chunk, self.feed_forward_chunk,
@ -524,15 +533,14 @@ class BertLayer(nn.Module):
def feed_forward_query_moe(self, attention_output, expert_attention_mask): def feed_forward_query_moe(self, attention_output, expert_attention_mask):
if not self.use_experts: if not self.use_experts:
layer_output = self.experts(attention_output) hidden_states = self.experts(attention_output)
layer_output = self.expert_ln(hidden_states + attention_output)
return layer_output, 0.0, [] return layer_output, 0.0, []
# if not self.importance_processor.is_moe: hidden_states, gate_loss, gate_load = self.experts(
# raise RuntimeError("Need to turn the model to a MoE first.")
layer_output, gate_loss, gate_load = self.experts(
attention_output, expert_attention_mask attention_output, expert_attention_mask
) )
layer_output = self.expert_ln(hidden_states + attention_output)
return layer_output, gate_loss, gate_load return layer_output, gate_loss, gate_load
class BertEncoder(nn.Module): class BertEncoder(nn.Module):

View File

@ -46,10 +46,9 @@ from transformers.utils import logging
from transformers.models.bert.configuration_bert import BertConfig from transformers.models.bert.configuration_bert import BertConfig
from minigpt4.models.moe.utils import ( from minigpt4.models.moe.utils import (
FeedForward,
MoEModelOutput, MoEModelOutput,
MoEModelOutputWithPooling, MoEModelOutputWithPooling,
use_experts, use_experts_route,
moe_layer_judge, moe_layer_judge,
) )
from minigpt4.models.moe.route_moe_layer import RouteMoELayer from minigpt4.models.moe.route_moe_layer import RouteMoELayer
@ -378,13 +377,14 @@ class BertOutput(nn.Module): # Add & Norm
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
self.dense = nn.Linear(config.intermediate_size, config.hidden_size) self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) # 1
self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor): def forward(self, hidden_states, input_tensor):
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states) hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor) # Move LayerNorm & ResNet out of FFN After MoEFFN
hidden_states = self.LayerNorm(hidden_states + input_tensor) # 1
return hidden_states return hidden_states
@ -429,7 +429,7 @@ class BertLayer(nn.Module):
self.output_query = BertOutput(config) self.output_query = BertOutput(config)
# Add MoE FFN # Add MoE FFN
self.use_experts = use_experts(layer_num) self.use_experts = use_experts_route(layer_num)
self.layer_judge = moe_layer_judge(layer_num) self.layer_judge = moe_layer_judge(layer_num)
self.num_beams = config.moebert_num_beams self.num_beams = config.moebert_num_beams
ffn = FeedForward(config) ffn = FeedForward(config)
@ -442,10 +442,13 @@ class BertLayer(nn.Module):
num_beams=config.moebert_num_beams, num_beams=config.moebert_num_beams,
layer_judge = self.layer_judge, layer_judge = self.layer_judge,
route_method=config.route_method, route_method=config.route_method,
weight_type=config.moe_weight_type,
) )
else: else:
self.experts = ffn self.experts = ffn
# self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
def forward( def forward(
self, self,
hidden_states, hidden_states,
@ -463,8 +466,8 @@ class BertLayer(nn.Module):
self_attn_past_key_value = ( self_attn_past_key_value = (
past_key_value[:2] if past_key_value is not None else None past_key_value[:2] if past_key_value is not None else None
) )
# import pdb;pdb.set_trace() # import pdb; pdb.set_trace() # 0107test
# adjust the dimension of hidden_states, attention_mask, encoder_attention_mask and encoder_hidden_states to be the same # adjust the dimension of hidden_states, attention_mask, encoder_attention_mask and encoder_hidden_states to be the same
if self.num_beams > 1: if self.num_beams > 1:
if hidden_states.shape[0]== attention_mask.shape[0]*self.num_beams: if hidden_states.shape[0]== attention_mask.shape[0]*self.num_beams:
@ -494,10 +497,6 @@ class BertLayer(nn.Module):
present_key_value = self_attention_outputs[-1] present_key_value = self_attention_outputs[-1]
# import pdb;pdb.set_trace()
# print(self.layer_num, hidden_states.shape, attention_mask.shape)
if query_length > 0: if query_length > 0:
query_attention_output = attention_output[:, :query_length, :] query_attention_output = attention_output[:, :query_length, :]
@ -526,7 +525,8 @@ class BertLayer(nn.Module):
moe_ffn_attention_input = query_attention_output[:, :query_length, :] moe_ffn_attention_input = query_attention_output[:, :query_length, :]
moe_ffn_attention_mask = attention_mask.squeeze(dim=1).squeeze(dim=1)[:, :query_length] moe_ffn_attention_mask = attention_mask.squeeze(dim=1).squeeze(dim=1)[:, :query_length]
layer_output = self.feed_forward_query_moe(moe_ffn_attention_input, moe_ffn_attention_mask, beam_scores, expert_route) layer_output = self.feed_forward_query_moe(moe_ffn_attention_input, moe_ffn_attention_mask, beam_scores, expert_route)
# layer_output = (layer_output, beam_scores, expert_route, beam_idx) # layer_output = (layer_output, beam_scores, expert_route, beam_idx, importance_loss)
# import pdb; pdb.set_trace() # 0107test
if attention_output.shape[1] > query_length: # have text input in Qformer if attention_output.shape[1] > query_length: # have text input in Qformer
layer_output_text = apply_chunking_to_forward( layer_output_text = apply_chunking_to_forward(
@ -535,7 +535,8 @@ class BertLayer(nn.Module):
self.seq_len_dim, self.seq_len_dim,
attention_output[:, query_length:, :], attention_output[:, query_length:, :],
) )
if layer_output[0].shape[0] == layer_output_text.shape[0]*self.num_beams and self.num_beams>1: if self.layer_judge == 'first' and self.num_beams>1:
# if layer_output[0].shape[0] == layer_output_text.shape[0]*self.num_beams and self.num_beams>1:
# adjust the dimension of layer_output_text to bz*num_beams # adjust the dimension of layer_output_text to bz*num_beams
layer_output_text = self.adjust_layer_output_text(layer_output_text) layer_output_text = self.adjust_layer_output_text(layer_output_text)
@ -550,7 +551,8 @@ class BertLayer(nn.Module):
# layer_output & layer_output_text dimen_0 from bz*num_beams to bz # layer_output & layer_output_text dimen_0 from bz*num_beams to bz
layer_output, layer_output_text = self.route_moe_last_layer_top1(layer_output, layer_output_text) layer_output, layer_output_text = self.route_moe_last_layer_top1(layer_output, layer_output_text)
layer_output = (torch.cat([layer_output[0], layer_output_text], dim=1), layer_output[1], layer_output[2]) layer_output = (torch.cat([layer_output[0], layer_output_text], dim=1), layer_output[1], layer_output[2], layer_output[3],layer_output[4])
# import pdb; pdb.set_trace() # 0107test
else: else:
layer_output = apply_chunking_to_forward( layer_output = apply_chunking_to_forward(
@ -559,7 +561,7 @@ class BertLayer(nn.Module):
self.seq_len_dim, self.seq_len_dim,
attention_output, attention_output,
) )
layer_output = (layer_output, None, None) layer_output = (layer_output, None, None, None, 0.0)
outputs = (layer_output,) + outputs outputs = (layer_output,) + outputs
@ -594,24 +596,27 @@ class BertLayer(nn.Module):
beam_scores_new = beam_scores[selects] beam_scores_new = beam_scores[selects]
expert_route_new = expert_route[selects] expert_route_new = expert_route[selects]
return (hidden_states_new, beam_scores_new, expert_route_new), layer_output_text return (hidden_states_new, beam_scores_new, expert_route_new, layer_output[3], layer_output[4]), layer_output_text
def feed_forward_chunk(self, attention_output): def feed_forward_chunk(self, attention_output):
intermediate_output = self.intermediate(attention_output) intermediate_output = self.intermediate(attention_output)
layer_output = self.output(intermediate_output, attention_output) layer_output = self.output(intermediate_output, attention_output)
# layer_output = self.LayerNorm(layer_output + attention_output)
return layer_output return layer_output
def feed_forward_query_moe(self, attention_output, expert_attention_mask, beam_scores, expert_route): def feed_forward_query_moe(self, attention_output, expert_attention_mask, beam_scores, expert_route):
if not self.use_experts: if not self.use_experts:
layer_output = self.experts(attention_output) layer_output = self.experts(attention_output)
return layer_output, None, None, None # layer_output = self.LayerNorm(layer_output + attention_output)
return layer_output, None, None, None, 0.0
layer_output, beam_scores, expert_route, beam_idx = self.experts( layer_output, beam_scores, expert_route, beam_idx, importance_loss = self.experts(
attention_output, expert_attention_mask, beam_scores, expert_route attention_output, expert_attention_mask, beam_scores, expert_route
) )
return layer_output, beam_scores, expert_route, beam_idx
# layer_output = self.LayerNorm(layer_output + attention_output)
return layer_output, beam_scores, expert_route, beam_idx, importance_loss
class BertEncoder(nn.Module): class BertEncoder(nn.Module):
def __init__(self, config): def __init__(self, config):
@ -645,6 +650,7 @@ class BertEncoder(nn.Module):
next_decoder_cache = () if use_cache else None next_decoder_cache = () if use_cache else None
beam_scores=None beam_scores=None
expert_route=None expert_route=None
importance_loss = 0
for i in range(self.config.num_hidden_layers): for i in range(self.config.num_hidden_layers):
layer_module = self.layer[i] layer_module = self.layer[i]
@ -693,6 +699,7 @@ class BertEncoder(nn.Module):
hidden_states = layer_outputs[0][0] hidden_states = layer_outputs[0][0]
beam_scores = beam_scores if layer_outputs[0][1] == None else layer_outputs[0][1] beam_scores = beam_scores if layer_outputs[0][1] == None else layer_outputs[0][1]
expert_route = expert_route if layer_outputs[0][2] == None else layer_outputs[0][2] expert_route = expert_route if layer_outputs[0][2] == None else layer_outputs[0][2]
importance_loss += layer_outputs[0][4]
if use_cache: if use_cache:
next_decoder_cache += (layer_outputs[-1],) next_decoder_cache += (layer_outputs[-1],)
@ -724,6 +731,7 @@ class BertEncoder(nn.Module):
cross_attentions=all_cross_attentions, cross_attentions=all_cross_attentions,
beam_scores=beam_scores, beam_scores=beam_scores,
expert_route=expert_route, expert_route=expert_route,
gate_loss=importance_loss,
) )
@ -1103,6 +1111,7 @@ class BertModel(BertPreTrainedModel):
cross_attentions=encoder_outputs.cross_attentions, cross_attentions=encoder_outputs.cross_attentions,
beam_scores=encoder_outputs.beam_scores, beam_scores=encoder_outputs.beam_scores,
expert_route=encoder_outputs.expert_route, expert_route=encoder_outputs.expert_route,
gate_loss=encoder_outputs.gate_loss
) )

View File

@ -62,7 +62,7 @@ class Blip2Base(BaseModel):
return Qformer, query_tokens return Qformer, query_tokens
@classmethod @classmethod
def init_RouteMoEQformer(cls, num_query_token, vision_width, moebert_expert_num, moebert_num_beams, route_method, cross_attention_freq=2): def init_RouteMoEQformer(cls, num_query_token, vision_width, moebert_expert_num, moebert_num_beams, route_method, moe_weight_type, cross_attention_freq=2):
moe_encoder_config = BertConfig.from_pretrained("/mnt/pfs-guan-ssai/nlu/wanghanzi/models/bert-base-uncased") moe_encoder_config = BertConfig.from_pretrained("/mnt/pfs-guan-ssai/nlu/wanghanzi/models/bert-base-uncased")
moe_encoder_config.encoder_width = vision_width moe_encoder_config.encoder_width = vision_width
@ -74,6 +74,7 @@ class Blip2Base(BaseModel):
moe_encoder_config.moebert_expert_num = moebert_expert_num moe_encoder_config.moebert_expert_num = moebert_expert_num
moe_encoder_config.moebert_num_beams = moebert_num_beams moe_encoder_config.moebert_num_beams = moebert_num_beams
moe_encoder_config.route_method = route_method moe_encoder_config.route_method = route_method
moe_encoder_config.moe_weight_type = moe_weight_type
RouteMoEQformer = BertMoERouteLMHeadModel.from_pretrained( RouteMoEQformer = BertMoERouteLMHeadModel.from_pretrained(
"/mnt/pfs-guan-ssai/nlu/wanghanzi/models/bert-base-uncased", config=moe_encoder_config "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/bert-base-uncased", config=moe_encoder_config

View File

@ -99,6 +99,7 @@ class Blip2VicunaInstruct(Blip2Base):
moebert_expert_num=moebert_expert_num, moebert_expert_num=moebert_expert_num,
moebert_num_beams=moebert_num_beams, moebert_num_beams=moebert_num_beams,
route_method=moebert_route_method, route_method=moebert_route_method,
moe_weight_type=moe_weight_type,
cross_attention_freq=2 cross_attention_freq=2
) )
else: else:
@ -118,7 +119,6 @@ class Blip2VicunaInstruct(Blip2Base):
num_query_token, self.visual_encoder.num_features num_query_token, self.visual_encoder.num_features
) )
# import pdb;pdb.set_trace()
if not qformer_text_input: if not qformer_text_input:
self.Qformer.bert.embeddings.word_embeddings = None self.Qformer.bert.embeddings.word_embeddings = None
self.Qformer.bert.embeddings.position_embeddings = None self.Qformer.bert.embeddings.position_embeddings = None
@ -178,6 +178,19 @@ class Blip2VicunaInstruct(Blip2Base):
if "_query" in name and "experts" not in name: # raw ffn_query not update if "_query" in name and "experts" not in name: # raw ffn_query not update
param.requires_grad = False param.requires_grad = False
ln_pattern = r"bert\.encoder\.layer\.\d+\.expert_ln\.(weight|bias)"
if re.match(ln_pattern, name):
key_orig = re.sub('expert_ln', 'output_query.LayerNorm', name)
param.data.copy_(state_dict[key_orig])
d1_pattern = r"bert\.encoder\.layer\.(\d+)\.experts(\.|\.experts\.\d+\.)dense1\.(weight|bias)"
if re.match(d1_pattern, name):
key_orig = re.sub(r'experts(\.|\.experts\.\d+\.)dense1', 'intermediate_query.dense', name)
param.data.copy_(state_dict[key_orig])
d2_pattern = r"bert\.encoder\.layer\.(\d+)\.experts(\.|\.experts\.\d+\.)dense2\.(weight|bias)"
if re.match(d2_pattern, name):
key_orig = re.sub(r'experts(\.|\.experts\.\d+\.)dense2', 'output_query.dense', name)
param.data.copy_(state_dict[key_orig])
# freeze qformer # freeze qformer
if freeze_qformer: if freeze_qformer:
for name, param in self.Qformer.named_parameters(): for name, param in self.Qformer.named_parameters():
@ -205,6 +218,7 @@ class Blip2VicunaInstruct(Blip2Base):
self.use_moeqformer = use_moeqformer self.use_moeqformer = use_moeqformer
self.use_route_moe = use_route_moe self.use_route_moe = use_route_moe
self.moebert_load_balance = moebert_load_balance self.moebert_load_balance = moebert_load_balance
self.moebert_num_beams = moebert_num_beams
self.gate_save_path = gate_save_path self.gate_save_path = gate_save_path
# if self.gate_save_path != None: # if self.gate_save_path != None:
@ -242,7 +256,7 @@ class Blip2VicunaInstruct(Blip2Base):
# print(samples["text_input"]) # print(samples["text_input"])
# print(samples["text_output"]) # print(samples["text_output"])
# print('-----------------') # print('-----------------')
# import pdb;pdb.set_trace() # import pdb;pdb.set_trace() # 0107test
image = samples["image"] image = samples["image"]
with self.maybe_autocast(): with self.maybe_autocast():
image_embeds = self.ln_vision(self.visual_encoder(image)) image_embeds = self.ln_vision(self.visual_encoder(image))
@ -278,10 +292,10 @@ class Blip2VicunaInstruct(Blip2Base):
return_dict=True, return_dict=True,
output_hidden_states=True, output_hidden_states=True,
) )
# import pdb; pdb.set_trace()# 0107test
query_output_to_linear = query_output.last_hidden_state[:,:query_tokens.size(1),:] query_output_to_linear = query_output.last_hidden_state[:,:query_tokens.size(1),:]
if self.use_moeqformer and not self.use_route_moe: if self.use_moeqformer:
gate_loss = query_output.gate_loss # only available in QformerMoE gate_loss = query_output.gate_loss # only available in QformerMoE
if self.gate_save_path != None: if self.gate_save_path != None:
@ -312,7 +326,7 @@ class Blip2VicunaInstruct(Blip2Base):
# 'gate_route_1': prob_gate_normalized[0][i].tolist(), # 'gate_route_1': prob_gate_normalized[0][i].tolist(),
}) })
# for layer in [6,8,10]: # for layer in [6,8,10]:
# layer_data = all_hidden_states[layer] # layer_data = all_hidden_states[layer]s
# file_path = os.path.join(self.gate_save_path, f'{image_id}_{str(layer)}.npy') # file_path = os.path.join(self.gate_save_path, f'{image_id}_{str(layer)}.npy')
# x = layer_data.data.cpu().numpy() # x = layer_data.data.cpu().numpy()
# np.save(file_path,x) # np.save(file_path,x)
@ -323,7 +337,6 @@ class Blip2VicunaInstruct(Blip2Base):
print("Gate Save Error....") print("Gate Save Error....")
print(e) print(e)
inputs_llm = self.llm_proj(query_output_to_linear) inputs_llm = self.llm_proj(query_output_to_linear)
atts_llm = torch.ones(inputs_llm.size()[:-1], dtype=torch.long).to(image.device) atts_llm = torch.ones(inputs_llm.size()[:-1], dtype=torch.long).to(image.device)
@ -380,7 +393,7 @@ class Blip2VicunaInstruct(Blip2Base):
labels=targets, labels=targets,
) )
if self.use_moeqformer and not self.use_route_moe: if self.use_moeqformer:
loss = outputs.loss + self.moebert_load_balance * gate_loss loss = outputs.loss + self.moebert_load_balance * gate_loss
else: else:
loss = outputs.loss loss = outputs.loss
@ -441,6 +454,8 @@ class Blip2VicunaInstruct(Blip2Base):
output_hidden_states=True, output_hidden_states=True,
) )
# import pdb; pdb.set_trace()
if self.gate_save_path != None: if self.gate_save_path != None:
all_hidden_states = query_output.hidden_states all_hidden_states = query_output.hidden_states
# prob_gate_normalized = query_output.gate_loads # prob_gate_normalized = query_output.gate_loads
@ -471,11 +486,11 @@ class Blip2VicunaInstruct(Blip2Base):
# 'gate_route_3': prob_gate_normalized[2][i].tolist(), # 'gate_route_3': prob_gate_normalized[2][i].tolist(),
# 'gate_route_1': prob_gate_normalized[0][i].tolist(), # 'gate_route_1': prob_gate_normalized[0][i].tolist(),
}) })
for layer in [6,8,10]: for layer in [6,7,8,9,10,11]:
if layer == 6: if layer in [6,11]:
layer_data = all_hidden_states[layer][i, :32, :] layer_data = all_hidden_states[layer][i, :, :]
else: else:
layer_data = all_hidden_states[layer][i*3, :32, :] layer_data = all_hidden_states[layer][i*self.moebert_num_beams, :, :]
file_path = os.path.join(self.gate_save_path, f'{image_id}_{str(layer)}.npy') file_path = os.path.join(self.gate_save_path, f'{image_id}_{str(layer)}.npy')
x = layer_data.data.cpu().numpy() x = layer_data.data.cpu().numpy()
np.save(file_path,x) # 大功告成 np.save(file_path,x) # 大功告成
@ -683,5 +698,6 @@ class Blip2VicunaInstruct(Blip2Base):
for name, param in model.named_parameters(): for name, param in model.named_parameters():
if param.requires_grad == True: if param.requires_grad == True:
print(name) print(name)
# [name for name, param in model.named_parameters() if (param.requires_grad == False and 'Qformer' in name and 'intermediate_query' in name)]
# import pdb; pdb.set_trace()# 0107test
return model return model

View File

@ -21,7 +21,6 @@ class MoELayer(nn.Module):
else: else:
raise KeyError("Routing method not supported.") raise KeyError("Routing method not supported.")
def _forward_gate_sentence(self, x, attention_mask): def _forward_gate_sentence(self, x, attention_mask):
""" """
x: query_attention_output , torch.Size([bz, 32, 768]) x: query_attention_output , torch.Size([bz, 32, 768])
@ -77,7 +76,65 @@ class MoELayer(nn.Module):
print('Layer Qformer MoE: \n',prob_gate) print('Layer Qformer MoE: \n',prob_gate)
return moe_result, select_prob_gate, gate return moe_result, select_prob_gate, gate
def _forward_gate_sentence_post(self, x, attention_mask):
"""
x: query_attention_output; torch.Size([bz, 32, 768])
attention_mask: torch.ones([bz, 32])
bz = 4
x = torch.randn(bz,32,768)
attention_mask = torch.ones([bz, 32])
"""
attention_mask = torch.ones(attention_mask.shape[0], attention_mask.shape[1]).to(x.device)
x_masked = x * attention_mask.unsqueeze(-1) # torch.Size([bz, 32, 768])
def forward_expert(input_x, expert_idx):
# input_x += torch.randn(4,32,768)
# return input_x
output_x = self.experts[expert_idx].forward(input_x)
return output_x
outputs = list()
logits_gate_lst = list()
for expert_idx in range(self.num_experts):
output_x = forward_expert(x_masked, expert_idx)
outputs.append(output_x.unsqueeze(0))
output_x_aver = output_x.sum(1) / attention_mask.unsqueeze(-1).sum(1) # torch.Size([bz, 768])
# gate_acore = self.gates[expert_idx](output_x_aver)
gate_score = self.gate(output_x_aver)
logits_gate_lst.append(gate_score)
candidate_output = torch.cat(outputs) # torch.Size([num_expert, bz, 32, 768])
logits_gate = torch.cat(logits_gate_lst,dim=1)# torch.Size([bz, num_expert])
prob_gate = F.softmax(logits_gate, dim=-1) # torch.Size([bz, num_experts])
topk_values, gate = torch.topk(prob_gate, self.topk, dim=1) # gate, 每个样本被分配的expert: torch.Size([bz, topk])
num_sentences = F.one_hot(gate, self.num_experts).sum(1).gt(0).sum(0) # 每个expert被分配的样本数 torch.Size([num_expert])
gate_load = num_sentences.clone()
# load balancing loss
if self.use_balance_loss:
balance_loss = self._balancing_loss(prob_gate, num_sentences)
else:
balance_loss = 0.0
# importance loss
importance_loss = self._importance_auxiliary_loss(prob_gate)
# output_average = candidate_output.sum(2) / candidate_attn_mask.unsqueeze(-1).sum(2) # torch.Size([num_expert, bz, 768])
# output_average = torch.permute(output_average, (1, 0, 2)) # torch.Size([bz, num_expert, 768])
# logits_gate = self.gate(output_average) # torch.Size([bz, num_experts, 1])
prob_gate_topk = torch.zeros_like(prob_gate)
prob_gate_topk.scatter_(1, gate, topk_values)
prob_gate_normalized = prob_gate_topk / prob_gate_topk.sum(dim=1, keepdim=True) # torch.Size([bz, num_expert])
candidate_output_ad = torch.permute(candidate_output, (1, 0, 2, 3)) # torch.Size([bz, num_expert, 32, 768])
results = prob_gate_normalized.unsqueeze(-1).unsqueeze(-1) * candidate_output_ad # torch.Size([bz, num_expert, 32, 768])
moe_result = torch.sum(results, dim=1) # torch.Size([bz, 32, 768])
import pdb;pdb.set_trace()
return moe_result, (balance_loss+importance_loss), prob_gate_normalized
def forward(self, x, attention_mask): def forward(self, x, attention_mask):
if self.route_method == "gate-token": if self.route_method == "gate-token":
x, balance_loss, gate_load = self._forward_gate_token(x) x, balance_loss, gate_load = self._forward_gate_token(x)
@ -95,7 +152,7 @@ class MoELayer(nn.Module):
class RouteMoELayer(nn.Module): class RouteMoELayer(nn.Module):
def __init__(self, hidden_size, expert, gate, num_experts, num_beams=2, layer_judge=None, route_method="pre-route"): def __init__(self, hidden_size, expert, num_experts, num_beams=2, layer_judge=None, route_method="pre-route", weight_type="ffn_prob"):
# remove hash list # remove hash list
nn.Module.__init__(self) nn.Module.__init__(self)
self.num_experts = num_experts self.num_experts = num_experts
@ -103,13 +160,26 @@ class RouteMoELayer(nn.Module):
self.num_beams = num_beams self.num_beams = num_beams
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.layer_judge = layer_judge self.layer_judge = layer_judge
self.weight_type = weight_type
self.route_method = route_method self.route_method = route_method
if self.route_method == "pre-route": if self.route_method == "pre-route":
self.gate = nn.Linear(hidden_size, num_experts, bias=False).float() self.gate = nn.Linear(hidden_size, num_experts, bias=False).float()
elif self.route_method == "post-route": elif self.route_method == "post-route":
# gate = nn.Linear(hidden_size, 1, bias=False).float() gate = nn.Linear(hidden_size, 1, bias=False).float()
self.gates = nn.ModuleList([copy.deepcopy(gate) for i in range(num_experts)]) self.gate = gate
# self.gates = nn.ModuleList([copy.deepcopy(gate) for i in range(num_experts)])
def _importance_auxiliary_loss(self, prob_gate):
# From VMOE
# _importance_auxiliary_loss
axis = tuple(range(prob_gate.ndim - 1)) # All except last.
importance_per_expert = torch.sum(prob_gate, dim=axis)
std_importance_per_expert = torch.std(importance_per_expert)
mean_importance_per_expert = torch.mean(importance_per_expert)
# Compute coefficient of variation (i.e. std/mean) squared.
return (std_importance_per_expert / mean_importance_per_expert)**2
def forward_gate(self, x): def forward_gate(self, x):
""" """
@ -123,19 +193,21 @@ class RouteMoELayer(nn.Module):
prob_gate = F.softmax(logits_gate, dim=-1) # torch.Size([bz*num_beams, num_experts]) prob_gate = F.softmax(logits_gate, dim=-1) # torch.Size([bz*num_beams, num_experts])
return prob_gate return prob_gate
def beam_search(self, current_scores_log, beam_scores, expert_route, batch_size):
import pdb;pdb.set_trace() def beam_search_backup(self, current_scores_log, beam_scores, expert_route, batch_size):
if self.layer_judge=='first' and self.route_method=='pre-route': if self.layer_judge=='first' and self.route_method=='pre-route':
# current_scores_log torch.Size([bz, num_experts])
assert beam_scores==None and expert_route==None assert beam_scores==None and expert_route==None
current_scores = torch.exp(current_scores_log) current_scores = torch.exp(current_scores_log)
topk_values, gate = torch.topk(current_scores, self.num_beams, dim=1) # gate, 每个样本被分配的expert: torch.Size([bz, topk]) topk_values, gate = torch.topk(current_scores, self.num_beams, dim=1) # gate, 每个样本被分配的expert: torch.Size([bz, topk])
beam_scores = topk_values.view(self.num_beams * batch_size) # torch.Size([bz * num_beams]) beam_scores = topk_values.view(self.num_beams * batch_size) # torch.Size([bz * num_beams])
expert_route = gate.view(self.num_beams * batch_size).unsqueeze(1) # torch.Size([bz * num_beams,1]) expert_route = gate.view(self.num_beams * batch_size).unsqueeze(1) # torch.Size([bz * num_beams,1])
beam_idx = None beam_idx = torch.tensor(range(self.num_beams * batch_size))
else: else:
if self.layer_judge=='first' and self.route_method == 'post-route': if self.layer_judge=='first' and self.route_method == 'post-route':
batch_size = batch_size batch_size = batch_size
next_scores_raw1 = torch.exp(current_scores_log) # torch.Size([bz, num_experts]) next_scores_raw1 = torch.exp(current_scores_log) # torch.Size([bz, num_beams*num_experts])
else: else:
batch_size = int(batch_size // self.num_beams) batch_size = int(batch_size // self.num_beams)
next_scores_raw = current_scores_log + torch.log(beam_scores).unsqueeze(1) # torch.Size([4*3, 5]) # 取log 之后,可以直接相加概率 next_scores_raw = current_scores_log + torch.log(beam_scores).unsqueeze(1) # torch.Size([4*3, 5]) # 取log 之后,可以直接相加概率
@ -147,9 +219,6 @@ class RouteMoELayer(nn.Module):
next_scores, next_experts = torch.topk(next_scores_raw1, self.num_beams, dim=1, largest=True, sorted=True) next_scores, next_experts = torch.topk(next_scores_raw1, self.num_beams, dim=1, largest=True, sorted=True)
# next_scores torch.Size([bz, num_beams]) # next_scores torch.Size([bz, num_beams])
# next_tokens torch.Size([bz, num_beams]) # next_tokens torch.Size([bz, num_beams])
print(next_scores_raw1)
print(next_scores)
print(next_experts)
next_batch_beam = list() next_batch_beam = list()
for batch_idx in range(batch_size): for batch_idx in range(batch_size):
@ -166,7 +235,7 @@ class RouteMoELayer(nn.Module):
next_batch_beam.extend(next_sent_beam) next_batch_beam.extend(next_sent_beam)
import pdb;pdb.set_trace() import pdb;pdb.set_trace()
if self.layer_judge=='first' and self.route_method == 'post-route': if self.layer_judge=='first' and self.route_method == 'post-route':
beam_scores = next_scores.view(self.num_beams * batch_size) # torch.Size([bz * num_beams]) beam_scores = next_scores.view(self.num_beams * batch_size) # torch.Size([bz * num_beams])
expert_route = next_experts.view(self.num_beams * batch_size) expert_route = next_experts.view(self.num_beams * batch_size)
@ -181,33 +250,91 @@ class RouteMoELayer(nn.Module):
pre_route = expert_route[beam_idx,:] pre_route = expert_route[beam_idx,:]
expert_route = torch.cat([pre_route, beam_experts.unsqueeze(1)], dim=-1) expert_route = torch.cat([pre_route, beam_experts.unsqueeze(1)], dim=-1)
import pdb;pdb.set_trace() return beam_scores, expert_route, beam_idx
def beam_search(self, current_scores_log, beam_scores, expert_route, batch_size):
if self.layer_judge=='first' and self.route_method in ['pre-route', 'post-route']:
# current_scores_log torch.Size([bz, num_experts])
assert beam_scores==None and expert_route==None
current_scores = torch.exp(current_scores_log)
topk_values, gate = torch.topk(current_scores, self.num_beams, dim=1) # gate, 每个样本被分配的expert: torch.Size([bz, topk])
beam_scores = topk_values.view(self.num_beams * batch_size) # torch.Size([bz * num_beams])
expert_route = gate.view(self.num_beams * batch_size).unsqueeze(1) # torch.Size([bz * num_beams,1])
beam_idx = torch.tensor(range(self.num_beams * batch_size))
import pdb;pdb.set_trace()
else:
batch_size = int(batch_size // self.num_beams)
next_scores_raw = current_scores_log + torch.log(beam_scores).unsqueeze(1) # torch.Size([4*3, 5]) # 取log 之后,可以直接相加概率
next_scores_exp = torch.exp(next_scores_raw)
next_scores_raw1 = next_scores_exp.view(
batch_size, self.num_beams * self.num_experts
) # torch.Size([bz, num_beams*num_experts])
next_scores, next_experts = torch.topk(next_scores_raw1, self.num_beams, dim=1, largest=True, sorted=True)
# next_scores torch.Size([bz, num_beams])
# next_tokens torch.Size([bz, num_beams])
next_batch_beam = list()
for batch_idx in range(batch_size):
next_sent_beam = list()
for rank, (expert_id, expert_score) in enumerate(
zip(next_experts[batch_idx], next_scores[batch_idx])
):
expert_id = expert_id.item()
beam_id = expert_id // self.num_experts
ex_id = expert_id % self.num_experts
effective_beam_id = batch_idx*self.num_beams + beam_id
next_sent_beam.append((expert_score, ex_id, effective_beam_id))
next_batch_beam.extend(next_sent_beam)
# import pdb;pdb.set_trace()
beam_scores = beam_scores.new([x[0] for x in next_batch_beam])
beam_experts = expert_route[:,-1].new([x[1] for x in next_batch_beam])
beam_idx = expert_route[:,-1].new([x[2] for x in next_batch_beam])
pre_route = expert_route[beam_idx,:]
expert_route = torch.cat([pre_route, beam_experts.unsqueeze(1)], dim=-1)
print("next_scores_raw1:\n",next_scores_raw1)
return beam_scores, expert_route, beam_idx return beam_scores, expert_route, beam_idx
def forward_expert_ffn(self, x, expert_select, beam_scores):
def forward_expert_ffn(self, x, expert_select, current_scores):
""" """
x_repeat : [bz*num_beams, 32,768] x_repeat : [bz*num_beams, 32,768]
expert_select : [bz*num_beams] expert_select : [bz*num_beams]
current_scores : [bz*num_beams, num_experts] / [bz, num_experts]
""" """
# add_1212 l2_normalization # add_1228 l2_normalization
# normalized_tensor = torch.nn.functional.normalize(beam_scores, p=2, dim=0) # L2 Normalization torch.Size([bz, topk]) # normalized_tensor = torch.nn.functional.normalize(current_scores, p=2, dim=0) # L2 Normalization torch.Size([bz, topk])
# tmp_prob = normalized_tensor.unsqueeze(-1).unsqueeze(-1) # tmp_prob = normalized_tensor.unsqueeze(-1).unsqueeze(-1)
import pdb;pdb.set_trace()
outputs = list() outputs = list()
for i in range(x.shape[0]): for i in range(self.num_experts):
output_x = self.experts[expert_select[i]].forward(x[i]) output_x = self.experts[i].forward(x)
outputs.append(output_x.unsqueeze(0)) outputs.append(output_x.unsqueeze(1))
candidate_output = torch.cat(outputs) candidate_output = torch.cat(outputs, dim=1)
expert_select_matrix = F.one_hot(expert_select, self.num_experts)
# candidate_output = candidate_output * tmp_prob if self.weight_type == 'ffn_prob':
return candidate_output # torch.Size([bz*num_beams, 32, 768]) tmp_prob = current_scores * expert_select_matrix
candidate_output = candidate_output * tmp_prob.unsqueeze(-1).unsqueeze(-1)
else:
candidate_output = candidate_output * expert_select_matrix.unsqueeze(-1).unsqueeze(-1)
import pdb;pdb.set_trace()
output = torch.sum(candidate_output, dim=1)
return output # torch.Size([bz*num_beams, 32, 768])
def forward_pre_route(self, x, beam_scores, expert_route, use_log=True): def forward_pre_route(self, x, beam_scores, expert_route, use_log=True):
import pdb;pdb.set_trace()
current_scores = self.forward_gate(x) # [bz*num_beams, 5] current_scores = self.forward_gate(x) # [bz, num_beams] / [bz*num_beams, num_beams]
importance_loss = self._importance_auxiliary_loss(current_scores)
if use_log: if use_log:
current_scores_log = torch.log(current_scores) # 取log之后可以直接相加 current_scores_log = torch.log(current_scores) # 取log之后可以直接相加
@ -215,42 +342,45 @@ class RouteMoELayer(nn.Module):
current_scores_log = current_scores current_scores_log = current_scores
batch_size, num_tokens = x.shape[0], x.shape[1] batch_size, num_tokens = x.shape[0], x.shape[1]
beam_scores, expert_route, _ = self.beam_search(current_scores_log, beam_scores, expert_route, batch_size) beam_scores, expert_route, beam_idx = self.beam_search(current_scores_log, beam_scores, expert_route, batch_size)
current_expert_select = expert_route[:,-1] current_expert_select = expert_route[:,-1]
import pdb;pdb.set_trace()
if self.layer_judge=='first': # expand first dim to batch_size * num_beams if self.layer_judge=='first': # expand first dim to batch_size * num_beams
replicated_tensor = x.unsqueeze(1).expand(batch_size, self.num_beams, num_tokens, self.hidden_size) replicated_tensor = x.unsqueeze(1).expand(batch_size, self.num_beams, num_tokens, self.hidden_size)
x = replicated_tensor.contiguous().view(-1, num_tokens, self.hidden_size) # [bz*num_beams, 32,768] x = replicated_tensor.contiguous().view(-1, num_tokens, self.hidden_size) # [bz*num_beams, 32,768]
current_scores_t = current_scores.unsqueeze(1).expand(batch_size, self.num_beams, self.num_experts)
current_scores = current_scores_t.contiguous().view(-1, self.num_experts) # [bz*num_beams, num_experts]
candidate_output = self.forward_expert_ffn(x, current_expert_select, beam_scores) # [bz*num_beams, 32,768] input_x = x[beam_idx]
candidate_output = self.forward_expert_ffn(input_x, current_expert_select, current_scores) # [bz*num_beams, 32,768]
return candidate_output, beam_scores, expert_route import pdb;pdb.set_trace()
return candidate_output, beam_scores, expert_route, beam_idx, importance_loss
def forward_post_route(self, x, beam_scores, expert_route, use_log=True): def forward_post_route(self, x, beam_scores, expert_route, use_log=True):
# if self.layer_judge=='first': # expand first dim to batch_size * num_beams
# batch_size, num_tokens = x.shape[0], x.shape[1]
# replicated_tensor = x.unsqueeze(1).expand(batch_size, self.num_beams, num_tokens, self.hidden_size)
# x = replicated_tensor.contiguous().view(-1, num_tokens, self.hidden_size) # [bz*num_beams, 32,768]
attention_mask = torch.ones(x.shape[0], x.shape[1]).to(x.device) attention_mask = torch.ones(x.shape[0], x.shape[1]).to(x.device)
x_masked = x * attention_mask.unsqueeze(-1) # torch.Size([bz, 32, 768]) x_masked = x * attention_mask.unsqueeze(-1) # torch.Size([bz, 32, 768])
def forward_expert(input_x, expert_idx): def forward_expert(input_x, expert_idx):
output_x = self.experts[expert_idx].forward(input_x) output_x = self.experts[expert_idx].forward(input_x)
return output_x return output_x
import pdb; pdb.set_trace()
outputs = list() outputs = list()
logits_gate_lst = list() logits_gate_lst = list()
for expert_idx in range(self.num_experts): for expert_idx in range(self.num_experts):
output_x = forward_expert(x_masked, expert_idx) output_x = forward_expert(x_masked, expert_idx)
# output_x_aver = output_x.sum(1) / attention_mask.unsqueeze(-1).sum(1) # torch.Size([bz*num_beam, 768])
output_x_aver = torch.mean(output_x, dim=1)
# gate_score = self.gates[expert_idx](output_x_aver)
gate_score = self.gate(output_x_aver)
logits_gate_lst.append(gate_score)
outputs.append(output_x.unsqueeze(0)) outputs.append(output_x.unsqueeze(0))
output_x_aver = output_x.sum(1) / attention_mask.unsqueeze(-1).sum(1) # torch.Size([bz*num_beam, 768])
gate_acore = self.gates[expert_idx](output_x_aver) candidate_output_raw = torch.cat(outputs) # torch.Size([num_expert, bz*num_beam, 32, 768])
logits_gate_lst.append(gate_acore)
candidate_output = torch.cat(outputs) # torch.Size([num_expert, bz*num_beam, 32, 768])
logits_gate = torch.cat(logits_gate_lst,dim=1)# torch.Size([bz*num_beam, num_expert]) logits_gate = torch.cat(logits_gate_lst,dim=1)# torch.Size([bz*num_beam, num_expert])
current_scores = F.softmax(logits_gate, dim=-1) # torch.Size([bz*num_beam, num_experts]) current_scores = F.softmax(logits_gate, dim=-1) # torch.Size([bz*num_beam, num_experts])
@ -259,25 +389,39 @@ class RouteMoELayer(nn.Module):
else: else:
current_scores_log = current_scores current_scores_log = current_scores
import pdb;pdb.set_trace() # importance loss
importance_loss = self._importance_auxiliary_loss(current_scores)
# import pdb; pdb.set_trace()
batch_size = x.shape[0] # bz*num_beam batch_size, num_tokens = x.shape[0], x.shape[1] # bz*num_beam
beam_scores, expert_route, beam_idx = self.beam_search(current_scores_log, beam_scores, expert_route, batch_size) beam_scores, expert_route, beam_idx = self.beam_search(current_scores_log, beam_scores, expert_route, batch_size)
# beam_scores torch.Size([bz*num_beam]) # beam_scores torch.Size([bz*num_beam])
# expert_route torch.Size([bz*num_beam, layer_n]) # expert_route torch.Size([bz*num_beam, layer_n])
current_select_expert = expert_route[:,-1] current_select_expert = expert_route[:,-1]
# current_select_expert torch.Size([bz*num_beam, 1])
output = list() # import pdb; pdb.set_trace()
for i in range(beam_idx.shape[0]):
b_idx = beam_idx[i] if self.layer_judge == 'first':
ex_idx = current_select_expert[i] replicated_tensor = candidate_output_raw.unsqueeze(2).expand(self.num_experts, batch_size, self.num_beams, num_tokens, self.hidden_size)
ex_out = candidate_output[ex_idx, b_idx, :,:] candidate_output_raw = replicated_tensor.contiguous().view(self.num_experts, -1, num_tokens, self.hidden_size) # [bz*num_beams, 32,768]
output.append(ex_out.unsqueeze(0)) current_scores_t = current_scores.unsqueeze(1).expand(batch_size, self.num_beams, self.num_experts)
current_scores = current_scores_t.contiguous().view(-1, self.num_experts) # [bz*num_beams, num_experts]
final_output = torch.concat(output, dim=0)
candidate_output = candidate_output_raw.permute(1, 0, 2, 3)[beam_idx] # torch.Size([8, 2, 32, 768])
return final_output, beam_scores, expert_route, beam_idx expert_select_matrix = F.one_hot(current_select_expert, self.num_experts)
if self.weight_type == 'ffn_prob':
tmp_prob = current_scores[beam_idx] * expert_select_matrix
output = candidate_output * tmp_prob.unsqueeze(-1).unsqueeze(-1)
else:
output = candidate_output * expert_select_matrix.unsqueeze(-1).unsqueeze(-1)
final_output = torch.sum(output, dim=1)
import pdb; pdb.set_trace()
print("current_scores:\n",current_scores)
return final_output, beam_scores, expert_route, beam_idx, importance_loss
def forward(self, x, attention_mask, beam_scores, expert_route, use_log=True): def forward(self, x, attention_mask, beam_scores, expert_route, use_log=True):
""" """
@ -286,13 +430,12 @@ class RouteMoELayer(nn.Module):
""" """
if self.route_method == 'pre-route': if self.route_method == 'pre-route':
candidate_output, beam_scores, expert_route, _ = self.forward_pre_route(x, beam_scores, expert_route, use_log=True) candidate_output, beam_scores, expert_route, beam_idx, importance_loss = self.forward_pre_route(x, beam_scores, expert_route, use_log=True)
elif self.route_method == "post-route": elif self.route_method == "post-route":
candidate_output, beam_scores, expert_route, beam_idx = self.forward_post_route(x, beam_scores, expert_route, use_log=True) candidate_output, beam_scores, expert_route, beam_idx, importance_loss = self.forward_post_route(x, beam_scores, expert_route, use_log=True)
return candidate_output, beam_scores, expert_route, beam_idx return candidate_output, beam_scores, expert_route, beam_idx, importance_loss
if __name__ == '__main__': if __name__ == '__main__':
import sys import sys
@ -314,8 +457,8 @@ if __name__ == '__main__':
config.add_cross_attention = True config.add_cross_attention = True
config.cross_attention_freq = cross_attention_freq config.cross_attention_freq = cross_attention_freq
config.query_length = num_query_token config.query_length = num_query_token
config.moebert_expert_num = 3 config.moebert_expert_num = 2
config.moebert_num_beams = 3 config.moebert_num_beams = 2
config.moebert_route_method = 'gate-sentence' config.moebert_route_method = 'gate-sentence'
config.moe_topk = 2 config.moe_topk = 2
config.use_balance_loss = False config.use_balance_loss = False
@ -332,40 +475,46 @@ if __name__ == '__main__':
for layer_num in [6, 8, 10]: for layer_num in [6, 8, 10]:
layer_judge = moe_layer_judge(layer_num) layer_judge = moe_layer_judge(layer_num)
ffn = FeedForward(config) ffn = FeedForward(config)
gate = nn.Linear(768, config.moebert_expert_num, bias=False).float()
# experts = RouteMoELayer( # experts = RouteMoELayer(
# hidden_size=768, # hidden_size=768,
# expert=ffn, # expert=ffn,
# gate = gate,
# num_experts=config.moebert_expert_num, # num_experts=config.moebert_expert_num,
# num_beams=config.moebert_num_beams, # num_beams=config.moebert_num_beams,
# layer_judge = layer_judge, # layer_judge = layer_judge,
# route_method = "pre-route" # route_method = "pre-route",
# weight_type="no_ffn_prob"
# ) # )
# layer_output = experts(x, None, beam_scores, expert_route) # layer_output = experts(x, None, beam_scores, expert_route)
# hidden_states1, beam_scores, expert_route,_ = layer_output # hidden_states1, beam_scores, expert_route, beam_idx, importance_loss = layer_output
# print(beam_scores) # print(beam_scores)
# print(expert_route) # print(expert_route)
# print(beam_idx)
# print(importance_loss)
# x = hidden_states1
gate1 = nn.Linear(768, 1, bias=False).float() gate1 = nn.Linear(768, 1, bias=False).float()
experts_post = RouteMoELayer( experts_post = RouteMoELayer(
hidden_size=768, hidden_size=768,
expert=ffn, expert=ffn,
gate = gate1,
num_experts=config.moebert_expert_num, num_experts=config.moebert_expert_num,
num_beams=config.moebert_num_beams, num_beams=config.moebert_num_beams,
layer_judge = layer_judge, layer_judge = layer_judge,
route_method = "post-route" route_method = "post-route",
weight_type="ffn_prob"
) )
layer_output = experts_post(x1, None, beam_scores1, expert_route1, False) layer_output = experts_post(x1, None, beam_scores1, expert_route1, False)
hidden_states2, beam_scores1, expert_route1, beam_idx = layer_output hidden_states2, beam_scores1, expert_route1, beam_idx, importance_loss = layer_output
print(beam_scores1) print(beam_scores1)
print(expert_route1) print(expert_route1)
print(beam_idx) print(beam_idx)
print(importance_loss)
x1 = hidden_states2
# gate = nn.Linear(768, config.moebert_expert_num, bias=False).float()
# experts_moe = MoELayer( # experts_moe = MoELayer(
# hidden_size=config.hidden_size, # hidden_size=config.hidden_size,
# expert=ffn, # expert=ffn,
@ -382,11 +531,62 @@ if __name__ == '__main__':
# print(select_prob_gate) # print(select_prob_gate)
# print(gate_load) # print(gate_load)
# x = hidden_states1
x1 = hidden_states2
# x2 = hidden_states3 # x2 = hidden_states3
print("------------------------------------") print("------------------------------------")
import pdb; pdb.set_trace()
def forward_post_route_backup(self, x, beam_scores, expert_route, use_log=True):
attention_mask = torch.ones(x.shape[0], x.shape[1]).to(x.device)
x_masked = x * attention_mask.unsqueeze(-1) # torch.Size([bz, 32, 768])
def forward_expert(input_x, expert_idx):
output_x = self.experts[expert_idx].forward(input_x)
return output_x
outputs = list()
logits_gate_lst = list()
for expert_idx in range(self.num_experts):
output_x = forward_expert(x_masked, expert_idx)
outputs.append(output_x.unsqueeze(0))
# output_x_aver = output_x.sum(1) / attention_mask.unsqueeze(-1).sum(1) # torch.Size([bz*num_beam, 768])
# gate_score = self.gates[expert_idx](output_x_aver)
output_x_aver = torch.mean(output_x, dim=1)
gate_score = self.gate(output_x_aver)
logits_gate_lst.append(gate_score)
candidate_output = torch.cat(outputs) # torch.Size([num_expert, bz*num_beam, 32, 768])
logits_gate = torch.cat(logits_gate_lst,dim=1)# torch.Size([bz*num_beam, num_expert])
current_scores = F.softmax(logits_gate, dim=-1) # torch.Size([bz*num_beam, num_experts])
if use_log:
current_scores_log = torch.log(current_scores) # 取log之后可以直接相加
else:
current_scores_log = current_scores
# importance loss
importance_loss = self._importance_auxiliary_loss(current_scores)
batch_size = x.shape[0] # bz*num_beam
beam_scores, expert_route, beam_idx = self.beam_search(current_scores_log, beam_scores, expert_route, batch_size)
# beam_scores torch.Size([bz*num_beam])
# expert_route torch.Size([bz*num_beam, layer_n])
current_select_expert = expert_route[:,-1]
# current_select_expert torch.Size([bz*num_beam, 1])
output = list()
for i in range(beam_idx.shape[0]):
b_idx = beam_idx[i]
ex_idx = current_select_expert[i]
ex_out = candidate_output[ex_idx, b_idx, :,:]
if self.weight_type == 'ffn_prob':
prob = current_scores[b_idx, ex_idx]
ex_out = ex_out*prob
output.append(ex_out.unsqueeze(0))
final_output = torch.concat(output, dim=0)
# import pdb;pdb.set_trace()
return final_output, beam_scores, expert_route, beam_idx, importance_loss

View File

@ -1,155 +0,0 @@
import torch
import copy
import pickle
import torch
import torch.nn as nn
import torch.nn.functional as F
device = torch.device("cuda:2" if torch.cuda.is_available() else "cpu")
def forward_expert(input_x, expert_idx):
input_x += torch.randn(32,768)
return input_x
# output_x = self.experts[expert_idx].forward(input_x)
# return output_x
def forward_ffn(x_repeat, expert_select):
"""
x_repeat : [bz*num_beams, 32,768]
expert_select : [bz*num_beams]
"""
outputs = list()
num_beams_bz = x_repeat.shape[0]
for i in range(num_beams_bz):
output_x = forward_expert(x_repeat[i], expert_select[i]) # (32,768)
outputs.append(output_x.unsqueeze(0))
candidate_output = torch.cat(outputs)
return candidate_output # torch.Size([bz*num_beams, 32, 768])
def forward_gate(x, num_expert):
"""
x : torch.Size([bz*num_beams, 32, 768]) or torch.Size([bz, 32, 768])
prob_gate : torch.Size([bz*num_beams, num_experts]) or torch.Size([bz, num_experts])
"""
# attention_mask = torch.ones(x.shape[0], x.shape[1]).to(x.device)
# x_masked = x * attention_mask.unsqueeze(-1) # torch.Size([bz*num_beams, 32, 768])
# x_average = x_masked.sum(1) / attention_mask.unsqueeze(-1).sum(1) # torch.Size([bz*num_beams, 768])
# logits_gate = gate(x_average) # torch.Size([bz, num_experts])
logits_gate = torch.randn(x.shape[0], num_expert)
prob_gate = F.softmax(logits_gate, dim=-1) # torch.Size([bz*num_beams, num_experts])
return prob_gate
def beam_search(layer, current_scores, beam_scores, expert_route, num_beams):
if layer == 0 and beam_scores==None and expert_route==None:
topk_values, gate = torch.topk(current_scores, num_beams, dim=1) # gate, 每个样本被分配的expert: torch.Size([bz, topk])
beam_scores = topk_values.view(num_beams*batch_size) # torch.Size([bz * num_beams])
expert_route = gate.view(num_beams*batch_size).unsqueeze(1) # torch.Size([bz * num_beams])
else:
next_scores_raw = current_scores + beam_scores.unsqueeze(1) # torch.Size([4*3, 5]) # 取log 之后,可以直接相加概率
next_scores_raw1 = next_scores_raw.view(
batch_size, num_beams * num_expert
) # torch.Size([4, 3*5])
next_scores, next_experts = torch.topk(next_scores_raw1, num_beams, dim=1, largest=True, sorted=True)
# next_scores torch.Size([4, 3*num_beams])
# next_tokens torch.Size([4, 3*num_beams])
next_batch_beam = list()
for batch_idx in range(batch_size):
next_sent_beam = list()
print(batch_idx)
for rank, (expert_id, expert_score) in enumerate(
zip(next_experts[batch_idx], next_scores[batch_idx])
):
expert_id = expert_id.item()
beam_id = expert_id // num_expert
ex_id = expert_id % num_expert
effective_beam_id = batch_idx*num_beams + beam_id
# print(expert_id, beam_id, ex_id, effective_beam_id, expert_score)
next_sent_beam.append((expert_score, ex_id, effective_beam_id))
next_batch_beam.extend(next_sent_beam)
# print()
import pdb;pdb.set_trace()
beam_scores = beam_scores.new([x[0] for x in next_batch_beam])
beam_experts = expert_route[:,-1].new([x[1] for x in next_batch_beam])
beam_idx = expert_route[:,-1].new([x[2] for x in next_batch_beam])
pre_route = expert_route[beam_idx,:]
expert_route = torch.cat([pre_route, beam_experts.unsqueeze(1)], dim=-1)
return beam_scores, expert_route
if __name__ == '__main__':
batch_size = 3
num_beams = 2
num_expert = 5
x = torch.randn(batch_size, 32, 768)
beam_scores, expert_route = None, None
for layer in range(0,3):
# import pdb;pdb.set_trace()
current_scores = forward_gate(x, num_expert)
import pdb;pdb.set_trace()
beam_scores, expert_route = beam_search(layer, current_scores, beam_scores, expert_route, num_beams)
current_expert_select = expert_route[:,-1]
if layer == 0:
replicated_tensor = x.unsqueeze(1).expand(batch_size, num_beams, 32, 768)
x = replicated_tensor.contiguous().view(-1, 32, 768) # [12,32,768] [bz*num_beams, 32,768]
else:
x = candidate_output
candidate_output = forward_ffn(x, current_expert_select) # torch.Size([4*3, 5])
x = candidate_output
scores = beam_scores.view(batch_size, num_beams)
topk_values, gate = torch.topk(scores, 1, dim=1)
# gate [batch_size, 1]
# topk_values [batch_size, 1]
selects = [ (bz_idx * num_beams + gate[bz_idx].item()) for bz_idx in range(batch_size)]
final_scores = beam_scores[selects]
final_expert_route = expert_route[selects]
final_output = candidate_output[selects]
# def forward_ffn_post(x_repeat, expert_select):
# """
# x_repeat : [bz*num_beams, 32,768]
# expert_select : [bz*num_beams]
# prob_gate : torch.Size([bz*num_beams, num_experts])
# """
# outputs = list()
# logits_gate_lst = list()
# # attention_mask = torch.ones([batch_size, 32])
# for i in range(num_beams*batch_size):
# output_x = forward_expert(x_repeat[i], expert_select[i]) # (32,768)
# outputs.append(output_x.unsqueeze(0))
# # output_x_aver = output_x.sum(1) / attention_mask.unsqueeze(-1).sum(1) # torch.Size([bz, 768])
# # gate_acore = self.gates[expert_idx](output_x_aver)
# # gate_score = self.gate(output_x_aver)
# num_expert = 5
# gate_score = torch.randn(1,num_expert)
# logits_gate_lst.append(gate_score)
# candidate_output = torch.cat(outputs) # torch.Size([bz*num_beams, 32, 768])
# logits_gate = torch.cat(logits_gate_lst,dim=0)# torch.Size([bz*num_beams, num_expert])
# prob_gate = F.softmax(logits_gate, dim=-1) # torch.Size([bz*num_beams, num_experts])
# return prob_gate, candidate_output

View File

@ -5,7 +5,7 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
class MoELayer(nn.Module): class MoELayer(nn.Module):
def __init__(self, hidden_size, expert, num_experts, route_method, topk=1, use_balance_loss=True, weight_type='l2_norm'): def __init__(self, hidden_size, expert, num_experts, route_method, topk=1, use_balance_loss=True, weight_type='raw_prob'):
# remove hash list # remove hash list
nn.Module.__init__(self) nn.Module.__init__(self)
self.num_experts = num_experts self.num_experts = num_experts
@ -81,54 +81,6 @@ class MoELayer(nn.Module):
return x, balance_loss, gate_load return x, balance_loss, gate_load
def _forward_gate_sentence_top1_raw(self, x, attention_mask):
"""
x: query_attention_output , torch.Size([bz, 32, 768])
attention_mask: torch.ones([bz, 32])
### Notice:
the raw version of expert_attention_mask is the extended_attention_mask,
which will be add to attention_score directly
the values of extended_attention_mask are -0.0 or -10000
it should be adjust to 1/0 version to be processed by experts
"""
attention_mask = torch.ones(attention_mask.shape[0], attention_mask.shape[1]).to(x.device)
x_masked = x * attention_mask.unsqueeze(-1) # torch.Size([bz, 32, 768])
x_average = x_masked.sum(1) / attention_mask.unsqueeze(-1).sum(1) # torch.Size([bz, 768])
logits_gate = self.gate(x_average) # torch.Size([bz, num_experts])
prob_gate = F.softmax(logits_gate, dim=-1) # torch.Size([bz, num_experts])
gate = torch.argmax(prob_gate, dim=-1) # torch.Size([bz])
order = gate.argsort(0)
num_sentences = F.one_hot(gate, self.num_experts).gt(0).sum(0)
gate_load = num_sentences.clone()
x = x[order] # reorder according to expert number
x = x.split(num_sentences.tolist(), dim=0) # a list of length self.num_experts
# compute the load balancing loss
P = prob_gate.mean(0)
temp = num_sentences.float()
f = temp / temp.sum(0, keepdim=True)
balance_loss = self.num_experts * torch.sum(P * f)
prob_gate = prob_gate.gather(dim=1, index=gate.unsqueeze(1))
prob_gate = prob_gate[order]
prob_gate = prob_gate.split(num_sentences.tolist(), dim=0)
def forward_expert(input_x, prob_x, expert_idx):
input_x = self.experts[expert_idx].forward(input_x)
input_x = input_x * prob_x.unsqueeze(-1)
return input_x
result = []
for i in range(self.num_experts):
if x[i].size(0) > 0:
result.append(forward_expert(x[i], prob_gate[i], i))
result = torch.vstack(result)
result = result[order.argsort(0)] # restore original order
return result, balance_loss, gate_load
def _forward_gate_sentence_post(self, x, attention_mask): def _forward_gate_sentence_post(self, x, attention_mask):
""" """
x: query_attention_output; torch.Size([bz, 32, 768]) x: query_attention_output; torch.Size([bz, 32, 768])
@ -174,13 +126,17 @@ class MoELayer(nn.Module):
# importance loss # importance loss
importance_loss = self._importance_auxiliary_loss(prob_gate) importance_loss = self._importance_auxiliary_loss(prob_gate)
# output_average = candidate_output.sum(2) / candidate_attn_mask.unsqueeze(-1).sum(2) # torch.Size([num_expert, bz, 768])
# output_average = torch.permute(output_average, (1, 0, 2)) # torch.Size([bz, num_expert, 768])
# logits_gate = self.gate(output_average) # torch.Size([bz, num_experts, 1])
prob_gate_topk = torch.zeros_like(prob_gate) prob_gate_topk = torch.zeros_like(prob_gate)
prob_gate_topk.scatter_(1, gate, topk_values) prob_gate_topk.scatter_(1, gate, topk_values)
prob_gate_normalized = prob_gate_topk / prob_gate_topk.sum(dim=1, keepdim=True) # torch.Size([bz, num_expert])
if self.weight_type == 'average':
# torch.Size([bz, num_expert]) 未选中的expert prob_gate_norm为0
prob_gate_normalized = prob_gate_topk / prob_gate_topk.sum(dim=1, keepdim=True)
elif self.weight_type == 'raw_prob':
prob_gate_normalized = prob_gate_topk
elif self.weight_type == 'softmax_norm':
prob_gate_normalized = F.softmax(prob_gate_topk, dim=-1) # torch.Size([bz, num_expert])
candidate_output_ad = torch.permute(candidate_output, (1, 0, 2, 3)) # torch.Size([bz, num_expert, 32, 768]) candidate_output_ad = torch.permute(candidate_output, (1, 0, 2, 3)) # torch.Size([bz, num_expert, 32, 768])
results = prob_gate_normalized.unsqueeze(-1).unsqueeze(-1) * candidate_output_ad # torch.Size([bz, num_expert, 32, 768]) results = prob_gate_normalized.unsqueeze(-1).unsqueeze(-1) * candidate_output_ad # torch.Size([bz, num_expert, 32, 768])
moe_result = torch.sum(results, dim=1) # torch.Size([bz, 32, 768]) moe_result = torch.sum(results, dim=1) # torch.Size([bz, 32, 768])
@ -188,6 +144,46 @@ class MoELayer(nn.Module):
return moe_result, (balance_loss+importance_loss), prob_gate_normalized return moe_result, (balance_loss+importance_loss), prob_gate_normalized
def router(self, x, attention_mask):
# Prepare input x
attention_mask = torch.ones(attention_mask.shape[0], attention_mask.shape[1]).to(x.device)
x_masked = x * attention_mask.unsqueeze(-1) # torch.Size([bz, 32, 768])
x_average = torch.mean(x_masked, dim=1) # torch.Size([bz, 768])
# Forward Gate
# logits_gate: [bz, num_experts]
logits_gate = self.gate(x_average)
# Probabilities for each sample of what expert it should be sent to.
# prob_gate: [bz, num_experts]
prob_gate = F.softmax(logits_gate, dim=-1)
# Get Top-K experts for each sample
# gate: [bz, topk]
# select_prob_gate: [bz, topk]
select_prob_gate, gate = torch.topk(prob_gate, self.topk, dim=1)
# Reshap Prob_gate & Gate
# expert_mask: [batch_size, topk, num_experts]
# expert_gate: [batch_size, topk, num_experts]
# combine_tensor: [batch_size, num_experts]
expert_mask = F.one_hot(gate, self.num_experts)
expert_gate = select_prob_gate.unsqueeze(-1) * expert_mask
combine_tensor = torch.sum(expert_gate, dim=1)
# Calculate Balancing Loss
if self.use_balance_loss:
num_sentences = F.one_hot(gate, self.num_experts).sum(1).gt(0).sum(0) # 每个expert被分配的样本数 torch.Size([num_expert])
balance_loss = self._balancing_loss(prob_gate, num_sentences)
else:
balance_loss = 0.0
# Calculate Importance Loss
importance_loss = self._importance_auxiliary_loss(prob_gate)
# import pdb; pdb.set_trace()
return expert_mask, combine_tensor, balance_loss, importance_loss
def _forward_gate_sentence(self, x, attention_mask): def _forward_gate_sentence(self, x, attention_mask):
""" """
@ -200,81 +196,37 @@ class MoELayer(nn.Module):
the values of extended_attention_mask are -0.0 or -10000 the values of extended_attention_mask are -0.0 or -10000
it should be adjust to 1/0 version to be processed by experts it should be adjust to 1/0 version to be processed by experts
""" """
attention_mask = torch.ones(attention_mask.shape[0], attention_mask.shape[1]).to(x.device) # Forward Router
x_masked = x * attention_mask.unsqueeze(-1) # torch.Size([bz, 32, 768]) expert_mask, combine_tensor, balance_loss, importance_loss = self.router(x, attention_mask)
x_average = x_masked.sum(1) / attention_mask.unsqueeze(-1).sum(1) # torch.Size([bz, 768])
logits_gate = self.gate(x_average) # torch.Size([bz, num_experts]) # Forward Expert FFN
prob_gate = F.softmax(logits_gate, dim=-1) # torch.Size([bz, num_experts]) result = []
select_prob_gate, gate = torch.topk(prob_gate, self.topk, dim=1) # gate, 每个样本被分配的expert: torch.Size([bz, topk]) for expert_idx in range(self.num_experts):
output_x = self.experts[expert_idx].forward(x)
result.append(output_x.unsqueeze(0))
expert_output = torch.cat(result).permute(1,0,2,3) # torch.Size([batch_size, num_expert, num_tokens, hidden_states])
# 这里用l2 norm 去加权 # multiply outputs of experts by the routing probability
if self.weight_type == 'l2_norm': if self.weight_type == 'raw_prob':
normalized_tensor = torch.nn.functional.normalize(select_prob_gate, p=2, dim=0) # L2 Normalization torch.Size([bz, topk]) expert_outputs_combined = expert_output * combine_tensor.unsqueeze(-1).unsqueeze(-1) # torch.Size([batch_size, num_expert, num_tokens, hidden_states])
elif self.weight_type == 'average': elif self.weight_type == 'no_prob':
normalized_tensor = select_prob_gate / select_prob_gate.sum(dim=1, keepdim=True) combine_index = torch.sum(expert_mask, dim=1)
expert_outputs_combined = expert_output * combine_index.unsqueeze(-1).unsqueeze(-1) # torch.Size([batch_size, num_expert, num_tokens, hidden_states])
num_sentences = F.one_hot(gate, self.num_experts).sum(1).gt(0).sum(0) # 每个expert被分配的样本数 torch.Size([num_expert]) outputs = torch.sum(expert_outputs_combined, dim=1) # torch.Size([batch_size, num_tokens, hidden_states])
gate_load = num_sentences.clone()
# load balancing loss # import pdb; pdb.set_trace()
if self.use_balance_loss:
balance_loss = self._balancing_loss(prob_gate, num_sentences)
else:
balance_loss = 0.0
# importance loss return outputs, (balance_loss+importance_loss), combine_tensor
importance_loss = self._importance_auxiliary_loss(prob_gate)
# forward experts
def forward_expert(input_x, expert_idx):
input_x = self.experts[expert_idx].forward(input_x)
return input_x
result_lst = list()
for i in range(self.topk):
# top1、top2... 分别为一组进行gate分组之后过expert然后乘以概率后相加
tmp_gate = gate[:,i]
tmp_prob = normalized_tensor[:,i].unsqueeze(-1).unsqueeze(-1)
order = tmp_gate.argsort(0)
num_sentences_t = F.one_hot(tmp_gate, self.num_experts).gt(0).sum(0)
x1 = x[order] # reorder according to expert number
x1 = x1.split(num_sentences_t.tolist(), dim=0) # a list of length self.num_experts
result = []
for i in range(self.num_experts):
if x1[i].size(0) > 0:
result.append(forward_expert(x1[i], i))
result = torch.vstack(result)
result = result[order.argsort(0)] # restore original order
# result_lst.append(result * tmp_prob) # result * prob
result_lst.append(result) # result * prob # add_1212
moe_result = sum(result_lst)
# import pdb;pdb.set_trace()
return moe_result, (balance_loss+importance_loss), gate
def _forward_sentence_single_expert(self, x, attention_mask):
x_masked = x * attention_mask.unsqueeze(-1)
x_average = x_masked.sum(1) / attention_mask.unsqueeze(-1).sum(1)
logits_gate = self.gate(x_average)
prob_gate = F.softmax(logits_gate, dim=-1)
gate = torch.argmax(prob_gate, dim=-1)
gate_load = F.one_hot(gate, self.num_experts).gt(0).sum(0)
x = self.experts[gate.cpu().item()].forward(x)
return x, 0.0, gate_load
def forward(self, x, attention_mask): def forward(self, x, attention_mask):
if self.route_method == "gate-token": if self.route_method == "gate-token":
x, balance_loss, gate_load = self._forward_gate_token(x) x, balance_loss, gate_load = self._forward_gate_token(x)
elif self.route_method == "gate-sentence": elif self.route_method == "gate-sentence":
if x.size(0) == 1: x, balance_loss, gate_load = self._forward_gate_sentence(x, attention_mask)
x, balance_loss, gate_load = self._forward_sentence_single_expert(x, attention_mask)
else:
x, balance_loss, gate_load = self._forward_gate_sentence(x, attention_mask)
elif self.route_method == "gate-sentence-post": elif self.route_method == "gate-sentence-post":
x, balance_loss, gate_load = self._forward_gate_sentence_post(x, attention_mask) x, balance_loss, gate_load = self._forward_gate_sentence_post(x, attention_mask)
else: else:
raise KeyError("Routing method not supported.") raise KeyError("Routing method not supported.")
# import pdb; pdb.set_trace()
return x, balance_loss, gate_load return x, balance_loss, gate_load

View File

@ -0,0 +1,330 @@
import copy
import pickle
import torch
import torch.nn as nn
import torch.nn.functional as F
class MoELayer(nn.Module):
def __init__(self, hidden_size, expert, num_experts, route_method, topk=1, use_balance_loss=True, weight_type='l2_norm'):
# remove hash list
nn.Module.__init__(self)
self.num_experts = num_experts
self.experts = nn.ModuleList([copy.deepcopy(expert) for i in range(num_experts)])
self.route_method = route_method
self.topk = topk
self.use_balance_loss = use_balance_loss
self.weight_type = weight_type
if route_method in ["gate-token", "gate-sentence"]:
self.gate = nn.Linear(hidden_size, num_experts, bias=False).float()
elif route_method in ["gate-sentence-post"]:
gate = nn.Linear(hidden_size, 1, bias=False).float()
# self.gates = nn.ModuleList([copy.deepcopy(gate) for i in range(num_experts)])
self.gate = gate
else:
raise KeyError("Routing method not supported.")
def _balancing_loss(self, prob_gate, num_tokens):
# From MOEBERT
# compute the load balancing loss
# prob_gate是 [bz, num_expert]每个样本被分配给每个expert的概率
# 等价于 VMOE 中 _gshard_auxiliary_loss
P = prob_gate.mean(0) # torch.Size([num_expert]) 每个expert被分配到样本的平均概率
temp = num_tokens.float()
f = temp / temp.sum(0, keepdim=True) # 每个expert被分配的sample比例
balance_loss = self.num_experts * torch.sum(P * f)
return balance_loss
def _importance_auxiliary_loss(self, prob_gate):
# From VMOE
# _importance_auxiliary_loss
axis = tuple(range(prob_gate.ndim - 1)) # All except last.
importance_per_expert = torch.sum(prob_gate, dim=axis)
std_importance_per_expert = torch.std(importance_per_expert)
mean_importance_per_expert = torch.mean(importance_per_expert)
# Compute coefficient of variation (i.e. std/mean) squared.
return (std_importance_per_expert / mean_importance_per_expert)**2
def _forward_gate_token(self, x):
bsz, seq_len, dim = x.size()
x = x.view(-1, dim)
logits_gate = self.gate(x)
prob_gate = F.softmax(logits_gate, dim=-1)
gate = torch.argmax(prob_gate, dim=-1)
order = gate.argsort(0)
num_tokens = F.one_hot(gate, self.num_experts).gt(0).sum(0)
gate_load = num_tokens.clone()
x = x[order] # reorder according to expert number
x = x.split(num_tokens.tolist(), dim=0) # a list of length self.num_experts
# compute the load balancing loss
P = prob_gate.mean(0)
temp = num_tokens.float()
f = temp / temp.sum(0, keepdim=True)
balance_loss = self.num_experts * torch.sum(P * f)
prob_gate = prob_gate.gather(dim=1, index=gate.unsqueeze(1))
prob_gate = prob_gate[order]
prob_gate = prob_gate.split(num_tokens.tolist(), dim=0)
def forward_expert(input_x, prob_x, expert_idx):
input_x = self.experts[expert_idx].forward(input_x)
input_x = input_x * prob_x
return input_x
x = [forward_expert(x[i], prob_gate[i], i) for i in range(self.num_experts)]
x = torch.vstack(x)
x = x[order.argsort(0)] # restore original order
x = x.view(bsz, seq_len, dim)
return x, balance_loss, gate_load
def _forward_gate_sentence_top1_raw(self, x, attention_mask):
"""
x: query_attention_output , torch.Size([bz, 32, 768])
attention_mask: torch.ones([bz, 32])
### Notice:
the raw version of expert_attention_mask is the extended_attention_mask,
which will be add to attention_score directly
the values of extended_attention_mask are -0.0 or -10000
it should be adjust to 1/0 version to be processed by experts
"""
attention_mask = torch.ones(attention_mask.shape[0], attention_mask.shape[1]).to(x.device)
x_masked = x * attention_mask.unsqueeze(-1) # torch.Size([bz, 32, 768])
x_average = x_masked.sum(1) / attention_mask.unsqueeze(-1).sum(1) # torch.Size([bz, 768])
logits_gate = self.gate(x_average) # torch.Size([bz, num_experts])
prob_gate = F.softmax(logits_gate, dim=-1) # torch.Size([bz, num_experts])
gate = torch.argmax(prob_gate, dim=-1) # torch.Size([bz])
order = gate.argsort(0)
num_sentences = F.one_hot(gate, self.num_experts).gt(0).sum(0)
gate_load = num_sentences.clone()
x = x[order] # reorder according to expert number
x = x.split(num_sentences.tolist(), dim=0) # a list of length self.num_experts
# compute the load balancing loss
P = prob_gate.mean(0)
temp = num_sentences.float()
f = temp / temp.sum(0, keepdim=True)
balance_loss = self.num_experts * torch.sum(P * f)
prob_gate = prob_gate.gather(dim=1, index=gate.unsqueeze(1))
prob_gate = prob_gate[order]
prob_gate = prob_gate.split(num_sentences.tolist(), dim=0)
def forward_expert(input_x, prob_x, expert_idx):
input_x = self.experts[expert_idx].forward(input_x)
input_x = input_x * prob_x.unsqueeze(-1)
return input_x
result = []
for i in range(self.num_experts):
if x[i].size(0) > 0:
result.append(forward_expert(x[i], prob_gate[i], i))
result = torch.vstack(result)
result = result[order.argsort(0)] # restore original order
return result, balance_loss, gate_load
def _forward_gate_sentence_post(self, x, attention_mask):
"""
x: query_attention_output; torch.Size([bz, 32, 768])
attention_mask: torch.ones([bz, 32])
bz = 4
x = torch.randn(bz,32,768)
attention_mask = torch.ones([bz, 32])
"""
attention_mask = torch.ones(attention_mask.shape[0], attention_mask.shape[1]).to(x.device)
x_masked = x * attention_mask.unsqueeze(-1) # torch.Size([bz, 32, 768])
def forward_expert(input_x, expert_idx):
# input_x += torch.randn(4,32,768)
# return input_x
output_x = self.experts[expert_idx].forward(input_x)
return output_x
outputs = list()
logits_gate_lst = list()
for expert_idx in range(self.num_experts):
output_x = forward_expert(x_masked, expert_idx)
outputs.append(output_x.unsqueeze(0))
output_x_aver = output_x.sum(1) / attention_mask.unsqueeze(-1).sum(1) # torch.Size([bz, 768])
# gate_acore = self.gates[expert_idx](output_x_aver)
gate_score = self.gate(output_x_aver)
logits_gate_lst.append(gate_score)
candidate_output = torch.cat(outputs) # torch.Size([num_expert, bz, 32, 768])
logits_gate = torch.cat(logits_gate_lst,dim=1)# torch.Size([bz, num_expert])
prob_gate = F.softmax(logits_gate, dim=-1) # torch.Size([bz, num_experts])
topk_values, gate = torch.topk(prob_gate, self.topk, dim=1) # gate, 每个样本被分配的expert: torch.Size([bz, topk])
num_sentences = F.one_hot(gate, self.num_experts).sum(1).gt(0).sum(0) # 每个expert被分配的样本数 torch.Size([num_expert])
gate_load = num_sentences.clone()
# load balancing loss
if self.use_balance_loss:
balance_loss = self._balancing_loss(prob_gate, num_sentences)
else:
balance_loss = 0.0
# importance loss
importance_loss = self._importance_auxiliary_loss(prob_gate)
# output_average = candidate_output.sum(2) / candidate_attn_mask.unsqueeze(-1).sum(2) # torch.Size([num_expert, bz, 768])
# output_average = torch.permute(output_average, (1, 0, 2)) # torch.Size([bz, num_expert, 768])
# logits_gate = self.gate(output_average) # torch.Size([bz, num_experts, 1])
prob_gate_topk = torch.zeros_like(prob_gate)
prob_gate_topk.scatter_(1, gate, topk_values)
if self.weight_type == 'average':
# torch.Size([bz, num_expert]) 未选中的expert prob_gate_norm为0
prob_gate_normalized = prob_gate_topk / prob_gate_topk.sum(dim=1, keepdim=True)
elif self.weight_type == 'raw_prob':
prob_gate_normalized = prob_gate_topk
elif self.weight_type == 'softmax_norm':
prob_gate_normalized = F.softmax(prob_gate_topk, dim=-1) # torch.Size([bz, num_expert])
candidate_output_ad = torch.permute(candidate_output, (1, 0, 2, 3)) # torch.Size([bz, num_expert, 32, 768])
results = prob_gate_normalized.unsqueeze(-1).unsqueeze(-1) * candidate_output_ad # torch.Size([bz, num_expert, 32, 768])
moe_result = torch.sum(results, dim=1) # torch.Size([bz, 32, 768])
# import pdb;pdb.set_trace()
return moe_result, (balance_loss+importance_loss), prob_gate_normalized
# def _forward_gate_sentence(self, x, attention_mask):
# attention_mask = torch.ones(attention_mask.shape[0], attention_mask.shape[1]).to(x.device)
# x_masked = x * attention_mask.unsqueeze(-1) # torch.Size([bz, 32, 768])
# x_average = x_masked.sum(1) / attention_mask.unsqueeze(-1).sum(1)
# logits_gate = self.gate(x_average)
# prob_gate = F.softmax(logits_gate, dim=-1)
# gate = torch.argmax(prob_gate, dim=-1)
# order = gate.argsort(0)
# num_sentences = F.one_hot(gate, self.num_experts).gt(0).sum(0)
# gate_load = num_sentences.clone()
# x = x[order] # reorder according to expert number
# x = x.split(num_sentences.tolist(), dim=0) # a list of length self.num_experts
# # compute the load balancing loss
# P = prob_gate.mean(0)
# temp = num_sentences.float()
# f = temp / temp.sum(0, keepdim=True)
# balance_loss = self.num_experts * torch.sum(P * f)
# prob_gate = prob_gate.gather(dim=1, index=gate.unsqueeze(1))
# prob_gate = prob_gate[order]
# prob_gate = prob_gate.split(num_sentences.tolist(), dim=0)
# def forward_expert(input_x, prob_x, expert_idx):
# input_x = self.experts[expert_idx].forward(input_x)
# input_x = input_x * prob_x.unsqueeze(-1)
# return input_x
# result = []
# for i in range(self.num_experts):
# if x[i].size(0) > 0:
# result.append(forward_expert(x[i], prob_gate[i], i))
# result = torch.vstack(result)
# result = result[order.argsort(0)] # restore original order
# return result, balance_loss, gate_load
def _forward_gate_sentence(self, x, attention_mask):
"""
x: query_attention_output , torch.Size([bz, 32, 768])
attention_mask: torch.ones([bz, 32])
### Notice:
the raw version of expert_attention_mask is the extended_attention_mask,
which will be add to attention_score directly
the values of extended_attention_mask are -0.0 or -10000
it should be adjust to 1/0 version to be processed by experts
"""
attention_mask = torch.ones(attention_mask.shape[0], attention_mask.shape[1]).to(x.device)
x_masked = x * attention_mask.unsqueeze(-1) # torch.Size([bz, 32, 768])
x_average = x_masked.sum(1) / attention_mask.unsqueeze(-1).sum(1) # torch.Size([bz, 768])
logits_gate = self.gate(x_average) # torch.Size([bz, num_experts])
prob_gate = F.softmax(logits_gate, dim=-1) # torch.Size([bz, num_experts])
select_prob_gate, gate = torch.topk(prob_gate, self.topk, dim=1) # gate, 每个样本被分配的expert: torch.Size([bz, topk])
# 这里用l2 norm 去加权
if self.weight_type == 'l2_norm':
# actually neigther dim=0 nor dim=1 is right
normalized_tensor = torch.nn.functional.normalize(select_prob_gate, p=2, dim=1) # L2 Normalization torch.Size([bz, topk])
elif self.weight_type == 'l2_norm_0':
normalized_tensor = torch.nn.functional.normalize(select_prob_gate, p=2, dim=0) # L2 Normalization torch.Size([bz, topk])
elif self.weight_type == 'average':
normalized_tensor = select_prob_gate / select_prob_gate.sum(dim=1, keepdim=True)
elif self.weight_type == 'raw_prob':
normalized_tensor = select_prob_gate
num_sentences = F.one_hot(gate, self.num_experts).sum(1).gt(0).sum(0) # 每个expert被分配的样本数 torch.Size([num_expert])
gate_load = num_sentences.clone()
# load balancing loss
if self.use_balance_loss:
balance_loss = self._balancing_loss(prob_gate, num_sentences)
else:
balance_loss = 0.0
# importance loss
importance_loss = self._importance_auxiliary_loss(prob_gate)
# forward experts
def forward_expert(input_x, expert_idx):
input_x = self.experts[expert_idx].forward(input_x)
return input_x
result_lst = list()
for i in range(self.topk):
# top1、top2... 分别为一组进行gate分组之后过expert然后乘以概率后相加
tmp_gate = gate[:,i]
tmp_prob = normalized_tensor[:,i].unsqueeze(-1).unsqueeze(-1)
order = tmp_gate.argsort(0)
num_sentences_t = F.one_hot(tmp_gate, self.num_experts).gt(0).sum(0)
x1 = x[order] # reorder according to expert number
x1 = x1.split(num_sentences_t.tolist(), dim=0) # a list of length self.num_experts
result = []
for i in range(self.num_experts):
if x1[i].size(0) > 0:
result.append(forward_expert(x1[i], i))
result = torch.vstack(result)
result = result[order.argsort(0)] # restore original order
result_lst.append(result * tmp_prob) # result * prob
# result_lst.append(result) # result * prob # add_1212
moe_result = sum(result_lst)
return moe_result, (balance_loss+importance_loss), gate
def _forward_sentence_single_expert(self, x, attention_mask):
x_masked = x * attention_mask.unsqueeze(-1)
x_average = x_masked.sum(1) / attention_mask.unsqueeze(-1).sum(1)
logits_gate = self.gate(x_average)
prob_gate = F.softmax(logits_gate, dim=-1)
gate = torch.argmax(prob_gate, dim=-1)
gate_load = F.one_hot(gate, self.num_experts).gt(0).sum(0)
x = self.experts[gate.cpu().item()].forward(x)
return x, 0.0, gate_load
def forward(self, x, attention_mask):
if self.route_method == "gate-token":
x, balance_loss, gate_load = self._forward_gate_token(x)
elif self.route_method == "gate-sentence":
if x.size(0) == 1:
x, balance_loss, gate_load = self._forward_sentence_single_expert(x, attention_mask)
else:
x, balance_loss, gate_load = self._forward_gate_sentence(x, attention_mask)
elif self.route_method == "gate-sentence-post":
x, balance_loss, gate_load = self._forward_gate_sentence_post(x, attention_mask)
else:
raise KeyError("Routing method not supported.")
# import pdb; pdb.set_trace()
return x, balance_loss, gate_load

View File

@ -92,7 +92,6 @@ class PrePromptMoE(PromptMoEBase):
self.topk = topk self.topk = topk
if route_method in ["gate-token", "gate-single-token", "gate-sentence"]: if route_method in ["gate-token", "gate-single-token", "gate-sentence"]:
self.gate = nn.Linear(hidden_size, num_experts, bias=False).float() self.gate = nn.Linear(hidden_size, num_experts, bias=False).float()
print(self.gate)
else: else:
raise KeyError("Routing method not supported.") raise KeyError("Routing method not supported.")

View File

@ -5,7 +5,7 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
class RouteMoELayer(nn.Module): class RouteMoELayer(nn.Module):
def __init__(self, hidden_size, expert, num_experts, num_beams=2, layer_judge=None, route_method="pre-route"): def __init__(self, hidden_size, expert, num_experts, num_beams=2, layer_judge=None, route_method="pre-route", weight_type="ffn_prob"):
# remove hash list # remove hash list
nn.Module.__init__(self) nn.Module.__init__(self)
self.num_experts = num_experts self.num_experts = num_experts
@ -13,6 +13,7 @@ class RouteMoELayer(nn.Module):
self.num_beams = num_beams self.num_beams = num_beams
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.layer_judge = layer_judge self.layer_judge = layer_judge
self.weight_type = weight_type
self.route_method = route_method self.route_method = route_method
if self.route_method == "pre-route": if self.route_method == "pre-route":
@ -22,6 +23,17 @@ class RouteMoELayer(nn.Module):
self.gate = gate self.gate = gate
# self.gates = nn.ModuleList([copy.deepcopy(gate) for i in range(num_experts)]) # self.gates = nn.ModuleList([copy.deepcopy(gate) for i in range(num_experts)])
def _importance_auxiliary_loss(self, prob_gate):
# From VMOE
# _importance_auxiliary_loss
axis = tuple(range(prob_gate.ndim - 1)) # All except last.
importance_per_expert = torch.sum(prob_gate, dim=axis)
std_importance_per_expert = torch.std(importance_per_expert)
mean_importance_per_expert = torch.mean(importance_per_expert)
# Compute coefficient of variation (i.e. std/mean) squared.
return (std_importance_per_expert / mean_importance_per_expert)**2
def forward_gate(self, x): def forward_gate(self, x):
""" """
x : torch.Size([bz*num_beams, 32, 768]) or torch.Size([bz, 32, 768]) x : torch.Size([bz*num_beams, 32, 768]) or torch.Size([bz, 32, 768])
@ -29,7 +41,8 @@ class RouteMoELayer(nn.Module):
""" """
attention_mask = torch.ones(x.shape[0], x.shape[1]).to(x.device) attention_mask = torch.ones(x.shape[0], x.shape[1]).to(x.device)
x_masked = x * attention_mask.unsqueeze(-1) # torch.Size([bz*num_beams, 32, 768]) x_masked = x * attention_mask.unsqueeze(-1) # torch.Size([bz*num_beams, 32, 768])
x_average = x_masked.sum(1) / attention_mask.unsqueeze(-1).sum(1) # torch.Size([bz*num_beams, 768]) # x_average = x_masked.sum(1) / attention_mask.unsqueeze(-1).sum(1) # torch.Size([bz*num_beams, 768])
x_average = torch.mean(x_masked, dim=1) # torch.Size([bz*num_beams, 768])
logits_gate = self.gate(x_average) # torch.Size([bz*num_beams, num_experts]) logits_gate = self.gate(x_average) # torch.Size([bz*num_beams, num_experts])
prob_gate = F.softmax(logits_gate, dim=-1) # torch.Size([bz*num_beams, num_experts]) prob_gate = F.softmax(logits_gate, dim=-1) # torch.Size([bz*num_beams, num_experts])
return prob_gate return prob_gate
@ -42,7 +55,7 @@ class RouteMoELayer(nn.Module):
topk_values, gate = torch.topk(current_scores, self.num_beams, dim=1) # gate, 每个样本被分配的expert: torch.Size([bz, topk]) topk_values, gate = torch.topk(current_scores, self.num_beams, dim=1) # gate, 每个样本被分配的expert: torch.Size([bz, topk])
beam_scores = topk_values.view(self.num_beams * batch_size) # torch.Size([bz * num_beams]) beam_scores = topk_values.view(self.num_beams * batch_size) # torch.Size([bz * num_beams])
expert_route = gate.view(self.num_beams * batch_size).unsqueeze(1) # torch.Size([bz * num_beams,1]) expert_route = gate.view(self.num_beams * batch_size).unsqueeze(1) # torch.Size([bz * num_beams,1])
beam_idx = None beam_idx = torch.tensor(range(self.num_beams * batch_size))
else: else:
if self.layer_judge=='first' and self.route_method == 'post-route': if self.layer_judge=='first' and self.route_method == 'post-route':
batch_size = batch_size batch_size = batch_size
@ -89,54 +102,63 @@ class RouteMoELayer(nn.Module):
return beam_scores, expert_route, beam_idx return beam_scores, expert_route, beam_idx
def forward_expert_ffn(self, x, expert_select, current_scores):
def forward_expert_ffn(self, x, expert_select, beam_scores):
""" """
x_repeat : [bz*num_beams, 32,768] x_repeat : [bz*num_beams, 32,768]
expert_select : [bz*num_beams] expert_select : [bz*num_beams]
current_scores : [bz*num_beams, num_experts] / [bz, num_experts]
""" """
# add_1212 l2_normalization # add_1228 l2_normalization
# normalized_tensor = torch.nn.functional.normalize(beam_scores, p=2, dim=0) # L2 Normalization torch.Size([bz, topk]) # normalized_tensor = torch.nn.functional.normalize(current_scores, p=2, dim=0) # L2 Normalization torch.Size([bz, topk])
# tmp_prob = normalized_tensor.unsqueeze(-1).unsqueeze(-1) # tmp_prob = normalized_tensor.unsqueeze(-1).unsqueeze(-1)
# import pdb;pdb.set_trace()
outputs = list() outputs = list()
for i in range(x.shape[0]): for i in range(self.num_experts):
output_x = self.experts[expert_select[i]].forward(x[i]) output_x = self.experts[i].forward(x)
outputs.append(output_x.unsqueeze(0)) outputs.append(output_x.unsqueeze(1))
candidate_output = torch.cat(outputs) candidate_output = torch.cat(outputs, dim=1)
expert_select_matrix = F.one_hot(expert_select, self.num_experts)
# candidate_output = candidate_output * tmp_prob if self.weight_type == 'ffn_prob':
return candidate_output # torch.Size([bz*num_beams, 32, 768]) tmp_prob = current_scores * expert_select_matrix
candidate_output = candidate_output * tmp_prob.unsqueeze(-1).unsqueeze(-1)
else:
candidate_output = candidate_output * expert_select_matrix.unsqueeze(-1).unsqueeze(-1)
output = torch.sum(candidate_output, dim=1)
# import pdb;pdb.set_trace()
return output # torch.Size([bz*num_beams, 32, 768])
def forward_pre_route(self, x, beam_scores, expert_route, use_log=True): def forward_pre_route(self, x, beam_scores, expert_route, use_log=True):
current_scores = self.forward_gate(x) # [bz*num_beams, 5] current_scores = self.forward_gate(x) # [bz, num_beams] / [bz*num_beams, num_beams]
importance_loss = self._importance_auxiliary_loss(current_scores)
if use_log: if use_log:
current_scores_log = torch.log(current_scores) # 取log之后可以直接相加 current_scores_log = torch.log(current_scores) # 取log之后可以直接相加
else: else:
current_scores_log = current_scores current_scores_log = current_scores
# import pdb;pdb.set_trace()
batch_size, num_tokens = x.shape[0], x.shape[1] batch_size, num_tokens = x.shape[0], x.shape[1]
beam_scores, expert_route, beam_idx = self.beam_search(current_scores_log, beam_scores, expert_route, batch_size) beam_scores, expert_route, beam_idx = self.beam_search(current_scores_log, beam_scores, expert_route, batch_size)
current_expert_select = expert_route[:,-1] current_expert_select = expert_route[:,-1]
if self.layer_judge=='first': # expand first dim to batch_size * num_beams if self.layer_judge=='first': # expand first dim to batch_size * num_beams
replicated_tensor = x.unsqueeze(1).expand(batch_size, self.num_beams, num_tokens, self.hidden_size) replicated_tensor = x.unsqueeze(1).expand(batch_size, self.num_beams, num_tokens, self.hidden_size)
x = replicated_tensor.contiguous().view(-1, num_tokens, self.hidden_size) # [bz*num_beams, 32,768] x = replicated_tensor.contiguous().view(-1, num_tokens, self.hidden_size) # [bz*num_beams, 32,768]
current_scores_t = current_scores.unsqueeze(1).expand(batch_size, self.num_beams, self.num_experts)
current_scores = current_scores_t.contiguous().view(-1, self.num_experts) # [bz*num_beams, num_experts]
candidate_output = self.forward_expert_ffn(x, current_expert_select, beam_scores) # [bz*num_beams, 32,768] input_x = x[beam_idx]
candidate_output = self.forward_expert_ffn(input_x, current_expert_select, current_scores) # [bz*num_beams, 32,768]
return candidate_output, beam_scores, expert_route, beam_idx # import pdb;pdb.set_trace()
return candidate_output, beam_scores, expert_route, beam_idx, importance_loss
def forward_post_route(self, x, beam_scores, expert_route, use_log=True): def forward_post_route(self, x, beam_scores, expert_route, use_log=True):
attention_mask = torch.ones(x.shape[0], x.shape[1]).to(x.device) attention_mask = torch.ones(x.shape[0], x.shape[1]).to(x.device)
x_masked = x * attention_mask.unsqueeze(-1) # torch.Size([bz, 32, 768]) x_masked = x * attention_mask.unsqueeze(-1) # torch.Size([bz, 32, 768])
def forward_expert(input_x, expert_idx): def forward_expert(input_x, expert_idx):
output_x = self.experts[expert_idx].forward(input_x) output_x = self.experts[expert_idx].forward(input_x)
return output_x return output_x
@ -145,12 +167,14 @@ class RouteMoELayer(nn.Module):
logits_gate_lst = list() logits_gate_lst = list()
for expert_idx in range(self.num_experts): for expert_idx in range(self.num_experts):
output_x = forward_expert(x_masked, expert_idx) output_x = forward_expert(x_masked, expert_idx)
outputs.append(output_x.unsqueeze(0)) # output_x_aver = output_x.sum(1) / attention_mask.unsqueeze(-1).sum(1) # torch.Size([bz*num_beam, 768])
output_x_aver = output_x.sum(1) / attention_mask.unsqueeze(-1).sum(1) # torch.Size([bz*num_beam, 768]) output_x_aver = torch.mean(output_x, dim=1)
# gate_score = self.gates[expert_idx](output_x_aver) # gate_score = self.gates[expert_idx](output_x_aver)
gate_score = self.gate(output_x_aver) gate_score = self.gate(output_x_aver)
logits_gate_lst.append(gate_score) logits_gate_lst.append(gate_score)
candidate_output = torch.cat(outputs) # torch.Size([num_expert, bz*num_beam, 32, 768]) outputs.append(output_x.unsqueeze(0))
candidate_output_raw = torch.cat(outputs) # torch.Size([num_expert, bz*num_beam, 32, 768])
logits_gate = torch.cat(logits_gate_lst,dim=1)# torch.Size([bz*num_beam, num_expert]) logits_gate = torch.cat(logits_gate_lst,dim=1)# torch.Size([bz*num_beam, num_expert])
current_scores = F.softmax(logits_gate, dim=-1) # torch.Size([bz*num_beam, num_experts]) current_scores = F.softmax(logits_gate, dim=-1) # torch.Size([bz*num_beam, num_experts])
@ -159,24 +183,33 @@ class RouteMoELayer(nn.Module):
else: else:
current_scores_log = current_scores current_scores_log = current_scores
batch_size = x.shape[0] # bz*num_beam # importance loss
importance_loss = self._importance_auxiliary_loss(current_scores)
batch_size, num_tokens = x.shape[0], x.shape[1] # bz*num_beam
beam_scores, expert_route, beam_idx = self.beam_search(current_scores_log, beam_scores, expert_route, batch_size) beam_scores, expert_route, beam_idx = self.beam_search(current_scores_log, beam_scores, expert_route, batch_size)
# beam_scores torch.Size([bz*num_beam]) # beam_scores torch.Size([bz*num_beam])
# expert_route torch.Size([bz*num_beam, layer_n]) # expert_route torch.Size([bz*num_beam, layer_n])
current_select_expert = expert_route[:,-1] current_select_expert = expert_route[:,-1]
# current_select_expert torch.Size([bz*num_beam, 1])
output = list() if self.layer_judge == 'first':
for i in range(beam_idx.shape[0]): replicated_tensor = candidate_output_raw.unsqueeze(2).expand(self.num_experts, batch_size, self.num_beams, num_tokens, self.hidden_size)
b_idx = beam_idx[i] candidate_output_raw = replicated_tensor.contiguous().view(self.num_experts, -1, num_tokens, self.hidden_size) # [bz*num_beams, 32,768]
ex_idx = current_select_expert[i] current_scores_t = current_scores.unsqueeze(1).expand(batch_size, self.num_beams, self.num_experts)
ex_out = candidate_output[ex_idx, b_idx, :,:] current_scores = current_scores_t.contiguous().view(-1, self.num_experts) # [bz*num_beams, num_experts]
output.append(ex_out.unsqueeze(0))
candidate_output = candidate_output_raw.permute(1, 0, 2, 3)[beam_idx] # torch.Size([8, 2, 32, 768])
final_output = torch.concat(output, dim=0) expert_select_matrix = F.one_hot(current_select_expert, self.num_experts)
if self.weight_type == 'ffn_prob':
return final_output, beam_scores, expert_route, beam_idx tmp_prob = current_scores[beam_idx] * expert_select_matrix
output = candidate_output * tmp_prob.unsqueeze(-1).unsqueeze(-1)
else:
output = candidate_output * expert_select_matrix.unsqueeze(-1).unsqueeze(-1)
final_output = torch.sum(output, dim=1)
return final_output, beam_scores, expert_route, beam_idx, importance_loss
def forward(self, x, attention_mask, beam_scores, expert_route, use_log=True): def forward(self, x, attention_mask, beam_scores, expert_route, use_log=True):
""" """
if first_layer: x [bz, 32, 768] if first_layer: x [bz, 32, 768]
@ -184,11 +217,11 @@ class RouteMoELayer(nn.Module):
""" """
if self.route_method == 'pre-route': if self.route_method == 'pre-route':
candidate_output, beam_scores, expert_route, beam_idx = self.forward_pre_route(x, beam_scores, expert_route, use_log=True) candidate_output, beam_scores, expert_route, beam_idx, importance_loss = self.forward_pre_route(x, beam_scores, expert_route, use_log=True)
elif self.route_method == "post-route": elif self.route_method == "post-route":
candidate_output, beam_scores, expert_route, beam_idx = self.forward_post_route(x, beam_scores, expert_route, use_log=True) candidate_output, beam_scores, expert_route, beam_idx, importance_loss = self.forward_post_route(x, beam_scores, expert_route, use_log=True)
return candidate_output, beam_scores, expert_route, beam_idx return candidate_output, beam_scores, expert_route, beam_idx, importance_loss

View File

@ -0,0 +1,294 @@
import copy
import pickle
import torch
import torch.nn as nn
import torch.nn.functional as F
import copy
import pickle
import torch
import torch.nn as nn
import torch.nn.functional as F
class MoELayer(nn.Module):
def __init__(self, hidden_size, expert, num_experts, route_method, topk=1, use_balance_loss=True, weight_type='raw_prob, topk(softmax)'):
# remove hash list
nn.Module.__init__(self)
self.num_experts = num_experts
self.experts = nn.ModuleList([copy.deepcopy(expert) for i in range(num_experts)])
self.route_method = route_method
self.topk = topk
self.use_balance_loss = use_balance_loss
self.weight_type = weight_type
if route_method in ["gate-token", "gate-sentence"]:
self.gate = nn.Linear(hidden_size, num_experts, bias=False).float()
elif route_method in ["gate-sentence-post"]:
gate = nn.Linear(hidden_size, 1, bias=False).float()
# self.gates = nn.ModuleList([copy.deepcopy(gate) for i in range(num_experts)])
self.gate = gate
else:
raise KeyError("Routing method not supported.")
def _balancing_loss(self, prob_gate, num_tokens):
# From MOEBERT
# compute the load balancing loss
# prob_gate是 [bz, num_expert]每个样本被分配给每个expert的概率
# 等价于 VMOE 中 _gshard_auxiliary_loss
P = prob_gate.mean(0) # torch.Size([num_expert]) 每个expert被分配到样本的平均概率
temp = num_tokens.float()
f = temp / temp.sum(0, keepdim=True) # 每个expert被分配的sample比例
balance_loss = self.num_experts * torch.sum(P * f)
return balance_loss
def _importance_auxiliary_loss(self, prob_gate):
# From VMOE
# _importance_auxiliary_loss
axis = tuple(range(prob_gate.ndim - 1)) # All except last.
importance_per_expert = torch.sum(prob_gate, dim=axis)
std_importance_per_expert = torch.std(importance_per_expert)
mean_importance_per_expert = torch.mean(importance_per_expert)
# Compute coefficient of variation (i.e. std/mean) squared.
return (std_importance_per_expert / mean_importance_per_expert)**2
def _forward_gate_token(self, x):
bsz, seq_len, dim = x.size()
x = x.view(-1, dim)
logits_gate = self.gate(x)
prob_gate = F.softmax(logits_gate, dim=-1)
gate = torch.argmax(prob_gate, dim=-1)
order = gate.argsort(0)
num_tokens = F.one_hot(gate, self.num_experts).gt(0).sum(0)
gate_load = num_tokens.clone()
x = x[order] # reorder according to expert number
x = x.split(num_tokens.tolist(), dim=0) # a list of length self.num_experts
# compute the load balancing loss
P = prob_gate.mean(0)
temp = num_tokens.float()
f = temp / temp.sum(0, keepdim=True)
balance_loss = self.num_experts * torch.sum(P * f)
prob_gate = prob_gate.gather(dim=1, index=gate.unsqueeze(1))
prob_gate = prob_gate[order]
prob_gate = prob_gate.split(num_tokens.tolist(), dim=0)
def forward_expert(input_x, prob_x, expert_idx):
input_x = self.experts[expert_idx].forward(input_x)
input_x = input_x * prob_x
return input_x
x = [forward_expert(x[i], prob_gate[i], i) for i in range(self.num_experts)]
x = torch.vstack(x)
x = x[order.argsort(0)] # restore original order
x = x.view(bsz, seq_len, dim)
return x, balance_loss, gate_load
def _forward_gate_sentence_post(self, x, attention_mask):
"""
x: query_attention_output; torch.Size([bz, 32, 768])
attention_mask: torch.ones([bz, 32])
bz = 4
x = torch.randn(bz,32,768)
attention_mask = torch.ones([bz, 32])
"""
# Prepare Input x
attention_mask = torch.ones(attention_mask.shape[0], attention_mask.shape[1]).to(x.device)
x_masked = x * attention_mask.unsqueeze(-1) # torch.Size([bz, 32, 768])
# FeedForward(x) & Forward Gate
outputs = list()
logits_gate_lst = list()
for expert_idx in range(self.num_experts):
output_x = self.experts[expert_idx].forward(x_masked)
outputs.append(output_x.unsqueeze(0))
output_x_aver = torch.mean(output_x, dim=1)
# gate_acore = self.gates[expert_idx](output_x_aver)
gate_score = self.gate(output_x_aver)
logits_gate_lst.append(gate_score)
candidate_output = torch.cat(outputs) # torch.Size([num_expert, bz, 32, 768])
logits_gate = torch.cat(logits_gate_lst,dim=1)# torch.Size([bz, num_expert])
# Probabilities for each sample of what expert it should be sent to.
prob_gate = F.softmax(logits_gate, dim=-1) # torch.Size([bz, num_experts])
if 'softmax(topk)' in self.weight_type:
prob_gate1, gate = torch.topk(logits_gate, self.topk, dim=1)
select_prob_gate = F.softmax(prob_gate1, dim=-1)
else:
select_prob_gate, gate = torch.topk(prob_gate, self.topk, dim=1) # gate, 每个样本被分配的expert: torch.Size([bz, topk])
# Calculate Balancing Loss
if self.use_balance_loss:
num_sentences = F.one_hot(gate, self.num_experts).sum(1).gt(0).sum(0) # 每个expert被分配的样本数 torch.Size([num_expert])
balance_loss = self._balancing_loss(prob_gate, num_sentences)
else:
balance_loss = 0.0
# Calculate Importance Loss
importance_loss = self._importance_auxiliary_loss(prob_gate)
# Reshap Prob_gate & Gate
# expert_mask: [batch_size, topk, num_experts]
# expert_gate: [batch_size, topk, num_experts]
# combine_tensor: [batch_size, num_experts]
expert_mask = F.one_hot(gate, self.num_experts)
expert_gate = select_prob_gate.unsqueeze(-1) * expert_mask
combine_tensor = torch.sum(expert_gate, dim=1)
# combine_tensor = torch.zeros_like(prob_gate)
# combine_tensor.scatter_(1, gate, select_prob_gate) # 等价操作,但可能不可导
candidate_output_ad = torch.permute(candidate_output, (1, 0, 2, 3)) # torch.Size([bz, num_expert, 32, 768])
results = candidate_output_ad * combine_tensor.unsqueeze(-1).unsqueeze(-1) # torch.Size([bz, num_expert, 32, 768])
outputs = torch.sum(results, dim=1) # torch.Size([bz, 32, 768])
import pdb;pdb.set_trace()
return outputs, (balance_loss+importance_loss), combine_tensor
def pre_router(self, x, attention_mask):
# Prepare input x
attention_mask = torch.ones(attention_mask.shape[0], attention_mask.shape[1]).to(x.device)
x_masked = x * attention_mask.unsqueeze(-1) # torch.Size([bz, 32, 768])
x_average = torch.mean(x_masked, dim=1) # torch.Size([bz, 768])
# Forward Gate
# logits_gate: [bz, num_experts]
logits_gate = self.gate(x_average)
# Probabilities for each sample of what expert it should be sent to.
# prob_gate: [bz, num_experts]
prob_gate = F.softmax(logits_gate, dim=-1)
if 'softmax(topk)' in self.weight_type:
prob_gate1, gate = torch.topk(logits_gate, self.topk, dim=1)
select_prob_gate = F.softmax(prob_gate1, dim=-1)
else:
# topk(softmax)
# Get Top-K experts for each sample
# gate: [bz, topk]
# select_prob_gate: [bz, topk]
select_prob_gate, gate = torch.topk(prob_gate, self.topk, dim=1)
# Reshap Prob_gate & Gate
# expert_mask: [batch_size, topk, num_experts]
# expert_gate: [batch_size, topk, num_experts]
# combine_tensor: [batch_size, num_experts]
expert_mask = F.one_hot(gate, self.num_experts)
expert_gate = select_prob_gate.unsqueeze(-1) * expert_mask
combine_tensor = torch.sum(expert_gate, dim=1)
# Calculate Balancing Loss
if self.use_balance_loss:
num_sentences = F.one_hot(gate, self.num_experts).sum(1).gt(0).sum(0) # 每个expert被分配的样本数 torch.Size([num_expert])
balance_loss = self._balancing_loss(prob_gate, num_sentences)
else:
balance_loss = 0.0
# Calculate Importance Loss
importance_loss = self._importance_auxiliary_loss(prob_gate)
import pdb; pdb.set_trace()
return expert_mask, combine_tensor, balance_loss, importance_loss
def _forward_gate_sentence(self, x, attention_mask):
"""
x: query_attention_output , torch.Size([bz, 32, 768])
attention_mask: torch.ones([bz, 32])
### Notice:
the raw version of expert_attention_mask is the extended_attention_mask,
which will be add to attention_score directly
the values of extended_attention_mask are -0.0 or -10000
it should be adjust to 1/0 version to be processed by experts
"""
# Forward Router
expert_mask, combine_tensor, balance_loss, importance_loss = self.pre_router(x, attention_mask)
# Forward Expert FFN
result = []
for expert_idx in range(self.num_experts):
output_x = self.experts[expert_idx].forward(x)
result.append(output_x.unsqueeze(0))
expert_output = torch.cat(result).permute(1,0,2,3) # torch.Size([batch_size, num_expert, num_tokens, hidden_states])
# multiply outputs of experts by the routing probability
expert_outputs_combined = expert_output * combine_tensor.unsqueeze(-1).unsqueeze(-1) # torch.Size([batch_size, num_expert, num_tokens, hidden_states])
outputs = torch.sum(expert_outputs_combined, dim=1) # torch.Size([batch_size, num_tokens, hidden_states])
import pdb; pdb.set_trace()
return outputs, (balance_loss+importance_loss), combine_tensor
def forward(self, x, attention_mask):
if self.route_method == "gate-token":
x, balance_loss, gate_load = self._forward_gate_token(x)
elif self.route_method == "gate-sentence":
x, balance_loss, gate_load = self._forward_gate_sentence(x, attention_mask)
elif self.route_method == "gate-sentence-post":
x, balance_loss, gate_load = self._forward_gate_sentence_post(x, attention_mask)
else:
raise KeyError("Routing method not supported.")
# import pdb; pdb.set_trace()
return x, balance_loss, gate_load
if __name__ == '__main__':
import sys
sys.path.append("/mnt/pfs-guan-ssai/nlu/wanghanzi/multimodal/PromptMoE")
from minigpt4.models.QformerRouteMoE import BertConfig
from minigpt4.models.QformerRouteMoE import FeedForward
from minigpt4.models.moe.utils import (
moe_layer_judge,
)
vision_width = 1408
cross_attention_freq = 2
num_query_token = 32
# init_QformerMoE
config = BertConfig.from_pretrained("/mnt/pfs-guan-ssai/nlu/wanghanzi/models/bert-base-uncased")
config.encoder_width = vision_width
# insert cross-attention layer every other block
config.add_cross_attention = True
config.cross_attention_freq = cross_attention_freq
config.query_length = num_query_token
config.moebert_expert_num = 3
config.moebert_num_beams = 2
config.moebert_route_method = 'gate-sentence-post'
config.moe_topk = 1
config.use_balance_loss = False
# config.moe_weight_type = 'raw_prob, softmax(topk)'
config.moe_weight_type = 'raw_prob, topk(softmax)'
batch_size = 4
x2 = torch.randn(batch_size, 32, 768)
beam_scores, expert_route = None, None
for layer_num in [6, 8, 10]:
layer_judge = moe_layer_judge(layer_num)
ffn = FeedForward(config)
gate = nn.Linear(768, config.moebert_expert_num, bias=False).float()
experts_moe = MoELayer(
hidden_size=config.hidden_size,
expert=ffn,
num_experts=config.moebert_expert_num,
route_method=config.moebert_route_method,
topk=config.moe_topk,
use_balance_loss=config.use_balance_loss,
weight_type=config.moe_weight_type,
)
attn_mask = torch.ones([batch_size, 32])
layer_output = experts_moe(x2, attn_mask)
hidden_states3, aux_loss, combine_tensor = layer_output
print(combine_tensor)
print(aux_loss)
x2 = hidden_states3
print("------------------------------------")
import pdb; pdb.set_trace()

View File

@ -19,15 +19,33 @@ def use_experts(layer_idx):
else: else:
return False return False
def use_experts_route(layer_idx):
# if layer_idx % 2 == 0:
# use moe_ffn after cross_attns
# if int(layer_idx) in [0,2,4,6,8,10]:
if int(layer_idx) in [6,7,8,9,10,11]:
return True
else:
return False
def moe_layer_judge(layer_idx): def moe_layer_judge(layer_idx):
if layer_idx == 6: if layer_idx == 6:
return 'first' return 'first'
elif layer_idx == 8: elif layer_idx in [7,8,9,10]:
return 'mid' return 'mid'
elif layer_idx == 10: elif layer_idx == 11:
return 'last' return 'last'
else: else:
return None return None
# if layer_idx == 0:
# return 'first'
# elif layer_idx in [2,4,6,8]:
# return 'mid'
# elif layer_idx == 10:
# return 'last'
# else:
# return None
def process_ffn(model): def process_ffn(model):
if model.config.model_type == "bert": if model.config.model_type == "bert":

View File

@ -10,7 +10,6 @@ model:
load_finetuned: False load_finetuned: False
vit_model: eva_clip_g vit_model: eva_clip_g
pretrained: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/blip2/blip2_vicuna7b/blip2_pretrained_vicuna7b.pth" pretrained: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/blip2/blip2_vicuna7b/blip2_pretrained_vicuna7b.pth"
# finetuned: "/mnt/pfs-guan-ssai/nlu/wanghanzi/experiments/blip2/vicuna7b/qformer_moe/mix_coco_gqa_balance_raw_QformerMoE_Post_train_qf_train_qt_aver_weight_5ex_top1_1loss_textinqf_epo3_s42_1201/20231201184/checkpoint_best.pth"
finetuned: "" finetuned: ""
q_former_model: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/blip2/blip2_vicuna7b/blip2_pretrained_vicuna7b.pth" q_former_model: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/blip2/blip2_vicuna7b/blip2_pretrained_vicuna7b.pth"
@ -38,7 +37,7 @@ model:
# moe # moe
use_moeqformer: True use_moeqformer: True
moebert_expert_num: 5 moebert_expert_num: 3
moebert_route_method: "gate-sentence-post" moebert_route_method: "gate-sentence-post"
moebert_load_balance: 0 moebert_load_balance: 0
moe_topk: 1 moe_topk: 1
@ -110,6 +109,7 @@ run:
max_epoch: 1 max_epoch: 1
num_workers: 4 num_workers: 4
warmup_steps: 600 warmup_steps: 600
iters_per_epoch: 1000
seed: 42 seed: 42
output_dir: "/mnt/pfs-guan-ssai/nlu/wanghanzi/experiments/blip2/vicuna7b/qformer_moe/mix_coco_gqa_balance_raw_QformerMoE_Post_train_qf_train_qt_aver_weight_5ex_top1_1loss_textinqf_epo3_s42_1201/" output_dir: "/mnt/pfs-guan-ssai/nlu/wanghanzi/experiments/blip2/vicuna7b/qformer_moe/mix_coco_gqa_balance_raw_QformerMoE_Post_train_qf_train_qt_aver_weight_5ex_top1_1loss_textinqf_epo3_s42_1201/"

View File

@ -10,7 +10,7 @@ model:
load_finetuned: True load_finetuned: True
vit_model: eva_clip_g vit_model: eva_clip_g
pretrained: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/blip2/blip2_vicuna7b/blip2_pretrained_vicuna7b.pth" pretrained: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/blip2/blip2_vicuna7b/blip2_pretrained_vicuna7b.pth"
finetuned: "/mnt/pfs-guan-ssai/nlu/wanghanzi/experiments/blip2/vicuna7b/qformer_moe_route/mix_coco_gqa_1610k_raw_QformerMoE_Route_Post_linear_gate_3ex_3beam_1loss_top3layer_log_textinqf_epo3_1216/20231216155/checkpoint_best.pth" finetuned: "/mnt/pfs-guan-ssai/nlu/wanghanzi/experiments/blip2/vicuna7b/qformer_moe_route/mix_coco_gqa_1610k_raw_QformerMoE_Route_Post_ffn_prob_linear_gate_2ex_2beam_1gate_2loss_5e5lr_top6layer_textinqf_epo8_0112/20240112212/checkpoint_best.pth"
q_former_model: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/blip2/blip2_vicuna7b/blip2_pretrained_vicuna7b.pth" q_former_model: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/blip2/blip2_vicuna7b/blip2_pretrained_vicuna7b.pth"
# vit encoder # vit encoder
@ -38,10 +38,12 @@ model:
# moe # moe
use_moeqformer: True use_moeqformer: True
use_route_moe: True use_route_moe: True
moebert_expert_num: 3
moebert_num_beams: 3
moebert_route_method: "post-route" moebert_route_method: "post-route"
gate_save_path: "/mnt/pfs-guan-ssai/nlu/wanghanzi/experiments/blip2/vicuna7b/qformer_moe_route/route_save/mix_coco_gqa_balance_raw_QformerMoE_Route_linear_gate_5ex_2beam_1loss_textinqf_epo5_toplayer3_1209_eval_latest1/" moebert_load_balance: 0
moebert_expert_num: 2
moebert_num_beams: 2
moe_weight_type: 'ffn_prob'
gate_save_path: "/mnt/pfs-guan-ssai/nlu/wanghanzi/experiments/blip2/vicuna7b/qformer_moe_route/route_save/mix_coco_gqa_1610k_raw_QformerMoE_Route_Post_ffn_prob_linear_gate_2ex_2beam_1gate_2loss_5e5lr_top6layer_textinqf_epo8_0112/"
datasets: datasets:
gqa: gqa:
@ -81,19 +83,20 @@ run:
task: instruction_tuning task: instruction_tuning
# optimizer # optimizer
lr_sched: "linear_warmup_cosine_lr" lr_sched: "linear_warmup_cosine_lr"
init_lr: 2e-5 init_lr: 5e-5
min_lr: 1e-6 min_lr: 1e-6
warmup_lr: 1e-6 warmup_lr: 1e-6
log_freq: 5 log_freq: 5
save_freq: 1500 save_freq: 1500
weight_decay: 0.05 weight_decay: 0.05
max_epoch: 5 max_epoch: 10
num_workers: 4 num_workers: 4
warmup_steps: 600 warmup_steps: 600
iters_per_epoch: 3000
seed: 42 seed: 42
output_dir: "/mnt/pfs-guan-ssai/nlu/wanghanzi/experiments/blip2/vicuna7b/qformer_moe_route/eval/mix_coco_gqa_balance_raw_QformerMoE_Route_linear_gate_5ex_2beam_1loss_textinqf_epo5_toplayer3_1209/" output_dir: "/mnt/pfs-guan-ssai/nlu/wanghanzi/experiments/blip2/vicuna7b/qformer_moe_route/eval/mix_coco_gqa_1610k_raw_QformerMoE_Route_Post_ffn_prob_linear_gate_2ex_2beam_1gate_2loss_5e5lr_top6layer_textinqf_epo8_0112/"
amp: True amp: True
resume_ckpt_path: null resume_ckpt_path: null

View File

@ -38,10 +38,12 @@ model:
# moe # moe
use_moeqformer: True use_moeqformer: True
use_route_moe: True use_route_moe: True
moebert_route_method: "post-route"
moebert_load_balance: 0
moebert_expert_num: 3 moebert_expert_num: 3
moebert_num_beams: 3 moebert_num_beams: 3
moebert_route_method: "post-route" moe_weight_type: 'ffn_prob'
gate_save_path: "/mnt/pfs-guan-ssai/nlu/wanghanzi/experiments/blip2/vicuna7b/qformer_moe/route_save/mix_coco_gqa_balance_raw_QformerMoE_Route_linear_gate_5ex_2beam_1loss_textinqf_epo5_toplayer3_1209/" use_balance_loss: False
datasets: datasets:
gqa: # train: 943000, 12578, 12578) gqa: # train: 943000, 12578, 12578)
@ -97,19 +99,20 @@ run:
task: instruction_tuning task: instruction_tuning
# optimizer # optimizer
lr_sched: "linear_warmup_cosine_lr" lr_sched: "linear_warmup_cosine_lr"
init_lr: 2e-5 init_lr: 5e-5
min_lr: 1e-6 min_lr: 1e-6
warmup_lr: 1e-6 warmup_lr: 1e-6
log_freq: 5 log_freq: 5
save_freq: 1500 save_freq: 1500
weight_decay: 0.05 weight_decay: 0.05
max_epoch: 5 max_epoch: 8
num_workers: 4 num_workers: 4
warmup_steps: 600 warmup_steps: 600
iters_per_epoch: 5000
seed: 42 seed: 42
output_dir: "/mnt/pfs-guan-ssai/nlu/wanghanzi/experiments/blip2/vicuna7b/qformer_moe_route/mix_coco_gqa_balance_raw_QformerMoE_Route_linear_gate_5ex_2beam_1loss_textinqf_epo5_toplayer3_1209/" output_dir: "/mnt/pfs-guan-ssai/nlu/wanghanzi/experiments/blip2/vicuna7b/qformer_moe_route/mix_coco_gqa_1610k_raw_QformerMoE_Route_Post_ffn_prob_linear_1gate_3ex_3beam_1loss_5e5lr_top6layer_textinqf_epo8_0117/"
amp: True amp: True
resume_ckpt_path: null resume_ckpt_path: null

View File

@ -0,0 +1,129 @@
# Copyright (c) 2022, salesforce.com, inc.
# All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
# 0107test
model:
arch: blip2_vicuna_instruct
model_type: vicuna7b_pretrain
load_pretrained: True
load_finetuned: False
vit_model: eva_clip_g
pretrained: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/blip2/blip2_vicuna7b/blip2_pretrained_vicuna7b.pth"
# finetuned: ""
q_former_model: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/blip2/blip2_vicuna7b/blip2_pretrained_vicuna7b.pth"
# vit encoder
image_size: 224
drop_path_rate: 0
use_grad_checkpoint: False
vit_precision: "fp16"
# Q-Former
num_query_token: 32
qformer_text_input: True
# vicuna
llm_model: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/vicuna-7b-v1.1"
prompt: ""
max_txt_len: 256
max_output_txt_len: 256
# freeze
freeze_vit: True
freeze_llm: True
freeze_qformer: False
freeze_t5_proj: False
# moe
use_moeqformer: True
use_route_moe: True
moebert_route_method: "post-route"
moebert_load_balance: 0
moebert_expert_num: 2
moebert_num_beams: 2
moe_weight_type: 'ffn_prob'
use_balance_loss: False
# gate_save_path: "/mnt/pfs-guan-ssai/nlu/wanghanzi/experiments/blip2/vicuna7b/qformer_moe/route_save/mix_coco_gqa_balance_raw_QformerMoE_Route_linear_gate_5ex_2beam_1loss_textinqf_epo5_toplayer3_1209/"
datasets:
# gqa: # train: 943000, 12578, 12578)
# type: balanced_sft_raw
# batch_size: 1
# vis_processor:
# train:
# name: "blip2_image_train"
# image_size: 224
# eval:
# name: "blip2_image_eval"
# image_size: 224
# text_processor:
# train:
# name: "blip_caption"
# eval:
# name: "blip_caption"
# sample_ratio: 10
ok_vqa: # train, valid (9009, 5046)
batch_size: 1
vis_processor:
train:
name: "blip2_image_train"
image_size: 224
eval:
name: "blip2_image_eval"
image_size: 224
text_processor:
train:
name: "blip_caption"
eval:
name: "blip_caption"
sample_ratio: 1
# coco_vqa: # 658104
# batch_size: 1
# vis_processor:
# train:
# name: "blip2_image_train"
# image_size: 224
# eval:
# name: "blip2_image_eval"
# image_size: 224
# text_processor:
# train:
# name: "blip_caption"
# eval:
# name: "blip_caption"
# sample_ratio: 9
run:
task: instruction_tuning
# optimizer
lr_sched: "linear_warmup_cosine_lr"
init_lr: 2e-5
min_lr: 1e-6
warmup_lr: 1e-6
log_freq: 5
save_freq: 1500
weight_decay: 0.05
max_epoch: 5
num_workers: 4
warmup_steps: 600
seed: 42
output_dir: "/mnt/pfs-guan-ssai/nlu/wanghanzi/experiments/blip2/vicuna7b/qformer_moe_route/mix_coco_gqa_balance_raw_QformerMoE_Route_linear_gate_5ex_2beam_1loss_textinqf_epo5_toplayer3_1209/"
amp: True
resume_ckpt_path: null
evaluate: False
train_splits: ["train"]
valid_splits: ["val"]
# test_splits: ["val"]
device: "cuda"
world_size: 1
dist_url: "env://"
distributed: True

View File

@ -10,7 +10,7 @@ model:
load_finetuned: True load_finetuned: True
vit_model: eva_clip_g vit_model: eva_clip_g
pretrained: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/blip2/blip2_vicuna7b/blip2_pretrained_vicuna7b.pth" pretrained: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/blip2/blip2_vicuna7b/blip2_pretrained_vicuna7b.pth"
finetuned: "/mnt/pfs-guan-ssai/nlu/wanghanzi/experiments/blip2/vicuna7b/qformer_moe_route/mix_1048k_raw_QformerMoE_Route_Post_NoNorm_5ex_2beam_1loss_top3layer_textinqf_epo6_1215/20231216161/checkpoint_best.pth" finetuned: "/mnt/pfs-guan-ssai/nlu/wanghanzi/experiments/blip2/vicuna7b/qformer_moe_route/mix_coco_gqa_1610k_raw_QformerMoE_Route_Post_ffn_prob_linear_gate_2ex_2beam_1gate_1loss_5e5lr_top6layer_textinqf_epo8_0111/20240111145/checkpoint_best.pth"
q_former_model: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/blip2/blip2_vicuna7b/blip2_pretrained_vicuna7b.pth" q_former_model: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/blip2/blip2_vicuna7b/blip2_pretrained_vicuna7b.pth"
# vit encoder # vit encoder
@ -39,8 +39,11 @@ model:
use_moeqformer: True use_moeqformer: True
use_route_moe: True use_route_moe: True
moebert_route_method: "post-route" moebert_route_method: "post-route"
moebert_expert_num: 5 moebert_load_balance: 0
moebert_expert_num: 2
moebert_num_beams: 2 moebert_num_beams: 2
moe_weight_type: 'ffn_prob'
use_balance_loss: False
datasets: datasets:
ok_vqa: # train, valid (9009, 5046) ok_vqa: # train, valid (9009, 5046)
@ -78,7 +81,7 @@ evaluation_datasets:
run: run:
task: instruction_tuning task: instruction_tuning
name: vqa_benchmark_evaluation name: vqa_benchmark_evaluation
save_path: "/mnt/pfs-guan-ssai/nlu/wanghanzi/experiments/blip2/vicuna7b/qformer_moe/eval/benchmarks/mix_1048k_raw_QformerMoE_Route_Post_NoNorm_5ex_2beam_1loss_top3layer_textinqf_epo6_1215/" save_path: "/mnt/pfs-guan-ssai/nlu/wanghanzi/experiments/blip2/vicuna7b/qformer_moe_route/eval/benchmarks/mix_coco_gqa_1610k_raw_QformerMoE_Route_Post_ffn_prob_linear_gate_2ex_2beam_1gate_1loss_5e5lr_top6layer_textinqf_epo8_0111/"
seed: 42 seed: 42

View File

@ -0,0 +1,131 @@
# Copyright (c) 2022, salesforce.com, inc.
# All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
model:
arch: blip2_vicuna_instruct
model_type: vicuna7b_pretrain
load_pretrained: True
load_finetuned: False
vit_model: eva_clip_g
pretrained: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/blip2/blip2_vicuna7b/blip2_pretrained_vicuna7b.pth"
# finetuned: ""
q_former_model: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/blip2/blip2_vicuna7b/blip2_pretrained_vicuna7b.pth"
# vit encoder
image_size: 224
drop_path_rate: 0
use_grad_checkpoint: False
vit_precision: "fp16"
# Q-Former
num_query_token: 32
qformer_text_input: True
# vicuna7b
llm_model: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/vicuna-7b-v1.1"
prompt: ""
max_txt_len: 256
max_output_txt_len: 256
# freeze
freeze_vit: True
freeze_llm: True
freeze_qformer: False
freeze_t5_proj: False
# moe
use_moeqformer: True
use_route_moe: True
moebert_route_method: "post-route"
moebert_load_balance: 0.05
moebert_expert_num: 3
moebert_num_beams: 3
moe_weight_type: 'ffn_prob'
use_balance_loss: False
datasets:
gqa: # train: 943000, 12578, 12578)
type: balanced_sft_raw
# batch_size: 16
batch_size: 32
vis_processor:
train:
name: "blip2_image_train"
image_size: 224
eval:
name: "blip2_image_eval"
image_size: 224
text_processor:
train:
name: "blip_caption"
eval:
name: "blip_caption"
sample_ratio: 50
ok_vqa: # train, valid (9009, 5046)
# batch_size: 16
batch_size: 32
vis_processor:
train:
name: "blip2_image_train"
image_size: 224
eval:
name: "blip2_image_eval"
image_size: 224
text_processor:
train:
name: "blip_caption"
eval:
name: "blip_caption"
sample_ratio: 8
coco_vqa: # 658104
# batch_size: 16
batch_size: 32
vis_processor:
train:
name: "blip2_image_train"
image_size: 224
eval:
name: "blip2_image_eval"
image_size: 224
text_processor:
train:
name: "blip_caption"
eval:
name: "blip_caption"
sample_ratio: 15
run:
task: instruction_tuning
# optimizer
lr_sched: "linear_warmup_cosine_lr"
# init_lr: 2e-5
init_lr: 5e-5
min_lr: 1e-6
warmup_lr: 1e-6
log_freq: 5
save_freq: 1500
weight_decay: 0.05
max_epoch: 8
num_workers: 4
warmup_steps: 600
iters_per_epoch: 5000
seed: 42
output_dir: "/mnt/pfs-guan-ssai/nlu/wanghanzi/experiments/blip2/vicuna7b/qformer_moe_route/mix_coco_gqa_1610k_raw_QformerMoE_Route_Post_ffn_prob_linear_gate_3ex_3beam_1gate_2loss_5e5lr_top6layer_textinqf_epo8_0112/"
amp: True
resume_ckpt_path: null
evaluate: False
train_splits: ["train"]
valid_splits: ["val"]
device: "cuda"
world_size: 1
dist_url: "env://"
distributed: True

View File

@ -37,19 +37,19 @@ model:
# moe # moe
use_moeqformer: True use_moeqformer: True
moebert_expert_num: 5 moebert_expert_num: 3
moebert_route_method: "gate-sentence" moebert_route_method: "gate-sentence"
moebert_load_balance: 0 moebert_load_balance: 0
moe_topk: 1 moe_topk: 1
use_balance_loss: False use_balance_loss: False
moe_weight_type: 'l2_norm' moe_weight_type: 'raw_prob'
gate_save_path: "/mnt/pfs-guan-ssai/nlu/wanghanzi/experiments/blip2/vicuna7b/qformer_moe/gate_save/mix_coco_gqa_balance_raw_QformerMoE_train_qf_train_qt_linear_gate_5ex_top1_1loss_textinqf_training_epo5_toplayer3_1206/" # gate_save_path: "/mnt/pfs-guan-ssai/nlu/wanghanzi/experiments/blip2/vicuna7b/qformer_moe/gate_save/mix_coco_gqa_balance_raw_QformerMoE_train_qf_train_qt_linear_gate_5ex_top1_1loss_textinqf_training_epo5_toplayer3_1206/"
datasets: datasets:
gqa: # train: 94254 gqa: # train: 94254
type: balanced_sft_raw_part type: balanced_sft_raw_part
batch_size: 32 batch_size: 1
vis_processor: vis_processor:
train: train:
name: "blip2_image_train" name: "blip2_image_train"
@ -65,7 +65,7 @@ datasets:
sample_ratio: 50 sample_ratio: 50
ok_vqa: # train, valid (9009, 5046 ok_vqa: # train, valid (9009, 5046
batch_size: 32 batch_size: 1
vis_processor: vis_processor:
train: train:
name: "blip2_image_train" name: "blip2_image_train"
@ -80,22 +80,22 @@ datasets:
name: "blip_caption" name: "blip_caption"
sample_ratio: 8 sample_ratio: 8
coco_vqa: # 214352 vqa_val # coco_vqa: # 214352 vqa_val
type: vqa_v2_part # type: vqa_v2_part
batch_size: 32 # batch_size: 1
vis_processor: # vis_processor:
train: # train:
name: "blip2_image_train" # name: "blip2_image_train"
image_size: 224 # image_size: 224
eval: # eval:
name: "blip2_image_eval" # name: "blip2_image_eval"
image_size: 224 # image_size: 224
text_processor: # text_processor:
train: # train:
name: "blip_caption" # name: "blip_caption"
eval: # eval:
name: "blip_caption" # name: "blip_caption"
sample_ratio: 15 # sample_ratio: 15
run: run:
task: instruction_tuning task: instruction_tuning
@ -108,12 +108,13 @@ run:
save_freq: 1500 save_freq: 1500
weight_decay: 0.05 weight_decay: 0.05
max_epoch: 5 max_epoch: 1
num_workers: 4 num_workers: 4
warmup_steps: 600 warmup_steps: 600
iters_per_epoch: 1000
seed: 42 seed: 42
output_dir: "/mnt/pfs-guan-ssai/nlu/wanghanzi/experiments/blip2/vicuna7b/qformer_moe/mix_coco_gqa_balance_raw_QformerMoE_train_qf_train_qt_linear_gate_5ex_top1_1loss_textinqf_training_epo5_toplayer3_1206/" output_dir: "/mnt/pfs-guan-ssai/nlu/wanghanzi/experiments/blip2/vicuna7b/qformer_moe/mix_coco_gqa_balance_raw_QformerMoE_train_qf_train_qt_linear_gate_5ex_top1_1loss_textinqf_training_epo5_toplayer3_1220_test/"
amp: True amp: True
resume_ckpt_path: null resume_ckpt_path: null

View File

@ -0,0 +1,125 @@
# Copyright (c) 2022, salesforce.com, inc.
# All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
model:
arch: blip2_vicuna_instruct
model_type: vicuna7b_pretrain
load_pretrained: True
load_finetuned: False
vit_model: eva_clip_g
pretrained: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/blip2/blip2_vicuna7b/blip2_pretrained_vicuna7b.pth"
# finetuned: ""
q_former_model: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/blip2/blip2_vicuna7b/blip2_pretrained_vicuna7b.pth"
# vit encoder
image_size: 224
drop_path_rate: 0
use_grad_checkpoint: False
vit_precision: "fp16"
# Q-Former
num_query_token: 32
qformer_text_input: True
# vicuna7b
llm_model: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/vicuna-7b-v1.1"
prompt: ""
max_txt_len: 256
max_output_txt_len: 256
# freeze
freeze_vit: True
freeze_llm: True
freeze_qformer: False
freeze_t5_proj: False
# moe
use_moeqformer: False
moebert_expert_num: 1
moebert_route_method: "gate-sentence"
moebert_load_balance: 0.05
moe_topk: 1
datasets:
gqa: # train: 943000, 12578, 12578)
type: balanced_sft_raw
batch_size: 16
vis_processor:
train:
name: "blip2_image_train"
image_size: 224
eval:
name: "blip2_image_eval"
image_size: 224
text_processor:
train:
name: "blip_caption"
eval:
name: "blip_caption"
sample_ratio: 50
ok_vqa: # train, valid (9009, 5046)
batch_size: 16
vis_processor:
train:
name: "blip2_image_train"
image_size: 224
eval:
name: "blip2_image_eval"
image_size: 224
text_processor:
train:
name: "blip_caption"
eval:
name: "blip_caption"
sample_ratio: 8
coco_vqa: # 658104
batch_size: 16
vis_processor:
train:
name: "blip2_image_train"
image_size: 224
eval:
name: "blip2_image_eval"
image_size: 224
text_processor:
train:
name: "blip_caption"
eval:
name: "blip_caption"
sample_ratio: 15
run:
task: instruction_tuning
# optimizer
lr_sched: "linear_warmup_cosine_lr"
# init_lr: 2e-5
init_lr: 5e-5
min_lr: 1e-6
warmup_lr: 1e-6
log_freq: 5
save_freq: 1500
weight_decay: 0.05
max_epoch: 8
num_workers: 4
warmup_steps: 600
iters_per_epoch: 5000
seed: 42
output_dir: "/mnt/pfs-guan-ssai/nlu/wanghanzi/experiments/blip2/vicuna7b/qformer_moe/mix_coco_gqa_1610k_raw_QformerMoE_train_qf_train_qt_1ex_top1_textinqf_epo8_lr5e5_seed42_0112/"
amp: True
resume_ckpt_path: null
evaluate: False
train_splits: ["train"]
valid_splits: ["val"]
device: "cuda"
world_size: 1
dist_url: "env://"
distributed: True

View File

@ -110,6 +110,7 @@ class RunnerBase:
else: else:
p_wd.append(p) p_wd.append(p)
num_parameters += p.data.nelement() num_parameters += p.data.nelement()
# import pdb; pdb.set_trace() # 0107test
logging.info("number of trainable parameters: %d" % num_parameters) logging.info("number of trainable parameters: %d" % num_parameters)
optim_params = [ optim_params = [
{ {

View File

@ -238,13 +238,17 @@ class BaseTask:
with torch.cuda.amp.autocast(enabled=use_amp): with torch.cuda.amp.autocast(enabled=use_amp):
loss = self.train_step(model=model, samples=samples) loss = self.train_step(model=model, samples=samples)
# after_train_step() # after_train_step()
if use_amp: if use_amp:
# torch.autograd.set_detect_anomaly(True)
# 反向传播时检测是否有异常值定位code
# with torch.autograd.detect_anomaly():
scaler.scale(loss).backward() scaler.scale(loss).backward()
else: else:
loss.backward() loss.backward()
# import pdb; pdb.set_trace() # 0107test
# update gradients every accum_grad_iters iterations # update gradients every accum_grad_iters iterations
if (i + 1) % accum_grad_iters == 0: if (i + 1) % accum_grad_iters == 0:
if use_amp: if use_amp:
@ -252,6 +256,9 @@ class BaseTask:
scaler.update() scaler.update()
else: else:
optimizer.step() optimizer.step()
# import pdb; pdb.set_trace()# 0107test
optimizer.zero_grad() optimizer.zero_grad()
# if self.cfg.wandb_log: # if self.cfg.wandb_log:
# if self.cfg.run_cfg.wandb_log: # if self.cfg.run_cfg.wandb_log:

View File

@ -44,4 +44,6 @@ wheel
visualizer visualizer
tensorboard tensorboard
kmeans_pytorch kmeans_pytorch
visual_genome visual_genome
gpustat
torchviz

36
setup.py Normal file
View File

@ -0,0 +1,36 @@
"""
Copyright (c) 2022, salesforce.com, inc.
All rights reserved.
SPDX-License-Identifier: BSD-3-Clause
For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""
from setuptools import setup, find_namespace_packages
import platform
DEPENDENCY_LINKS = []
if platform.system() == "Windows":
DEPENDENCY_LINKS.append("https://download.pytorch.org/whl/torch_stable.html")
def fetch_requirements(filename):
with open(filename) as f:
return [ln.strip() for ln in f.read().split("\n")]
setup(
name="PromptMoE",
version="1.0.1",
author="Hanzi Wang",
description="PromptMoE & QformerMoE Based on LAVIS",
long_description=open("README.md", "r", encoding="utf-8").read(),
long_description_content_type="text/markdown",
keywords="Vision-Language, Multimodal, Image Captioning, Generative AI, Deep Learning, Library, PyTorch",
license="3-Clause BSD",
packages=find_namespace_packages(include="lavis.*"),
install_requires=fetch_requirements("requirements.txt"),
python_requires=">=3.7.0",
include_package_data=True,
dependency_links=DEPENDENCY_LINKS,
zip_safe=False,
)

5570
test.pdf/backward_graph Normal file

File diff suppressed because it is too large Load Diff

BIN
test.pdf/backward_graph.pdf Normal file

Binary file not shown.

360
test.txt Normal file
View File

@ -0,0 +1,360 @@
tmp_name = [name for name, p in model.named_parameters() if (p.requires_grad and '10.expert' in name)]
tmp = [p for name, p in model.named_parameters() if (p.requires_grad and '10.expert' in name)]
tensor([[-1.4032e-02, 3.7242e-03, 8.4997e-03, -3.4016e-03, -6.4855e-03,
4.3595e-02, 3.4423e-02, -8.6274e-03, -1.9702e-02, 9.1813e-03,
1.1643e-02, 2.3939e-02, -2.0908e-02, 3.4555e-03, 9.1636e-03,
1.5413e-02, 2.4148e-02, -1.0880e-03, 1.1193e-02, -1.3591e-02,
9.3484e-03, 1.5999e-02, -9.6086e-04, 3.8322e-02, -8.0687e-03,
-1.4056e-02, 3.9486e-02, 3.5167e-02, -9.3226e-03, -1.0493e-02,
-2.5795e-02, -9.7541e-03, 4.4437e-03, 7.7226e-03, 7.5210e-03,
-1.3526e-02, -5.0316e-03, -1.1149e-02, 6.0583e-03, 2.0564e-02,
-6.4477e-03, 1.4170e-02, -3.7847e-02, 1.1780e-02, 1.3321e-02,
-8.2501e-03, -1.0298e-02, 1.4805e-02, -1.2432e-02, -1.9159e-02,
-5.7095e-04, -3.8618e-02, -2.4230e-02, -1.4991e-03, -1.4114e-02,
-1.5365e-02, 1.5640e-02, -4.8623e-02, -2.9991e-02, 1.2796e-02,
-4.9917e-03, 2.3846e-03, 7.7368e-03, 1.2913e-02, 1.5300e-02,
8.5125e-03, 1.1582e-02, 8.1161e-03, 4.2259e-03, 7.6109e-03,
-2.0747e-02, -3.5099e-03, 2.2282e-02, 5.0493e-02, -1.7849e-02,
-3.7106e-02, -1.4944e-02, -1.4582e-02, -2.2458e-02, -4.6173e-05,
-8.1270e-03, 1.9037e-02, -2.0086e-02, 3.0980e-03, -9.3947e-03,
1.3054e-02, 2.3203e-02, -9.9304e-03, -2.6038e-02, 1.8679e-02,
9.2081e-03, -2.1770e-02, -1.6568e-03, -3.6503e-02, 2.0054e-02,
1.2886e-02, -1.8021e-02, 3.4457e-02, -1.3704e-02, -6.1498e-03,
-8.6769e-03, 1.5024e-02, -1.3875e-02, 1.7416e-02, -1.1178e-02,
-2.4088e-02, -1.7802e-02, 3.3326e-02, -1.1216e-02, -8.6330e-03,
-5.5359e-03, -1.1939e-02, -1.7777e-02, -2.8666e-02, -3.8280e-02,
4.2682e-02, 1.4946e-02, 9.6427e-03, 8.2754e-03, -1.0516e-03,
2.9560e-02, 2.4552e-03, -4.8354e-02, 1.5568e-02, 2.5881e-02,
-1.7354e-02, -3.1232e-02, 2.3683e-02, -2.3239e-02, 2.2966e-02,
5.6349e-03, -8.7595e-03, 1.5173e-02, 2.7660e-02, -4.3304e-03,
-2.5330e-02, -2.1795e-02, 1.6856e-02, -2.1587e-04, 2.3707e-02,
-2.3667e-02, 3.5378e-02, -7.9245e-03, 7.1029e-04, -3.2800e-02,
-1.5402e-03, -8.5634e-03, -1.1356e-02, -2.1935e-03, -1.8854e-02,
-1.9705e-03, -3.8333e-02, 2.9131e-02, -4.4470e-02, -2.0893e-03,
1.2937e-02, -1.7116e-02, 2.7778e-02, 1.0311e-02, -6.4017e-03,
3.7647e-02, -1.9953e-02, -5.3925e-03, 3.6978e-02, -1.5534e-02,
1.2241e-02, 1.3597e-02, 2.0703e-03, 2.4213e-03, 9.2604e-03,
6.6108e-03, -5.8213e-03, 9.8167e-03, -9.8300e-04, -1.0236e-02,
2.9581e-02, 1.0987e-02, 2.0046e-02, -1.0500e-02, -3.2221e-03,
-2.6303e-02, 1.3688e-02, -2.2529e-02, -5.7654e-03, 1.1784e-02,
1.6221e-02, 2.8743e-02, 5.7565e-03, 1.8129e-02, 1.5140e-02,
-1.1748e-02, -1.7528e-02, 4.7977e-02, 1.5568e-02, 4.7030e-04,
3.2757e-03, 1.6631e-02, 1.9986e-02, -7.3463e-03, 1.1435e-02,
-1.4739e-02, -3.2959e-03, -2.8770e-03, 2.9260e-02, 1.7007e-02,
3.0611e-02, 2.2102e-02, -3.3819e-02, -1.9403e-02, 2.5524e-02,
3.0738e-02, -1.9951e-02, -1.4553e-02, -1.5796e-02, -2.3143e-02,
-2.8826e-02, 2.4739e-02, -5.8602e-03, 4.1871e-02, 5.0821e-04,
3.3493e-02, 2.3524e-02, 2.3191e-02, 9.0416e-03, 3.3262e-02,
-1.6805e-02, 1.1545e-02, -1.7195e-02, -3.8696e-02, -8.4358e-04,
-8.1605e-03, 3.1372e-03, 1.0726e-03, 1.0865e-03, 1.0760e-02,
-5.2421e-03, 1.3039e-02, 3.6873e-04, 1.0464e-02, -1.1544e-02,
-2.2775e-02, -4.8439e-02, -1.0711e-02, 4.4236e-03, 2.0351e-02,
2.4479e-03, -1.9968e-02, -2.2941e-02, -2.0486e-02, -1.9528e-02,
-2.3176e-02, -3.2731e-03, 1.1789e-02, 2.0921e-02, 2.9809e-03,
-8.8507e-03, -3.5716e-02, 8.8418e-03, 5.3665e-05, -1.1288e-02,
-7.5571e-03, 2.1053e-02, -3.7381e-03, -4.0165e-03, -2.2628e-03,
3.7554e-03, -1.6597e-02, 7.6946e-03, -3.2689e-02, 2.2016e-02,
5.5122e-03, 4.5455e-02, 6.7586e-03, 1.5714e-02, 5.2125e-03,
3.9596e-03, 1.8134e-02, 1.5834e-03, -1.6239e-02, -1.3889e-02,
-2.3522e-02, 1.4738e-02, 5.5867e-03, -7.0727e-03, -2.8140e-03,
1.6849e-02, -3.1327e-02, -3.2443e-02, 4.7851e-03, 1.2980e-02,
-2.0014e-04, -9.9475e-03, 8.0657e-03, 1.9468e-02, -1.5774e-02,
1.7017e-02, -8.7196e-03, -4.0681e-03, -6.9754e-03, -2.2007e-02,
-6.6217e-03, -1.8219e-02, 4.2186e-02, -5.6621e-03, -9.3449e-03,
-1.1662e-02, 2.8700e-02, -9.0654e-03, 3.1569e-02, -2.9825e-03,
-3.8198e-02, -5.2723e-02, -4.8325e-02, -2.7871e-03, 5.1127e-03,
1.4511e-02, 9.3245e-03, -2.3339e-02, -8.6658e-03, 1.5276e-02,
-1.5823e-02, -3.4476e-03, 1.4601e-02, 6.3504e-03, -1.4307e-02,
2.2817e-02, 2.1998e-02, 1.7330e-02, -2.4448e-02, 4.0178e-03,
3.2280e-03, -1.2721e-02, 1.9661e-02, 7.5263e-03, 2.0245e-02,
4.5525e-02, -1.5658e-02, -4.0676e-02, 9.3160e-03, 1.1920e-02,
-1.9317e-02, 1.7848e-02, -5.8601e-03, 1.1786e-03, 8.3864e-03,
-1.8341e-02, 2.5985e-02, -1.1387e-02, -1.5069e-02, -2.8097e-02,
2.4966e-02, 1.4790e-02, 2.0424e-02, -1.3062e-02, 3.1314e-02,
1.7811e-02, 7.2393e-03, 1.4413e-02, -1.2746e-02, 3.1039e-02,
-1.1697e-02, -1.4826e-02, -8.8397e-03, 1.5157e-02, -1.5855e-02,
-1.8157e-03, 1.3024e-02, -1.8902e-03, 2.5212e-02, -3.4886e-02,
4.3029e-02, -4.0842e-02, 1.1362e-02, -1.4654e-02, -1.3337e-02,
-3.1832e-02, 3.6222e-03, 8.2804e-03, -1.4269e-02, 2.8399e-03,
-1.2008e-02, 2.4685e-02, -4.3070e-03, 6.3163e-03, -1.3517e-02,
-1.3807e-02, 2.4617e-02, 2.1453e-02, 4.7332e-03, 9.1636e-03,
-1.2881e-02, 1.9077e-02, 1.7571e-04, -5.2817e-03, -2.8821e-02,
5.8223e-03, -3.0979e-02, 2.4609e-02, 3.6666e-02, -1.0950e-02,
2.0421e-02, -2.6378e-03, 3.1825e-02, -9.6689e-04, -2.8398e-02,
-2.7513e-02, 1.6946e-02, -2.4110e-02, -1.3575e-02, -1.3443e-02,
8.4217e-03, 2.6754e-02, -2.3309e-03, -2.5086e-02, 1.1844e-02,
1.4152e-02, 1.2989e-02, -5.7336e-03, 4.7391e-03, 3.4106e-02,
1.0142e-02, -1.8029e-02, -1.5410e-04, -1.3548e-02, 9.1742e-03,
-3.0150e-02, 1.5666e-02, 4.3049e-03, 1.6273e-02, 2.0672e-02,
-1.2458e-02, 4.5496e-02, 3.2131e-02, -3.0967e-03, 2.1891e-02,
2.5524e-02, -1.1998e-02, -1.8866e-03, -1.0945e-02, 5.9930e-03,
-8.4233e-03, -8.9095e-03, -1.8261e-02, 1.9308e-02, -1.9728e-02,
-1.4216e-02, 1.4952e-02, 5.7355e-04, -2.4753e-02, -1.0948e-02,
1.0965e-02, 1.3607e-03, 3.4974e-02, -4.1396e-03, 2.5519e-02,
1.0364e-02, -1.5851e-02, -4.9224e-03, 1.0903e-02, -1.0523e-04,
3.1355e-02, -1.5105e-02, 5.6972e-03, -8.4078e-03, -1.9868e-02,
1.7186e-03, 2.9396e-02, -4.1439e-02, 1.4124e-02, -3.7745e-03,
3.3007e-02, 8.0368e-04, 8.5574e-03, 1.7269e-02, 1.1955e-02,
8.8142e-03, -1.3123e-02, 1.6817e-02, -1.5456e-02, -1.3868e-02,
2.4139e-02, -9.1566e-03, -1.8477e-02, -4.7972e-03, -6.8459e-03,
1.6818e-02, 3.1645e-03, -3.0901e-02, -5.6036e-03, -1.4758e-02,
2.0473e-02, -7.5411e-05, 2.0673e-03, -7.0061e-03, 9.5544e-03,
1.6600e-02, -1.7315e-02, -2.0168e-02, -5.3008e-03, 2.0206e-02,
2.4209e-03, 2.1205e-02, -8.9188e-03, -4.1350e-04, -1.0638e-02,
1.3705e-02, 9.5925e-05, 3.8877e-02, 3.2884e-02, -2.7730e-03,
1.0052e-02, 1.9311e-02, 1.1341e-02, -1.2988e-02, -1.7157e-02,
3.2095e-02, -1.8493e-02, -9.2551e-03, -2.6509e-03, -1.1130e-02,
1.6581e-02, 1.0216e-02, 1.3687e-02, 1.1860e-02, -3.0462e-03,
-1.2082e-02, 2.8502e-03, -1.2620e-02, 8.8330e-03, 1.7357e-02,
1.8383e-02, -2.3130e-02, -3.2654e-02, 1.2853e-02, -7.8144e-03,
1.9418e-04, 3.8635e-03, 4.9333e-02, 1.9350e-02, -2.0643e-02,
8.4650e-04, 5.0242e-02, 1.6576e-02, -8.9166e-03, -5.8805e-03,
-4.1484e-02, 9.3217e-03, -1.1292e-02, -8.7944e-03, -3.3190e-03,
5.7970e-03, -6.6078e-03, -2.4052e-02, -5.6347e-03, 8.4539e-03,
1.9250e-02, 7.9559e-03, -3.0055e-03, -3.0398e-04, 2.7007e-02,
3.1046e-03, 1.8332e-02, 5.5470e-03, 6.6815e-03, 1.1466e-02,
1.9738e-02, 1.2176e-02, -2.0220e-02, 8.6928e-03, 4.2451e-03,
4.4517e-03, -5.1524e-03, 1.0805e-02, -2.1935e-02, -1.7575e-02,
-1.2529e-02, -2.2191e-02, -1.0854e-02, -9.4462e-03, -2.9102e-02,
2.6752e-02, -1.0919e-02, -2.6724e-02, 8.3694e-04, 2.9832e-03,
1.4416e-02, -2.9906e-02, 2.3556e-02, -6.6624e-03, 2.6671e-02,
-3.6474e-02, 1.7237e-02, -2.5176e-02, 6.5560e-03, -2.6062e-02,
-2.3838e-02, 3.0629e-02, 2.5382e-02, 1.2302e-02, -1.1665e-02,
-7.0603e-03, 1.9931e-02, 2.3401e-02, -2.6047e-03, -2.7728e-02,
-1.7212e-02, 2.3061e-02, -2.5961e-02, 3.9764e-04, -2.9022e-02,
-1.5546e-03, 4.5519e-03, 2.3589e-02, -3.5005e-02, 4.1890e-03,
-1.5586e-02, 1.2389e-02, -2.1045e-02, 1.6377e-03, -1.1328e-02,
1.0195e-02, 6.4322e-03, -3.8431e-02, 2.2918e-02, -4.0123e-03,
6.6680e-02, 4.1135e-02, -1.5031e-02, -1.3550e-02, -2.2566e-02,
-2.3622e-03, -2.9323e-02, 2.1756e-02, 1.8399e-03, -4.2460e-03,
-1.5128e-03, -2.4731e-02, 1.8663e-02, 1.3469e-02, -1.3897e-02,
2.6399e-02, -8.0740e-03, -4.6753e-03, 3.9857e-02, 6.2364e-03,
2.2371e-03, 2.1501e-03, 5.9443e-02, 1.3574e-02, 7.6483e-03,
-6.2290e-03, 1.4324e-02, 1.2572e-02, 2.7331e-02, -6.0165e-03,
-5.9154e-03, -3.7000e-02, 1.4001e-02, 1.2869e-02, -2.8854e-02,
-9.4147e-03, 8.3965e-03, -1.4530e-03, -7.4215e-03, 9.0369e-03,
-2.4612e-02, 2.0625e-02, 2.2329e-02, -1.5216e-02, 1.4947e-03,
-3.6020e-02, -2.0702e-02, -4.0410e-02, -1.3157e-02, -1.5085e-02,
1.2911e-02, -2.7552e-02, -2.9781e-02, -4.7424e-03, 2.0521e-02,
-4.0043e-02, -4.8763e-02, -1.3175e-02, 2.6802e-02, 2.8869e-02,
6.5014e-03, -2.3213e-02, 1.4438e-02, -7.6318e-03, -1.9928e-03,
1.8509e-03, 2.9728e-03, 1.5225e-02, -2.9405e-03, -7.2875e-03,
2.9562e-05, -1.8661e-02, 9.1341e-03, -2.4919e-02, 2.9786e-02,
9.5186e-03, 1.5435e-02, -1.1080e-02, 1.1192e-02, -2.7315e-03,
6.9769e-05, -1.5392e-02, 4.9892e-03, 7.9857e-03, 2.0063e-02,
-2.0283e-02, -1.2596e-02, -4.1985e-04, -6.9686e-03, -5.4704e-02,
-1.9142e-02, 9.9706e-03, 2.3217e-02, -5.0579e-03, -4.9132e-02,
2.0023e-02, -2.6238e-02, 1.0709e-02, 2.1528e-02, -1.6390e-03,
-6.7829e-03, 1.3211e-02, -9.6793e-03, 1.3130e-02, -1.2878e-02,
1.7365e-02, 1.2509e-02, 1.2986e-03, -3.9292e-02, 9.5784e-03,
-8.0514e-03, -3.5619e-02, -3.2298e-02, 6.5933e-04, 9.9298e-03,
3.7268e-02, -3.4047e-02, -7.8385e-03, 2.3999e-02, 1.0386e-02,
1.7853e-02, -1.0122e-04, 5.2483e-04, -7.3150e-03, 1.0818e-02,
1.6245e-02, -3.5619e-02, -9.9190e-03, 4.0132e-03, 9.7788e-03,
2.7039e-02, -4.7858e-02, -2.0010e-02, -2.3702e-02, 7.8376e-04,
-2.5326e-02, 1.1698e-02, -1.3041e-02, 3.8634e-03, 9.3083e-03,
4.8204e-03, 3.9503e-02, -4.1356e-03]], requires_grad=True)
model.Qformer.bert.encoder.layer[10].experts.gate.weight
layer 11
0:
model.Qformer.bert.encoder.layer[11].output.dense.weight.grad
model.Qformer.bert.encoder.layer[11].intermediate.dense.weight.grad
nan:
model.Qformer.bert.encoder.layer[11].attention.output.dense.weight.grad
model.Qformer.bert.encoder.layer[11].attention.self.query.weight.grad
model.Qformer.bert.encoder.layer[11].experts.intermediate_query.dense.weight.grad
model.Qformer.bert.encoder.layer[11].experts.output_query.dense.weight.grad
None:
model.Qformer.bert.encoder.layer[11].intermediate_query.dense.weight.grad
model.Qformer.bert.encoder.layer[11].output_query.dense.weight.grad
layer 8
0:
model.Qformer.bert.encoder.layer[8].experts.experts[0].intermediate_query.dense.weight.grad
model.Qformer.bert.encoder.layer[8].experts.experts[2].intermediate_query.dense.weight.grad
model.Qformer.bert.encoder.layer[8].experts.experts[0].output_query.dense.weight.grad
model.Qformer.bert.encoder.layer[8].experts.experts[2].output_query.dense.weight.grad
nan:
model.Qformer.bert.encoder.layer[8].experts.experts[1].intermediate_query.dense.weight.grad
model.Qformer.bert.encoder.layer[8].experts.experts[1].output_query.dense.weight.grad
(Qformer)model.Qformer.bert.encoder.layer[8].intermediate_query.dense.weight.grad
None:
model.Qformer.bert.encoder.layer[8].experts.gate.weight.grad == None
model.Qformer.bert.encoder.layer[8].experts.gate.weight.requires_grad == True
model.Qformer.bert.encoder.layer[6].experts.gate.weight
Qformer.bert.encoder.layer.6.experts.gate.weight
tensor([[-0.0089, -0.0123, -0.0168, ..., -0.0072, 0.0295, -0.0167],
[ 0.0305, 0.0277, -0.0215, ..., 0.0149, 0.0016, -0.0415],
[ 0.0199, 0.0151, 0.0237, ..., 0.0007, 0.0023, 0.0167]],
requires_grad=True)
tensor([[-0.0089, -0.0123, -0.0168, ..., -0.0072, 0.0295, -0.0167],
[ 0.0305, 0.0277, -0.0215, ..., 0.0149, 0.0016, -0.0415],
[ 0.0199, 0.0151, 0.0237, ..., 0.0007, 0.0023, 0.0167]],
requires_grad=True)
tensor([[ 4.5972e-02, -1.5231e-02, -6.9533e-03, 3.2431e-02, -7.9703e-03,
1.5567e-02, 2.9619e-03, -2.2609e-04, 1.8580e-02, -2.8783e-02,
1.3093e-02, -1.0594e-02, 1.1918e-02, 4.4701e-02, 2.0108e-02,
-1.1011e-03, -8.2449e-03, 8.8876e-03, 4.6096e-03, 2.3274e-02,
-9.2557e-03, 2.5704e-03, 1.8919e-02, -5.3251e-03, -3.2665e-03,
-3.2663e-02, -5.6756e-02, -2.3400e-02, 1.3674e-02, -6.6185e-03,
1.4429e-03, 1.2354e-02, 2.5934e-03, 2.1895e-02, -1.9793e-02,
1.5497e-03, 4.3056e-03, -4.0023e-02, 9.8740e-03, 3.8631e-03,
-1.2918e-02, -3.6782e-02, -9.8365e-03, 3.2182e-02, 2.3729e-02,
2.3509e-03, 1.8473e-02, 1.5583e-02, -1.1029e-02, -1.0738e-02,
-3.0278e-02, -9.8731e-03, -1.0500e-02, 7.9832e-05, -1.0345e-02,
8.2803e-03, -5.9923e-03, -1.2669e-02, 1.2065e-03, 7.5720e-03,
-1.9286e-02, 4.0070e-02, 3.6221e-03, -1.7486e-02, 2.1725e-02,
-3.3231e-02, 7.3948e-03, -1.0924e-02, 3.1448e-02, 1.2101e-02,
6.1737e-03, -2.0851e-02, -3.7964e-02, 8.0938e-03, -8.8967e-03,
2.5925e-02, -7.8063e-04, 8.6102e-03, 2.7370e-02, 1.2323e-02,
4.0606e-03, 3.9316e-02, -1.0837e-02, -2.6835e-03, 3.1941e-03,
-1.2017e-02, -2.3022e-02, 8.3533e-03, -2.2668e-02, 1.4438e-02,
-2.3664e-02, 4.5595e-02, -1.0962e-02, 1.7547e-02, -1.6739e-03,
1.2048e-02, 2.0544e-02, 2.8837e-02, -1.6736e-02, 2.1207e-02,
8.7612e-03, 2.8757e-02, -3.8561e-03, 8.4050e-03, -1.1503e-02,
-5.8332e-03, 1.5734e-02, -1.0773e-02, 7.5827e-03, 6.5794e-03,
2.4291e-02, 2.6811e-02, 1.1681e-02, -3.3246e-02, 4.5776e-03,
-9.0628e-04, -2.9400e-02, 4.2933e-03, 1.5885e-03, 5.5757e-02,
7.5518e-03, 1.0099e-02, 5.3507e-03, -3.0182e-02, 2.0830e-02,
1.0102e-02, -9.3074e-03, 3.1161e-02, -1.7800e-02, -4.4445e-03,
-3.1503e-02, 2.3028e-02, 8.3472e-03, 7.4444e-03, 1.8838e-02,
-1.1977e-02, -2.6713e-02, 1.1364e-02, 8.3522e-04, 3.3736e-03,
6.9425e-03, -2.0632e-02, 1.8155e-02, -2.1711e-02, -3.4703e-02,
-3.6268e-03, -4.8810e-03, -2.8142e-02, -1.5781e-02, -3.3166e-02,
-2.9910e-02, -9.7459e-03, -6.7474e-03, 1.7988e-02, 9.0176e-03,
1.9452e-02, 4.2009e-02, 1.7217e-02, 1.4959e-02, -1.6552e-02,
-3.8206e-03, -2.4889e-02, 7.7993e-03, -1.9285e-02, -1.9770e-02,
2.6936e-02, -5.0484e-03, -2.5117e-02, -2.3122e-02, 1.3754e-02,
1.6025e-02, -9.1569e-03, -2.0068e-02, -1.6013e-02, -2.1775e-02,
-2.4154e-02, 6.2840e-03, -1.3684e-02, 2.5378e-02, -1.3166e-02,
-1.2201e-02, 1.0011e-02, -8.2324e-03, -5.6623e-03, -1.0383e-02,
-1.6251e-02, 1.0723e-02, -3.0207e-03, -6.9374e-03, -2.3161e-03,
-2.0850e-03, -3.4216e-02, 3.3997e-02, 3.7444e-02, -3.4273e-02,
1.5051e-02, -9.5605e-03, -2.6979e-03, 1.8848e-02, 2.3090e-02,
1.9669e-02, -3.9656e-02, 1.0453e-02, 5.2222e-03, -7.2493e-03,
1.4122e-02, 5.6583e-04, -1.3991e-02, 4.0975e-02, 1.3947e-02,
4.6919e-03, 7.9121e-03, 2.6936e-02, 1.2338e-02, 1.9048e-02,
7.7740e-03, -6.4494e-03, -5.2965e-02, 8.1929e-03, -1.3503e-02,
3.7466e-03, -3.3504e-02, -8.1192e-03, 1.0463e-02, -2.1568e-02,
1.0076e-02, -1.3420e-02, -6.3353e-04, 7.4253e-03, 2.2281e-02,
5.2829e-03, 1.4102e-02, 1.4427e-02, 1.6331e-02, -2.3305e-04,
-4.4875e-02, 6.5300e-03, 2.4963e-02, 2.2141e-03, 3.9830e-02,
1.1405e-02, 8.6810e-03, -2.0404e-03, -1.8579e-03, 1.4765e-02,
5.4752e-03, -1.3364e-02, -1.3082e-03, 1.5873e-03, 1.9309e-02,
3.4367e-02, 1.8459e-02, -1.1323e-02, -1.8764e-02, -1.5370e-02,
3.6180e-03, 2.8253e-02, -1.6867e-03, 3.5884e-03, -2.1952e-02,
-1.5026e-02, -2.1070e-02, -1.2149e-02, 1.1162e-02, -3.0343e-02,
-4.1372e-02, 1.0880e-02, 2.2365e-02, 1.2896e-02, 2.9694e-02,
-8.4248e-03, -7.8876e-03, -6.7049e-03, 2.3700e-02, 4.7528e-03,
-7.8350e-03, -5.9220e-03, 3.8396e-02, -4.1598e-02, -2.3161e-03,
1.3419e-02, 7.1029e-03, 1.4195e-02, -1.1124e-02, 1.5812e-02,
-1.9789e-02, -2.3883e-02, -8.2788e-04, 1.4670e-02, -2.1482e-02,
-1.1182e-02, -1.6532e-02, -8.0637e-03, -3.7822e-02, 3.9402e-02,
-1.4097e-03, -7.6648e-03, -3.7156e-02, 2.5791e-02, 6.1038e-03,
-6.3429e-03, 3.2865e-03, 3.6277e-02, 9.4312e-03, -2.1003e-02,
-3.6885e-03, 1.7147e-02, -1.3079e-02, -4.9414e-02, -3.2066e-02,
1.4835e-02, -2.9742e-02, 1.8358e-02, -2.1733e-02, 3.0256e-03,
1.7825e-02, 1.1079e-02, 1.1619e-02, -2.3680e-02, -7.8721e-03,
2.4456e-03, 4.3608e-02, -4.5674e-03, -3.6818e-02, 3.3952e-02,
3.3108e-02, -3.1665e-03, -2.3468e-03, 1.5091e-02, 7.0856e-03,
1.1723e-02, -2.0713e-02, -6.9180e-03, 3.7929e-02, 3.7671e-03,
4.6663e-02, 9.5301e-03, 1.2638e-02, -6.5623e-03, -3.1771e-03,
-1.7568e-02, 1.8711e-03, -1.2310e-02, 2.1518e-02, 4.3408e-03,
-6.7171e-03, -5.0451e-03, 2.6870e-02, -1.9832e-02, 7.0422e-03,
1.1274e-02, -2.4637e-02, -4.8450e-03, 2.1892e-02, -2.6059e-02,
1.5605e-02, -1.1617e-02, -1.9273e-02, -8.6735e-04, -9.8002e-04,
-1.8553e-02, 2.1239e-02, 2.1078e-02, -1.2091e-02, 9.7025e-03,
1.3426e-02, -1.1710e-02, -2.2242e-03, 6.4133e-03, -1.4820e-02,
1.4682e-02, 3.0679e-02, 1.1526e-02, 1.0072e-02, -1.1572e-02,
2.6128e-02, 4.0879e-03, -1.7936e-02, 1.3715e-02, -2.3667e-02,
2.0419e-03, -1.6887e-02, 1.2595e-02, -2.1988e-02, -2.3777e-02,
-1.0399e-02, 2.4868e-03, -1.2265e-02, -1.8543e-02, 3.4672e-02,
2.1114e-02, 2.0523e-02, 7.6818e-03, 2.9282e-02, -5.9593e-03,
-2.8496e-02, 2.8482e-03, 3.6874e-04, 4.7455e-02, -2.9770e-02,
-2.0684e-02, -2.0749e-02, -5.7681e-02, -2.6175e-03, -2.4488e-02,
-5.2550e-03, -7.1191e-03, 3.8192e-02, 4.3438e-02, 5.4181e-03,
2.8392e-02, 1.9493e-02, -3.5262e-02, 1.4839e-02, 4.6481e-03,
1.7219e-02, 2.0160e-02, 4.9998e-03, 2.1316e-02, -8.7929e-04,
-2.1542e-02, 3.9816e-03, 1.5879e-02, 9.9231e-03, 1.3962e-02,
-5.3418e-03, 3.9857e-02, 2.0997e-02, -2.1291e-05, 1.8133e-02,
-1.2472e-02, 4.9437e-03, -1.5099e-02, 4.8860e-02, 6.1980e-03,
2.0197e-02, 1.3141e-04, -3.1087e-03, -2.2718e-03, 2.3804e-02,
6.0726e-03, -2.0485e-02, -2.0514e-02, -2.7679e-02, -3.0412e-02,
-1.7661e-02, -1.7462e-02, 7.5216e-03, 2.2238e-02, 1.1413e-03,
2.6647e-02, -2.3855e-02, 2.2652e-03, -4.3256e-03, -9.3274e-03,
2.5149e-02, 6.8432e-03, 4.2664e-03, 3.8221e-02, 7.7480e-03,
8.7203e-03, -1.2851e-03, -1.1325e-02, -1.0650e-02, -2.8079e-02,
-1.5375e-02, 2.2630e-02, -4.3439e-03, 1.3493e-02, -1.8223e-02,
9.9750e-03, -2.4560e-02, 1.0904e-03, -3.1198e-02, 4.7331e-03,
1.6713e-02, -1.7653e-02, -3.8674e-02, 1.5458e-02, 4.0555e-02,
6.9451e-03, 1.1988e-03, 8.0718e-04, 3.9985e-03, -2.2781e-02,
8.1173e-04, 2.0106e-02, -1.2800e-02, -1.2961e-02, -2.1273e-02,
-4.4104e-05, -3.6080e-02, -1.9392e-02, 3.2862e-02, -5.6041e-03,
2.3288e-02, -4.6795e-02, 1.7282e-02, 5.7052e-03, 2.2405e-02,
1.9871e-03, -1.4333e-02, 5.3773e-03, 4.3568e-02, 9.8980e-03,
-1.9403e-03, 1.8981e-02, -2.5712e-02, -3.3621e-03, 2.9886e-02,
1.3326e-03, 1.1318e-02, -3.3238e-03, -1.5494e-02, -3.0565e-02,
1.7137e-02, -2.7874e-02, -1.1257e-02, 3.2250e-02, -2.5293e-02,
-3.0693e-03, -2.7787e-02, 1.4931e-02, 2.4202e-03, -4.0572e-03,
5.0273e-03, 9.7496e-03, 2.2601e-02, 3.2389e-02, -1.1910e-02,
9.1037e-03, 5.6000e-02, -1.9640e-02, 1.5469e-02, -3.3027e-02,
1.4839e-02, 2.5071e-02, -1.2687e-02, -1.3466e-02, 1.9031e-02,
-7.3403e-03, -1.5207e-02, -1.4486e-02, 2.0678e-02, -4.1996e-02,
1.0585e-02, 3.6276e-02, 6.1149e-03, 1.6405e-02, 1.5643e-02,
1.5060e-02, -5.1235e-03, -2.2824e-02, -1.3752e-02, -1.5742e-02,
2.4032e-02, -2.1782e-03, -1.3158e-02, 3.9482e-03, 3.2267e-02,
-2.2632e-03, 1.2055e-02, 4.4731e-02, 1.8271e-02, -1.1486e-02,
1.7836e-02, 1.7886e-03, -2.4020e-02, 2.6064e-02, -2.2122e-04,
1.8643e-02, -2.9808e-02, -6.1845e-03, -4.4464e-03, 8.8374e-04,
1.5268e-02, 1.7205e-03, 5.7832e-02, -1.7486e-02, 1.1897e-02,
5.8081e-02, 1.7667e-02, -7.7282e-03, 1.4036e-02, -1.4936e-03,
6.0635e-04, 1.6124e-03, -1.6916e-02, -1.1239e-02, 1.8497e-02,
1.2334e-03, -2.0706e-02, 3.2959e-03, 2.9186e-02, 3.7506e-02,
1.2037e-02, -1.4903e-02, 8.5606e-03, 3.4136e-03, 1.1850e-02,
-7.4782e-03, 5.3924e-03, -2.4772e-02, 2.6840e-02, -2.7656e-02,
-3.2637e-02, -1.2779e-02, 1.0730e-02, 1.4096e-03, 3.1572e-02,
7.8976e-04, 3.1674e-02, 8.5333e-03, -1.2679e-02, 1.1176e-02,
-2.0446e-02, 1.8628e-02, -4.0158e-02, -2.3358e-02, -2.2504e-02,
-2.8759e-02, -1.4597e-02, -8.5879e-03, 1.0550e-02, -3.5556e-02,
-1.9046e-02, -1.9159e-02, -2.2703e-02, -7.2056e-03, 4.2380e-02,
-9.7475e-03, -2.4754e-02, 1.3992e-03, -1.0411e-02, 1.5708e-02,
-8.2899e-03, -6.4856e-03, 1.6359e-02, -5.1969e-04, -5.0958e-03,
-4.1232e-02, 2.7349e-03, -1.7723e-02, 1.3388e-02, 2.2776e-03,
-2.0786e-02, -1.8082e-02, -2.4866e-03, 2.2141e-02, 6.9998e-03,
-5.5714e-03, 2.1088e-02, 5.8745e-03, 1.2788e-02, 4.2977e-03,
5.8631e-03, -1.8121e-02, 1.9242e-03, 2.3622e-02, 1.4917e-02,
-5.3198e-03, -3.9222e-02, -2.4697e-02, 9.1218e-03, -1.0711e-02,
1.0268e-02, 1.5148e-02, -4.4508e-02, 4.6783e-03, 2.8093e-03,
9.1253e-03, -7.3281e-03, 1.0114e-03, -9.2369e-04, 1.4841e-02,
2.2642e-02, 2.3675e-02, 1.3902e-02, -5.6343e-03, 1.4851e-02,
-9.5169e-03, -3.1721e-02, 1.6696e-02, 2.9285e-02, -1.4090e-02,
2.1128e-02, 4.8656e-02, 3.8431e-02, -3.5470e-02, -4.8230e-03,
-1.6513e-02, 4.1917e-02, 8.9090e-03, -1.4022e-04, 4.0182e-03,
7.1723e-03, 3.1419e-02, -4.8508e-03, 1.7768e-03, -7.3688e-03,
3.4637e-03, -2.3227e-02, 3.9606e-05, -2.4731e-02, -1.3640e-02,
-5.1718e-03, 2.6662e-02, -1.2871e-02, -1.6009e-02, -5.3720e-03,
2.7397e-04, -3.4016e-03, 2.6429e-02, 3.8069e-02, 1.0929e-02,
-1.0620e-02, 1.2165e-02, -2.6018e-02, 1.6021e-02, 4.0644e-02,
-8.0898e-03, -3.5198e-02, -1.9602e-02, 2.4986e-02, -5.8400e-03,
3.2070e-02, -1.8265e-02, -5.4518e-03, 2.8195e-02, 5.5598e-02,
-3.9959e-02, 1.5521e-02, -2.8416e-02, 3.1130e-02, -1.0038e-02,
2.1522e-02, -1.1654e-02, 2.2382e-02, -5.4467e-03, -2.2840e-02,
2.7036e-03, -4.4607e-02, -4.1953e-02, 2.0079e-02, -5.0121e-03,
-1.7495e-02, 4.4070e-03, 3.7400e-04, 1.0899e-02, 1.7008e-02,
-1.6307e-02, -1.9986e-02, -2.3865e-02, -2.5618e-02, -2.9981e-02,
-2.7230e-03, 2.7079e-02, 5.2920e-03, 2.1069e-02, -2.5896e-02,
-1.6256e-02, -1.4182e-03, 1.1829e-02, 1.0360e-02, 2.8883e-02,
-6.8762e-03, 1.4032e-02, -4.3389e-03]], requires_grad=True)

109
test1.txt Normal file
View File

@ -0,0 +1,109 @@
from torchviz import make_dot
dot = make_dot(query_output.last_hidden_state, params=dict(self.Qformer.bert.named_parameters()))
log_dir = '/mnt/pfs-guan-ssai/nlu/wanghanzi/multimodal/PromptMoE/'
dot.render(filename="Pre_PromptMoE_RawProb_backward_graph", directory=log_dir, format="pdf")
# Pre-Prompt-MoE
model.Qformer.bert.encoder.layer[6].experts.gate.weight.grad
model.Qformer.bert.encoder.layer[8].experts.gate.weight.grad
model.Qformer.bert.encoder.layer[10].experts.gate.weight.grad
model.Qformer.bert.encoder.layer[6].experts.experts[0].dense1.weight.grad
model.Qformer.bert.encoder.layer[10].experts.experts[0].dense1.weight.grad
model.Qformer.bert.encoder.layer[8].experts.experts[0].dense1.weight.grad
model.Qformer.bert.encoder.layer[8].experts.experts[1].dense1.weight.grad
model.Qformer.bert.encoder.layer[8].experts.experts[2].dense1.weight.grad
model.Qformer.bert.encoder.layer[8].experts.gate.weight.grad
model.Qformer.bert.encoder.layer[9].intermediate_query.dense.weight
model.Qformer.bert.encoder.layer[9].intermediate_query.dense.weight.grad
model.Qformer.bert.encoder.layer[10].intermediate.dense.weight.grad
model.Qformer.bert.encoder.layer[11].intermediate.dense.weight.grad
model.Qformer.bert.encoder.layer[10].intermediate_query.dense.weight
model.Qformer.bert.encoder.layer[10].experts.experts[2].dense1.weight
model.Qformer.bert.encoder.layer[10].experts.experts[1].dense1.weight
model.Qformer.bert.encoder.layer[10].experts.experts[0].dense1.weight
model.Qformer.bert.encoder.layer[10].intermediate_query.dense.weight == model.Qformer.bert.encoder.layer[10].experts.experts[0].dense1.weight
# Pre-MoE gate-sentence
# model.Qformer.bert.encoder.layer[8].experts.gate.weight.grad 不更新
# Pre-MoE gate-token
# 正常更新
# Post-MoE gate-sentence
model.Qformer.bert.encoder.layer[8].experts.gate.weight.grad
# model.Qformer.bert.encoder.layer[8].experts.gate.weight.grad 正常更新
# model.Qformer.bert.encoder.layer[6].experts.gate.weight.grad 全是0/-0
# model.Qformer.bert.encoder.layer[10].experts.gate.weight.grad 全是0/-0
# Route-MoE
# Pre-MoE 算的beam_scores有问题
# Post-Route 会更新多个expert的参数会更新gate的参数
# Layer 6 更新了两个expert的参数 (layer 6 layer 8)
# model.Qformer.bert.encoder.layer[11].intermediate.dense.weight.grad 是0都是0
# model.Qformer.bert.encoder.layer[11].output.dense.weight.grad
model.Qformer.bert.encoder.layer[6].experts.gate.weight.grad
model.Qformer.bert.encoder.layer[6].experts.experts[0].intermediate_query.dense.weight.grad
model.Qformer.bert.encoder.layer[6].experts.experts[1].intermediate_query.dense.weight.grad
model.Qformer.bert.encoder.layer[7].experts.gate.weight.grad
model.Qformer.bert.encoder.layer[7].experts.experts[0].intermediate_query.dense.weight.grad
model.Qformer.bert.encoder.layer[7].experts.experts[1].intermediate_query.dense.weight.grad
model.Qformer.bert.encoder.layer[8].experts.gate.weight.grad
model.Qformer.bert.encoder.layer[8].experts.experts[0].intermediate_query.dense.weight.grad
model.Qformer.bert.encoder.layer[8].experts.experts[1].intermediate_query.dense.weight.grad
model.Qformer.bert.encoder.layer[9].experts.gate.weight.grad
model.Qformer.bert.encoder.layer[9].experts.experts[0].intermediate_query.dense.weight.grad
model.Qformer.bert.encoder.layer[9].experts.experts[1].intermediate_query.dense.weight.grad
model.Qformer.bert.encoder.layer[10].experts.gate.weight.grad
model.Qformer.bert.encoder.layer[10].experts.experts[0].intermediate_query.dense.weight.grad
model.Qformer.bert.encoder.layer[10].experts.experts[1].intermediate_query.dense.weight.grad
model.Qformer.bert.encoder.layer[11].experts.gate.weight.grad
model.Qformer.bert.encoder.layer[11].experts.experts[0].intermediate_query.dense.weight.grad
model.Qformer.bert.encoder.layer[11].experts.experts[1].intermediate_query.dense.weight.grad
(Pdb) [p for n, p in self.model.named_parameters() if n == 'Qformer.bert.encoder.layer.10.experts.experts.0.dense1.weight']
[Parameter containing:
tensor([[-0.0328, 0.0414, 0.0010, ..., -0.0068, 0.0244, 0.0587],
[ 0.0120, 0.0458, 0.0171, ..., -0.0439, -0.0107, -0.0397],
[ 0.0239, 0.0191, -0.0145, ..., 0.0008, -0.0067, 0.0090],
...,
[ 0.0174, -0.0465, -0.0106, ..., -0.0095, 0.0153, -0.0195],
[-0.0151, -0.0082, -0.0320, ..., -0.0016, -0.0232, -0.0147],
[ 0.0142, -0.0286, 0.0161, ..., -0.0160, -0.0306, -0.0272]],
device='cuda:0', requires_grad=True)]
(Pdb) [p for n, p in self.model.named_parameters() if n == 'Qformer.bert.encoder.layer.8.experts.experts.0.dense1.weight']
[Parameter containing:
tensor([[ 0.0024, 0.0218, -0.0186, ..., -0.0178, -0.0067, 0.0820],
[-0.0759, -0.0002, -0.0548, ..., 0.0292, 0.0531, 0.0779],
[-0.0220, -0.0037, -0.0520, ..., -0.0426, -0.0261, -0.0357],
...,
[-0.0448, 0.0471, 0.0133, ..., -0.0062, -0.0217, -0.0203],
[ 0.0532, 0.0197, 0.0320, ..., -0.0010, -0.0838, 0.0682],
[ 0.0284, 0.0038, -0.0007, ..., -0.0305, 0.0296, 0.0056]],
device='cuda:0', requires_grad=True)]
(Pdb) [p for n, p in self.model.named_parameters() if n == 'Qformer.bert.encoder.layer.6.experts.experts.0.dense1.weight']
[Parameter containing:
tensor([[ 6.5176e-02, -4.6473e-02, -2.7396e-02, ..., 2.1774e-03,
6.1457e-02, 1.9180e-03],
[ 7.3707e-03, 6.1392e-02, -2.7108e-02, ..., 4.0778e-02,
-1.9791e-02, -1.1612e-02],
[ 2.1193e-02, -3.8323e-02, -6.0238e-02, ..., -1.4539e-02,
9.2965e-02, 3.9153e-02],
...,
[ 5.3203e-03, -1.7276e-02, -3.2191e-02, ..., -1.6435e-02,
-1.8553e-02, -2.8158e-02],
[-6.9853e-02, 9.2719e-03, -1.8895e-03, ..., -2.6425e-02,
1.4880e-03, 3.4505e-02],
[-1.2168e-03, 3.7038e-02, 4.8047e-02, ..., -3.4523e-03,
-1.3030e-05, -1.4778e-02]], device='cuda:0', requires_grad=True)]