Route MoE Promote (Post/Pre) update 0110

This commit is contained in:
root 2024-01-10 16:56:52 +08:00 committed by wanghanzi
parent ce67e5669a
commit eb022668a3
38 changed files with 12983 additions and 491 deletions

File diff suppressed because it is too large Load Diff

Binary file not shown.

View File

@ -1,4 +1,4 @@
name: minigptv
name: promptmoe
channels:
- pytorch
- defaults

View File

@ -17,14 +17,14 @@ datasets:
# md5: aa31ac474cf6250ebb81d18348a07ed8
storage:
- /mnt/pfs-guan-ssai/nlu/wanghanzi/data/COCO_Cap/coco_karpathy_train.json
val:
url: https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_val.json
storage:
- /mnt/pfs-guan-ssai/nlu/wanghanzi/data/COCO_Cap/coco_karpathy_val.json
test:
url: https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_test.json
storage:
- /mnt/pfs-guan-ssai/nlu/wanghanzi/data/COCO_Cap/coco_karpathy_test.json
# val:
# url: https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_val.json
# storage:
# - /mnt/pfs-guan-ssai/nlu/wanghanzi/data/COCO_Cap/coco_karpathy_val.json
# test:
# url: https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_test.json
# storage:
# - /mnt/pfs-guan-ssai/nlu/wanghanzi/data/COCO_Cap/coco_karpathy_test.json
images:
storage: /mnt/pfs-guan-ssai/nlu/dingyifeng/data/COCO

View File

@ -20,6 +20,7 @@ datasets:
- https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/vqav2/v2_mscoco_val2014_annotations.json
storage:
- /mnt/pfs-guan-ssai/nlu/wanghanzi/data/VQAv2/vqa_val_eval.json
# - /mnt/pfs-guan-ssai/nlu/wanghanzi/data/VQAv2/vqa_train.json
- /mnt/pfs-guan-ssai/nlu/wanghanzi/data/VQAv2/answer_list.json
- /mnt/pfs-guan-ssai/nlu/wanghanzi/data/VQAv2/v2_OpenEnded_mscoco_val2014_questions.json
- /mnt/pfs-guan-ssai/nlu/wanghanzi/data/VQAv2/v2_mscoco_val2014_annotations.json
@ -29,6 +30,7 @@ datasets:
- https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/vqav2/answer_list.json
storage:
- /mnt/pfs-guan-ssai/nlu/wanghanzi/data/VQAv2/vqa_test.json
# - /mnt/pfs-guan-ssai/nlu/wanghanzi/data/VQAv2/vqa_train.json
- /mnt/pfs-guan-ssai/nlu/wanghanzi/data/VQAv2/answer_list.json
images:

View File

@ -20,6 +20,7 @@ datasets:
- https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/okvqa/mscoco_val2014_annotations.json
storage:
- /mnt/pfs-guan-ssai/nlu/wanghanzi/data/OKVQA/okvqa_val_eval.json
# - /mnt/pfs-guan-ssai/nlu/wanghanzi/data/OKVQA/okvqa_train.json
- /mnt/pfs-guan-ssai/nlu/wanghanzi/data/OKVQA/okvqa_answer_list_train.json
- /mnt/pfs-guan-ssai/nlu/wanghanzi/data/OKVQA/OpenEnded_mscoco_val2014_questions.json
- /mnt/pfs-guan-ssai/nlu/wanghanzi/data/OKVQA/mscoco_val2014_annotations.json
@ -32,6 +33,7 @@ datasets:
- https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/okvqa/mscoco_val2014_annotations.json
storage:
- /mnt/pfs-guan-ssai/nlu/wanghanzi/data/OKVQA/okvqa_val_eval.json
# - /mnt/pfs-guan-ssai/nlu/wanghanzi/data/OKVQA/okvqa_train.json
# - /mnt/pfs-guan-ssai/nlu/wanghanzi/data/OKVQA/okvqa_val_eval_part100.json
- /mnt/pfs-guan-ssai/nlu/wanghanzi/data/OKVQA/okvqa_answer_list_train.json
- /mnt/pfs-guan-ssai/nlu/wanghanzi/data/OKVQA/OpenEnded_mscoco_val2014_questions.json

View File

@ -105,6 +105,8 @@ class COCOCaptionDataset(BaseDataset, __DisplMixin):
'Using language, provide a short account of the image.',
'Use a few words to illustrate what is happening in the picture.',
]
self.source = 'coco_cap'
def __getitem__(self, index):
# TODO this assumes image input, not general enough
@ -118,13 +120,20 @@ class COCOCaptionDataset(BaseDataset, __DisplMixin):
image = self.vis_processor(image)
caption = self.text_processor(ann["caption"])
instruction = random.choice(self.instruction_pool)
instruction = "<Img><ImageHere></Img> [caption] {} ".format(instruction)
# instruction = random.choice(self.instruction_pool)
# instruction = "<Img><ImageHere></Img> [caption] {} ".format(instruction)
q_input = ""
llm_input = random.choice(self.instruction_pool)
return {
"image": image,
"image_id": ann["image"],
"answer": caption,
"instruction_input": instruction,
"q_input": q_input,
"llm_input": llm_input,
"text_input": llm_input,
"text_output": caption,
"source": 'coco_cap',
}
class CaptionEvalDataset(BaseDataset, __DisplMixin):

View File

@ -31,6 +31,7 @@ class COCOCapEvalDataset(CaptionEvalDataset):
split (string): val or test
"""
super().__init__(vis_processor, text_processor, vis_root, ann_paths)
self.source = 'coco_cap'
def __getitem__(self, index):
ann = self.annotation[index]

View File

@ -31,7 +31,6 @@ class MultiIterLoader:
if ratios is None:
ratios = [1.0] * len(loaders)
else:
# import pdb; pdb.set_trace()
assert len(ratios) == len(loaders)
ratios = [float(ratio) / sum(ratios) for ratio in ratios]

View File

@ -12,7 +12,6 @@ from tqdm import tqdm
import torch
from torch.utils.data import DataLoader
import torch.backends.cudnn as cudnn
from datasets import load_dataset
import sys
sys.path.append("/mnt/pfs-guan-ssai/nlu/wanghanzi/multimodal/PromptMoE")
@ -248,6 +247,7 @@ if 'vsr' in args.dataset:
img_path = cfg.evaluation_datasets_cfg["vsr"]["img_path"]
batch_size = cfg.evaluation_datasets_cfg["vsr"]["batch_size"]
max_new_tokens = cfg.evaluation_datasets_cfg["vsr"]["max_new_tokens"]
from datasets import load_dataset
annotation = load_dataset("cambridgeltl/vsr_zeroshot", split='test')
data = VSREvalData(annotation, vis_processor, img_path)

View File

@ -386,17 +386,23 @@ class BertOutput(nn.Module): # Add & Norm
class FeedForward(nn.Module):
# remove LayerNorm
def __init__(self, config):
nn.Module.__init__(self)
# first layer
self.intermediate_query = BertIntermediate(config)
# second layer
self.output_query = BertOutput(config)
super().__init__()
self.dense1 = nn.Linear(config.hidden_size, config.intermediate_size)
if isinstance(config.hidden_act, str):
self.intermediate_act_fn = ACT2FN[config.hidden_act]
else:
self.intermediate_act_fn = config.hidden_act
self.dense2 = nn.Linear(config.intermediate_size, config.hidden_size)
self.dropout = nn.Dropout(config.hidden_dropout_prob) # adjust dropout ratio 0.1->0.2
# self.dropout = nn.Dropout(0.2) # adjust dropout ratio 0.1->0.2
def forward(self, hidden_states: Tensor):
input_tensor = hidden_states
intermediate_output = self.intermediate_query(hidden_states)
hidden_states = self.output_query(intermediate_output, input_tensor)
hidden_states = self.dense1(hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states)
hidden_states = self.dense2(hidden_states)
hidden_states = self.dropout(hidden_states)
return hidden_states
@ -440,6 +446,7 @@ class BertLayer(nn.Module):
)
else:
self.experts = ffn
self.expert_ln = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
def forward(
self,
@ -494,7 +501,8 @@ class BertLayer(nn.Module):
moe_ffn_attention_input = query_attention_output[:, :query_length, :]
moe_ffn_attention_mask = attention_mask.squeeze(dim=1).squeeze(dim=1)[:, :query_length]
layer_output = self.feed_forward_query_moe(moe_ffn_attention_input, moe_ffn_attention_mask) # layer_output, gate_loss, gate_load
# import pdb; pdb.set_trace() # test0107
if attention_output.shape[1] > query_length: # have text input in Qformer
layer_output_text = apply_chunking_to_forward(
self.feed_forward_chunk,
@ -503,6 +511,7 @@ class BertLayer(nn.Module):
attention_output[:, query_length:, :],
)
layer_output = (torch.cat([layer_output[0], layer_output_text], dim=1), layer_output[1], layer_output[2])
else:
layer_output = apply_chunking_to_forward(
self.feed_forward_chunk,
@ -524,15 +533,14 @@ class BertLayer(nn.Module):
def feed_forward_query_moe(self, attention_output, expert_attention_mask):
if not self.use_experts:
layer_output = self.experts(attention_output)
hidden_states = self.experts(attention_output)
layer_output = self.expert_ln(hidden_states + attention_output)
return layer_output, 0.0, []
# if not self.importance_processor.is_moe:
# raise RuntimeError("Need to turn the model to a MoE first.")
layer_output, gate_loss, gate_load = self.experts(
hidden_states, gate_loss, gate_load = self.experts(
attention_output, expert_attention_mask
)
layer_output = self.expert_ln(hidden_states + attention_output)
return layer_output, gate_loss, gate_load
class BertEncoder(nn.Module):

View File

@ -46,10 +46,9 @@ from transformers.utils import logging
from transformers.models.bert.configuration_bert import BertConfig
from minigpt4.models.moe.utils import (
FeedForward,
MoEModelOutput,
MoEModelOutputWithPooling,
use_experts,
use_experts_route,
moe_layer_judge,
)
from minigpt4.models.moe.route_moe_layer import RouteMoELayer
@ -378,13 +377,14 @@ class BertOutput(nn.Module): # Add & Norm
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) # 1
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor):
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor)
# Move LayerNorm & ResNet out of FFN After MoEFFN
hidden_states = self.LayerNorm(hidden_states + input_tensor) # 1
return hidden_states
@ -429,7 +429,7 @@ class BertLayer(nn.Module):
self.output_query = BertOutput(config)
# Add MoE FFN
self.use_experts = use_experts(layer_num)
self.use_experts = use_experts_route(layer_num)
self.layer_judge = moe_layer_judge(layer_num)
self.num_beams = config.moebert_num_beams
ffn = FeedForward(config)
@ -442,10 +442,13 @@ class BertLayer(nn.Module):
num_beams=config.moebert_num_beams,
layer_judge = self.layer_judge,
route_method=config.route_method,
weight_type=config.moe_weight_type,
)
else:
self.experts = ffn
# self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
def forward(
self,
hidden_states,
@ -463,8 +466,8 @@ class BertLayer(nn.Module):
self_attn_past_key_value = (
past_key_value[:2] if past_key_value is not None else None
)
# import pdb;pdb.set_trace()
# import pdb; pdb.set_trace() # 0107test
# adjust the dimension of hidden_states, attention_mask, encoder_attention_mask and encoder_hidden_states to be the same
if self.num_beams > 1:
if hidden_states.shape[0]== attention_mask.shape[0]*self.num_beams:
@ -494,10 +497,6 @@ class BertLayer(nn.Module):
present_key_value = self_attention_outputs[-1]
# import pdb;pdb.set_trace()
# print(self.layer_num, hidden_states.shape, attention_mask.shape)
if query_length > 0:
query_attention_output = attention_output[:, :query_length, :]
@ -526,7 +525,8 @@ class BertLayer(nn.Module):
moe_ffn_attention_input = query_attention_output[:, :query_length, :]
moe_ffn_attention_mask = attention_mask.squeeze(dim=1).squeeze(dim=1)[:, :query_length]
layer_output = self.feed_forward_query_moe(moe_ffn_attention_input, moe_ffn_attention_mask, beam_scores, expert_route)
# layer_output = (layer_output, beam_scores, expert_route, beam_idx)
# layer_output = (layer_output, beam_scores, expert_route, beam_idx, importance_loss)
# import pdb; pdb.set_trace() # 0107test
if attention_output.shape[1] > query_length: # have text input in Qformer
layer_output_text = apply_chunking_to_forward(
@ -535,7 +535,8 @@ class BertLayer(nn.Module):
self.seq_len_dim,
attention_output[:, query_length:, :],
)
if layer_output[0].shape[0] == layer_output_text.shape[0]*self.num_beams and self.num_beams>1:
if self.layer_judge == 'first' and self.num_beams>1:
# if layer_output[0].shape[0] == layer_output_text.shape[0]*self.num_beams and self.num_beams>1:
# adjust the dimension of layer_output_text to bz*num_beams
layer_output_text = self.adjust_layer_output_text(layer_output_text)
@ -550,7 +551,8 @@ class BertLayer(nn.Module):
# layer_output & layer_output_text dimen_0 from bz*num_beams to bz
layer_output, layer_output_text = self.route_moe_last_layer_top1(layer_output, layer_output_text)
layer_output = (torch.cat([layer_output[0], layer_output_text], dim=1), layer_output[1], layer_output[2])
layer_output = (torch.cat([layer_output[0], layer_output_text], dim=1), layer_output[1], layer_output[2], layer_output[3],layer_output[4])
# import pdb; pdb.set_trace() # 0107test
else:
layer_output = apply_chunking_to_forward(
@ -559,7 +561,7 @@ class BertLayer(nn.Module):
self.seq_len_dim,
attention_output,
)
layer_output = (layer_output, None, None)
layer_output = (layer_output, None, None, None, 0.0)
outputs = (layer_output,) + outputs
@ -594,24 +596,27 @@ class BertLayer(nn.Module):
beam_scores_new = beam_scores[selects]
expert_route_new = expert_route[selects]
return (hidden_states_new, beam_scores_new, expert_route_new), layer_output_text
return (hidden_states_new, beam_scores_new, expert_route_new, layer_output[3], layer_output[4]), layer_output_text
def feed_forward_chunk(self, attention_output):
intermediate_output = self.intermediate(attention_output)
layer_output = self.output(intermediate_output, attention_output)
# layer_output = self.LayerNorm(layer_output + attention_output)
return layer_output
def feed_forward_query_moe(self, attention_output, expert_attention_mask, beam_scores, expert_route):
if not self.use_experts:
layer_output = self.experts(attention_output)
return layer_output, None, None, None
# layer_output = self.LayerNorm(layer_output + attention_output)
return layer_output, None, None, None, 0.0
layer_output, beam_scores, expert_route, beam_idx = self.experts(
layer_output, beam_scores, expert_route, beam_idx, importance_loss = self.experts(
attention_output, expert_attention_mask, beam_scores, expert_route
)
return layer_output, beam_scores, expert_route, beam_idx
# layer_output = self.LayerNorm(layer_output + attention_output)
return layer_output, beam_scores, expert_route, beam_idx, importance_loss
class BertEncoder(nn.Module):
def __init__(self, config):
@ -645,6 +650,7 @@ class BertEncoder(nn.Module):
next_decoder_cache = () if use_cache else None
beam_scores=None
expert_route=None
importance_loss = 0
for i in range(self.config.num_hidden_layers):
layer_module = self.layer[i]
@ -693,6 +699,7 @@ class BertEncoder(nn.Module):
hidden_states = layer_outputs[0][0]
beam_scores = beam_scores if layer_outputs[0][1] == None else layer_outputs[0][1]
expert_route = expert_route if layer_outputs[0][2] == None else layer_outputs[0][2]
importance_loss += layer_outputs[0][4]
if use_cache:
next_decoder_cache += (layer_outputs[-1],)
@ -724,6 +731,7 @@ class BertEncoder(nn.Module):
cross_attentions=all_cross_attentions,
beam_scores=beam_scores,
expert_route=expert_route,
gate_loss=importance_loss,
)
@ -1103,6 +1111,7 @@ class BertModel(BertPreTrainedModel):
cross_attentions=encoder_outputs.cross_attentions,
beam_scores=encoder_outputs.beam_scores,
expert_route=encoder_outputs.expert_route,
gate_loss=encoder_outputs.gate_loss
)

View File

@ -62,7 +62,7 @@ class Blip2Base(BaseModel):
return Qformer, query_tokens
@classmethod
def init_RouteMoEQformer(cls, num_query_token, vision_width, moebert_expert_num, moebert_num_beams, route_method, cross_attention_freq=2):
def init_RouteMoEQformer(cls, num_query_token, vision_width, moebert_expert_num, moebert_num_beams, route_method, moe_weight_type, cross_attention_freq=2):
moe_encoder_config = BertConfig.from_pretrained("/mnt/pfs-guan-ssai/nlu/wanghanzi/models/bert-base-uncased")
moe_encoder_config.encoder_width = vision_width
@ -74,6 +74,7 @@ class Blip2Base(BaseModel):
moe_encoder_config.moebert_expert_num = moebert_expert_num
moe_encoder_config.moebert_num_beams = moebert_num_beams
moe_encoder_config.route_method = route_method
moe_encoder_config.moe_weight_type = moe_weight_type
RouteMoEQformer = BertMoERouteLMHeadModel.from_pretrained(
"/mnt/pfs-guan-ssai/nlu/wanghanzi/models/bert-base-uncased", config=moe_encoder_config

View File

@ -99,6 +99,7 @@ class Blip2VicunaInstruct(Blip2Base):
moebert_expert_num=moebert_expert_num,
moebert_num_beams=moebert_num_beams,
route_method=moebert_route_method,
moe_weight_type=moe_weight_type,
cross_attention_freq=2
)
else:
@ -118,7 +119,6 @@ class Blip2VicunaInstruct(Blip2Base):
num_query_token, self.visual_encoder.num_features
)
# import pdb;pdb.set_trace()
if not qformer_text_input:
self.Qformer.bert.embeddings.word_embeddings = None
self.Qformer.bert.embeddings.position_embeddings = None
@ -178,6 +178,19 @@ class Blip2VicunaInstruct(Blip2Base):
if "_query" in name and "experts" not in name: # raw ffn_query not update
param.requires_grad = False
ln_pattern = r"bert\.encoder\.layer\.\d+\.expert_ln\.(weight|bias)"
if re.match(ln_pattern, name):
key_orig = re.sub('expert_ln', 'output_query.LayerNorm', name)
param.data.copy_(state_dict[key_orig])
d1_pattern = r"bert\.encoder\.layer\.(\d+)\.experts(\.|\.experts\.\d+\.)dense1\.(weight|bias)"
if re.match(d1_pattern, name):
key_orig = re.sub(r'experts(\.|\.experts\.\d+\.)dense1', 'intermediate_query.dense', name)
param.data.copy_(state_dict[key_orig])
d2_pattern = r"bert\.encoder\.layer\.(\d+)\.experts(\.|\.experts\.\d+\.)dense2\.(weight|bias)"
if re.match(d2_pattern, name):
key_orig = re.sub(r'experts(\.|\.experts\.\d+\.)dense2', 'output_query.dense', name)
param.data.copy_(state_dict[key_orig])
# freeze qformer
if freeze_qformer:
for name, param in self.Qformer.named_parameters():
@ -205,6 +218,7 @@ class Blip2VicunaInstruct(Blip2Base):
self.use_moeqformer = use_moeqformer
self.use_route_moe = use_route_moe
self.moebert_load_balance = moebert_load_balance
self.moebert_num_beams = moebert_num_beams
self.gate_save_path = gate_save_path
# if self.gate_save_path != None:
@ -242,7 +256,7 @@ class Blip2VicunaInstruct(Blip2Base):
# print(samples["text_input"])
# print(samples["text_output"])
# print('-----------------')
# import pdb;pdb.set_trace()
# import pdb;pdb.set_trace() # 0107test
image = samples["image"]
with self.maybe_autocast():
image_embeds = self.ln_vision(self.visual_encoder(image))
@ -278,10 +292,10 @@ class Blip2VicunaInstruct(Blip2Base):
return_dict=True,
output_hidden_states=True,
)
# import pdb; pdb.set_trace()# 0107test
query_output_to_linear = query_output.last_hidden_state[:,:query_tokens.size(1),:]
if self.use_moeqformer and not self.use_route_moe:
if self.use_moeqformer:
gate_loss = query_output.gate_loss # only available in QformerMoE
if self.gate_save_path != None:
@ -312,7 +326,7 @@ class Blip2VicunaInstruct(Blip2Base):
# 'gate_route_1': prob_gate_normalized[0][i].tolist(),
})
# for layer in [6,8,10]:
# layer_data = all_hidden_states[layer]
# layer_data = all_hidden_states[layer]s
# file_path = os.path.join(self.gate_save_path, f'{image_id}_{str(layer)}.npy')
# x = layer_data.data.cpu().numpy()
# np.save(file_path,x)
@ -323,7 +337,6 @@ class Blip2VicunaInstruct(Blip2Base):
print("Gate Save Error....")
print(e)
inputs_llm = self.llm_proj(query_output_to_linear)
atts_llm = torch.ones(inputs_llm.size()[:-1], dtype=torch.long).to(image.device)
@ -380,7 +393,7 @@ class Blip2VicunaInstruct(Blip2Base):
labels=targets,
)
if self.use_moeqformer and not self.use_route_moe:
if self.use_moeqformer:
loss = outputs.loss + self.moebert_load_balance * gate_loss
else:
loss = outputs.loss
@ -441,6 +454,8 @@ class Blip2VicunaInstruct(Blip2Base):
output_hidden_states=True,
)
# import pdb; pdb.set_trace()
if self.gate_save_path != None:
all_hidden_states = query_output.hidden_states
# prob_gate_normalized = query_output.gate_loads
@ -471,11 +486,11 @@ class Blip2VicunaInstruct(Blip2Base):
# 'gate_route_3': prob_gate_normalized[2][i].tolist(),
# 'gate_route_1': prob_gate_normalized[0][i].tolist(),
})
for layer in [6,8,10]:
if layer == 6:
layer_data = all_hidden_states[layer][i, :32, :]
for layer in [6,7,8,9,10,11]:
if layer in [6,11]:
layer_data = all_hidden_states[layer][i, :, :]
else:
layer_data = all_hidden_states[layer][i*3, :32, :]
layer_data = all_hidden_states[layer][i*self.moebert_num_beams, :, :]
file_path = os.path.join(self.gate_save_path, f'{image_id}_{str(layer)}.npy')
x = layer_data.data.cpu().numpy()
np.save(file_path,x) # 大功告成
@ -683,5 +698,6 @@ class Blip2VicunaInstruct(Blip2Base):
for name, param in model.named_parameters():
if param.requires_grad == True:
print(name)
# [name for name, param in model.named_parameters() if (param.requires_grad == False and 'Qformer' in name and 'intermediate_query' in name)]
# import pdb; pdb.set_trace()# 0107test
return model

View File

@ -21,7 +21,6 @@ class MoELayer(nn.Module):
else:
raise KeyError("Routing method not supported.")
def _forward_gate_sentence(self, x, attention_mask):
"""
x: query_attention_output , torch.Size([bz, 32, 768])
@ -77,7 +76,65 @@ class MoELayer(nn.Module):
print('Layer Qformer MoE: \n',prob_gate)
return moe_result, select_prob_gate, gate
def _forward_gate_sentence_post(self, x, attention_mask):
"""
x: query_attention_output; torch.Size([bz, 32, 768])
attention_mask: torch.ones([bz, 32])
bz = 4
x = torch.randn(bz,32,768)
attention_mask = torch.ones([bz, 32])
"""
attention_mask = torch.ones(attention_mask.shape[0], attention_mask.shape[1]).to(x.device)
x_masked = x * attention_mask.unsqueeze(-1) # torch.Size([bz, 32, 768])
def forward_expert(input_x, expert_idx):
# input_x += torch.randn(4,32,768)
# return input_x
output_x = self.experts[expert_idx].forward(input_x)
return output_x
outputs = list()
logits_gate_lst = list()
for expert_idx in range(self.num_experts):
output_x = forward_expert(x_masked, expert_idx)
outputs.append(output_x.unsqueeze(0))
output_x_aver = output_x.sum(1) / attention_mask.unsqueeze(-1).sum(1) # torch.Size([bz, 768])
# gate_acore = self.gates[expert_idx](output_x_aver)
gate_score = self.gate(output_x_aver)
logits_gate_lst.append(gate_score)
candidate_output = torch.cat(outputs) # torch.Size([num_expert, bz, 32, 768])
logits_gate = torch.cat(logits_gate_lst,dim=1)# torch.Size([bz, num_expert])
prob_gate = F.softmax(logits_gate, dim=-1) # torch.Size([bz, num_experts])
topk_values, gate = torch.topk(prob_gate, self.topk, dim=1) # gate, 每个样本被分配的expert: torch.Size([bz, topk])
num_sentences = F.one_hot(gate, self.num_experts).sum(1).gt(0).sum(0) # 每个expert被分配的样本数 torch.Size([num_expert])
gate_load = num_sentences.clone()
# load balancing loss
if self.use_balance_loss:
balance_loss = self._balancing_loss(prob_gate, num_sentences)
else:
balance_loss = 0.0
# importance loss
importance_loss = self._importance_auxiliary_loss(prob_gate)
# output_average = candidate_output.sum(2) / candidate_attn_mask.unsqueeze(-1).sum(2) # torch.Size([num_expert, bz, 768])
# output_average = torch.permute(output_average, (1, 0, 2)) # torch.Size([bz, num_expert, 768])
# logits_gate = self.gate(output_average) # torch.Size([bz, num_experts, 1])
prob_gate_topk = torch.zeros_like(prob_gate)
prob_gate_topk.scatter_(1, gate, topk_values)
prob_gate_normalized = prob_gate_topk / prob_gate_topk.sum(dim=1, keepdim=True) # torch.Size([bz, num_expert])
candidate_output_ad = torch.permute(candidate_output, (1, 0, 2, 3)) # torch.Size([bz, num_expert, 32, 768])
results = prob_gate_normalized.unsqueeze(-1).unsqueeze(-1) * candidate_output_ad # torch.Size([bz, num_expert, 32, 768])
moe_result = torch.sum(results, dim=1) # torch.Size([bz, 32, 768])
import pdb;pdb.set_trace()
return moe_result, (balance_loss+importance_loss), prob_gate_normalized
def forward(self, x, attention_mask):
if self.route_method == "gate-token":
x, balance_loss, gate_load = self._forward_gate_token(x)
@ -95,7 +152,7 @@ class MoELayer(nn.Module):
class RouteMoELayer(nn.Module):
def __init__(self, hidden_size, expert, gate, num_experts, num_beams=2, layer_judge=None, route_method="pre-route"):
def __init__(self, hidden_size, expert, num_experts, num_beams=2, layer_judge=None, route_method="pre-route", weight_type="ffn_prob"):
# remove hash list
nn.Module.__init__(self)
self.num_experts = num_experts
@ -103,13 +160,26 @@ class RouteMoELayer(nn.Module):
self.num_beams = num_beams
self.hidden_size = hidden_size
self.layer_judge = layer_judge
self.weight_type = weight_type
self.route_method = route_method
if self.route_method == "pre-route":
self.gate = nn.Linear(hidden_size, num_experts, bias=False).float()
elif self.route_method == "post-route":
# gate = nn.Linear(hidden_size, 1, bias=False).float()
self.gates = nn.ModuleList([copy.deepcopy(gate) for i in range(num_experts)])
gate = nn.Linear(hidden_size, 1, bias=False).float()
self.gate = gate
# self.gates = nn.ModuleList([copy.deepcopy(gate) for i in range(num_experts)])
def _importance_auxiliary_loss(self, prob_gate):
# From VMOE
# _importance_auxiliary_loss
axis = tuple(range(prob_gate.ndim - 1)) # All except last.
importance_per_expert = torch.sum(prob_gate, dim=axis)
std_importance_per_expert = torch.std(importance_per_expert)
mean_importance_per_expert = torch.mean(importance_per_expert)
# Compute coefficient of variation (i.e. std/mean) squared.
return (std_importance_per_expert / mean_importance_per_expert)**2
def forward_gate(self, x):
"""
@ -123,19 +193,21 @@ class RouteMoELayer(nn.Module):
prob_gate = F.softmax(logits_gate, dim=-1) # torch.Size([bz*num_beams, num_experts])
return prob_gate
def beam_search(self, current_scores_log, beam_scores, expert_route, batch_size):
import pdb;pdb.set_trace()
def beam_search_backup(self, current_scores_log, beam_scores, expert_route, batch_size):
if self.layer_judge=='first' and self.route_method=='pre-route':
# current_scores_log torch.Size([bz, num_experts])
assert beam_scores==None and expert_route==None
current_scores = torch.exp(current_scores_log)
topk_values, gate = torch.topk(current_scores, self.num_beams, dim=1) # gate, 每个样本被分配的expert: torch.Size([bz, topk])
beam_scores = topk_values.view(self.num_beams * batch_size) # torch.Size([bz * num_beams])
expert_route = gate.view(self.num_beams * batch_size).unsqueeze(1) # torch.Size([bz * num_beams,1])
beam_idx = None
beam_idx = torch.tensor(range(self.num_beams * batch_size))
else:
if self.layer_judge=='first' and self.route_method == 'post-route':
batch_size = batch_size
next_scores_raw1 = torch.exp(current_scores_log) # torch.Size([bz, num_experts])
next_scores_raw1 = torch.exp(current_scores_log) # torch.Size([bz, num_beams*num_experts])
else:
batch_size = int(batch_size // self.num_beams)
next_scores_raw = current_scores_log + torch.log(beam_scores).unsqueeze(1) # torch.Size([4*3, 5]) # 取log 之后,可以直接相加概率
@ -147,9 +219,6 @@ class RouteMoELayer(nn.Module):
next_scores, next_experts = torch.topk(next_scores_raw1, self.num_beams, dim=1, largest=True, sorted=True)
# next_scores torch.Size([bz, num_beams])
# next_tokens torch.Size([bz, num_beams])
print(next_scores_raw1)
print(next_scores)
print(next_experts)
next_batch_beam = list()
for batch_idx in range(batch_size):
@ -166,7 +235,7 @@ class RouteMoELayer(nn.Module):
next_batch_beam.extend(next_sent_beam)
import pdb;pdb.set_trace()
if self.layer_judge=='first' and self.route_method == 'post-route':
beam_scores = next_scores.view(self.num_beams * batch_size) # torch.Size([bz * num_beams])
expert_route = next_experts.view(self.num_beams * batch_size)
@ -181,33 +250,91 @@ class RouteMoELayer(nn.Module):
pre_route = expert_route[beam_idx,:]
expert_route = torch.cat([pre_route, beam_experts.unsqueeze(1)], dim=-1)
import pdb;pdb.set_trace()
return beam_scores, expert_route, beam_idx
def beam_search(self, current_scores_log, beam_scores, expert_route, batch_size):
if self.layer_judge=='first' and self.route_method in ['pre-route', 'post-route']:
# current_scores_log torch.Size([bz, num_experts])
assert beam_scores==None and expert_route==None
current_scores = torch.exp(current_scores_log)
topk_values, gate = torch.topk(current_scores, self.num_beams, dim=1) # gate, 每个样本被分配的expert: torch.Size([bz, topk])
beam_scores = topk_values.view(self.num_beams * batch_size) # torch.Size([bz * num_beams])
expert_route = gate.view(self.num_beams * batch_size).unsqueeze(1) # torch.Size([bz * num_beams,1])
beam_idx = torch.tensor(range(self.num_beams * batch_size))
import pdb;pdb.set_trace()
else:
batch_size = int(batch_size // self.num_beams)
next_scores_raw = current_scores_log + torch.log(beam_scores).unsqueeze(1) # torch.Size([4*3, 5]) # 取log 之后,可以直接相加概率
next_scores_exp = torch.exp(next_scores_raw)
next_scores_raw1 = next_scores_exp.view(
batch_size, self.num_beams * self.num_experts
) # torch.Size([bz, num_beams*num_experts])
next_scores, next_experts = torch.topk(next_scores_raw1, self.num_beams, dim=1, largest=True, sorted=True)
# next_scores torch.Size([bz, num_beams])
# next_tokens torch.Size([bz, num_beams])
next_batch_beam = list()
for batch_idx in range(batch_size):
next_sent_beam = list()
for rank, (expert_id, expert_score) in enumerate(
zip(next_experts[batch_idx], next_scores[batch_idx])
):
expert_id = expert_id.item()
beam_id = expert_id // self.num_experts
ex_id = expert_id % self.num_experts
effective_beam_id = batch_idx*self.num_beams + beam_id
next_sent_beam.append((expert_score, ex_id, effective_beam_id))
next_batch_beam.extend(next_sent_beam)
# import pdb;pdb.set_trace()
beam_scores = beam_scores.new([x[0] for x in next_batch_beam])
beam_experts = expert_route[:,-1].new([x[1] for x in next_batch_beam])
beam_idx = expert_route[:,-1].new([x[2] for x in next_batch_beam])
pre_route = expert_route[beam_idx,:]
expert_route = torch.cat([pre_route, beam_experts.unsqueeze(1)], dim=-1)
print("next_scores_raw1:\n",next_scores_raw1)
return beam_scores, expert_route, beam_idx
def forward_expert_ffn(self, x, expert_select, beam_scores):
def forward_expert_ffn(self, x, expert_select, current_scores):
"""
x_repeat : [bz*num_beams, 32,768]
expert_select : [bz*num_beams]
current_scores : [bz*num_beams, num_experts] / [bz, num_experts]
"""
# add_1212 l2_normalization
# normalized_tensor = torch.nn.functional.normalize(beam_scores, p=2, dim=0) # L2 Normalization torch.Size([bz, topk])
# add_1228 l2_normalization
# normalized_tensor = torch.nn.functional.normalize(current_scores, p=2, dim=0) # L2 Normalization torch.Size([bz, topk])
# tmp_prob = normalized_tensor.unsqueeze(-1).unsqueeze(-1)
import pdb;pdb.set_trace()
outputs = list()
for i in range(x.shape[0]):
output_x = self.experts[expert_select[i]].forward(x[i])
outputs.append(output_x.unsqueeze(0))
candidate_output = torch.cat(outputs)
for i in range(self.num_experts):
output_x = self.experts[i].forward(x)
outputs.append(output_x.unsqueeze(1))
candidate_output = torch.cat(outputs, dim=1)
expert_select_matrix = F.one_hot(expert_select, self.num_experts)
# candidate_output = candidate_output * tmp_prob
return candidate_output # torch.Size([bz*num_beams, 32, 768])
if self.weight_type == 'ffn_prob':
tmp_prob = current_scores * expert_select_matrix
candidate_output = candidate_output * tmp_prob.unsqueeze(-1).unsqueeze(-1)
else:
candidate_output = candidate_output * expert_select_matrix.unsqueeze(-1).unsqueeze(-1)
import pdb;pdb.set_trace()
output = torch.sum(candidate_output, dim=1)
return output # torch.Size([bz*num_beams, 32, 768])
def forward_pre_route(self, x, beam_scores, expert_route, use_log=True):
current_scores = self.forward_gate(x) # [bz*num_beams, 5]
import pdb;pdb.set_trace()
current_scores = self.forward_gate(x) # [bz, num_beams] / [bz*num_beams, num_beams]
importance_loss = self._importance_auxiliary_loss(current_scores)
if use_log:
current_scores_log = torch.log(current_scores) # 取log之后可以直接相加
@ -215,42 +342,45 @@ class RouteMoELayer(nn.Module):
current_scores_log = current_scores
batch_size, num_tokens = x.shape[0], x.shape[1]
beam_scores, expert_route, _ = self.beam_search(current_scores_log, beam_scores, expert_route, batch_size)
beam_scores, expert_route, beam_idx = self.beam_search(current_scores_log, beam_scores, expert_route, batch_size)
current_expert_select = expert_route[:,-1]
import pdb;pdb.set_trace()
if self.layer_judge=='first': # expand first dim to batch_size * num_beams
replicated_tensor = x.unsqueeze(1).expand(batch_size, self.num_beams, num_tokens, self.hidden_size)
x = replicated_tensor.contiguous().view(-1, num_tokens, self.hidden_size) # [bz*num_beams, 32,768]
current_scores_t = current_scores.unsqueeze(1).expand(batch_size, self.num_beams, self.num_experts)
current_scores = current_scores_t.contiguous().view(-1, self.num_experts) # [bz*num_beams, num_experts]
candidate_output = self.forward_expert_ffn(x, current_expert_select, beam_scores) # [bz*num_beams, 32,768]
return candidate_output, beam_scores, expert_route
input_x = x[beam_idx]
candidate_output = self.forward_expert_ffn(input_x, current_expert_select, current_scores) # [bz*num_beams, 32,768]
import pdb;pdb.set_trace()
return candidate_output, beam_scores, expert_route, beam_idx, importance_loss
def forward_post_route(self, x, beam_scores, expert_route, use_log=True):
# if self.layer_judge=='first': # expand first dim to batch_size * num_beams
# batch_size, num_tokens = x.shape[0], x.shape[1]
# replicated_tensor = x.unsqueeze(1).expand(batch_size, self.num_beams, num_tokens, self.hidden_size)
# x = replicated_tensor.contiguous().view(-1, num_tokens, self.hidden_size) # [bz*num_beams, 32,768]
attention_mask = torch.ones(x.shape[0], x.shape[1]).to(x.device)
x_masked = x * attention_mask.unsqueeze(-1) # torch.Size([bz, 32, 768])
def forward_expert(input_x, expert_idx):
output_x = self.experts[expert_idx].forward(input_x)
return output_x
import pdb; pdb.set_trace()
outputs = list()
logits_gate_lst = list()
for expert_idx in range(self.num_experts):
output_x = forward_expert(x_masked, expert_idx)
# output_x_aver = output_x.sum(1) / attention_mask.unsqueeze(-1).sum(1) # torch.Size([bz*num_beam, 768])
output_x_aver = torch.mean(output_x, dim=1)
# gate_score = self.gates[expert_idx](output_x_aver)
gate_score = self.gate(output_x_aver)
logits_gate_lst.append(gate_score)
outputs.append(output_x.unsqueeze(0))
output_x_aver = output_x.sum(1) / attention_mask.unsqueeze(-1).sum(1) # torch.Size([bz*num_beam, 768])
gate_acore = self.gates[expert_idx](output_x_aver)
logits_gate_lst.append(gate_acore)
candidate_output = torch.cat(outputs) # torch.Size([num_expert, bz*num_beam, 32, 768])
candidate_output_raw = torch.cat(outputs) # torch.Size([num_expert, bz*num_beam, 32, 768])
logits_gate = torch.cat(logits_gate_lst,dim=1)# torch.Size([bz*num_beam, num_expert])
current_scores = F.softmax(logits_gate, dim=-1) # torch.Size([bz*num_beam, num_experts])
@ -259,25 +389,39 @@ class RouteMoELayer(nn.Module):
else:
current_scores_log = current_scores
import pdb;pdb.set_trace()
# importance loss
importance_loss = self._importance_auxiliary_loss(current_scores)
# import pdb; pdb.set_trace()
batch_size = x.shape[0] # bz*num_beam
batch_size, num_tokens = x.shape[0], x.shape[1] # bz*num_beam
beam_scores, expert_route, beam_idx = self.beam_search(current_scores_log, beam_scores, expert_route, batch_size)
# beam_scores torch.Size([bz*num_beam])
# expert_route torch.Size([bz*num_beam, layer_n])
current_select_expert = expert_route[:,-1]
# current_select_expert torch.Size([bz*num_beam, 1])
output = list()
for i in range(beam_idx.shape[0]):
b_idx = beam_idx[i]
ex_idx = current_select_expert[i]
ex_out = candidate_output[ex_idx, b_idx, :,:]
output.append(ex_out.unsqueeze(0))
final_output = torch.concat(output, dim=0)
return final_output, beam_scores, expert_route, beam_idx
# import pdb; pdb.set_trace()
if self.layer_judge == 'first':
replicated_tensor = candidate_output_raw.unsqueeze(2).expand(self.num_experts, batch_size, self.num_beams, num_tokens, self.hidden_size)
candidate_output_raw = replicated_tensor.contiguous().view(self.num_experts, -1, num_tokens, self.hidden_size) # [bz*num_beams, 32,768]
current_scores_t = current_scores.unsqueeze(1).expand(batch_size, self.num_beams, self.num_experts)
current_scores = current_scores_t.contiguous().view(-1, self.num_experts) # [bz*num_beams, num_experts]
candidate_output = candidate_output_raw.permute(1, 0, 2, 3)[beam_idx] # torch.Size([8, 2, 32, 768])
expert_select_matrix = F.one_hot(current_select_expert, self.num_experts)
if self.weight_type == 'ffn_prob':
tmp_prob = current_scores[beam_idx] * expert_select_matrix
output = candidate_output * tmp_prob.unsqueeze(-1).unsqueeze(-1)
else:
output = candidate_output * expert_select_matrix.unsqueeze(-1).unsqueeze(-1)
final_output = torch.sum(output, dim=1)
import pdb; pdb.set_trace()
print("current_scores:\n",current_scores)
return final_output, beam_scores, expert_route, beam_idx, importance_loss
def forward(self, x, attention_mask, beam_scores, expert_route, use_log=True):
"""
@ -286,13 +430,12 @@ class RouteMoELayer(nn.Module):
"""
if self.route_method == 'pre-route':
candidate_output, beam_scores, expert_route, _ = self.forward_pre_route(x, beam_scores, expert_route, use_log=True)
candidate_output, beam_scores, expert_route, beam_idx, importance_loss = self.forward_pre_route(x, beam_scores, expert_route, use_log=True)
elif self.route_method == "post-route":
candidate_output, beam_scores, expert_route, beam_idx = self.forward_post_route(x, beam_scores, expert_route, use_log=True)
candidate_output, beam_scores, expert_route, beam_idx, importance_loss = self.forward_post_route(x, beam_scores, expert_route, use_log=True)
return candidate_output, beam_scores, expert_route, beam_idx
return candidate_output, beam_scores, expert_route, beam_idx, importance_loss
if __name__ == '__main__':
import sys
@ -314,8 +457,8 @@ if __name__ == '__main__':
config.add_cross_attention = True
config.cross_attention_freq = cross_attention_freq
config.query_length = num_query_token
config.moebert_expert_num = 3
config.moebert_num_beams = 3
config.moebert_expert_num = 2
config.moebert_num_beams = 2
config.moebert_route_method = 'gate-sentence'
config.moe_topk = 2
config.use_balance_loss = False
@ -332,40 +475,46 @@ if __name__ == '__main__':
for layer_num in [6, 8, 10]:
layer_judge = moe_layer_judge(layer_num)
ffn = FeedForward(config)
gate = nn.Linear(768, config.moebert_expert_num, bias=False).float()
# experts = RouteMoELayer(
# hidden_size=768,
# expert=ffn,
# gate = gate,
# num_experts=config.moebert_expert_num,
# num_beams=config.moebert_num_beams,
# layer_judge = layer_judge,
# route_method = "pre-route"
# route_method = "pre-route",
# weight_type="no_ffn_prob"
# )
# layer_output = experts(x, None, beam_scores, expert_route)
# hidden_states1, beam_scores, expert_route,_ = layer_output
# hidden_states1, beam_scores, expert_route, beam_idx, importance_loss = layer_output
# print(beam_scores)
# print(expert_route)
# print(beam_idx)
# print(importance_loss)
# x = hidden_states1
gate1 = nn.Linear(768, 1, bias=False).float()
experts_post = RouteMoELayer(
hidden_size=768,
expert=ffn,
gate = gate1,
num_experts=config.moebert_expert_num,
num_beams=config.moebert_num_beams,
layer_judge = layer_judge,
route_method = "post-route"
route_method = "post-route",
weight_type="ffn_prob"
)
layer_output = experts_post(x1, None, beam_scores1, expert_route1, False)
hidden_states2, beam_scores1, expert_route1, beam_idx = layer_output
hidden_states2, beam_scores1, expert_route1, beam_idx, importance_loss = layer_output
print(beam_scores1)
print(expert_route1)
print(beam_idx)
print(importance_loss)
x1 = hidden_states2
# gate = nn.Linear(768, config.moebert_expert_num, bias=False).float()
# experts_moe = MoELayer(
# hidden_size=config.hidden_size,
# expert=ffn,
@ -382,11 +531,62 @@ if __name__ == '__main__':
# print(select_prob_gate)
# print(gate_load)
# x = hidden_states1
x1 = hidden_states2
# x2 = hidden_states3
print("------------------------------------")
import pdb; pdb.set_trace()
def forward_post_route_backup(self, x, beam_scores, expert_route, use_log=True):
attention_mask = torch.ones(x.shape[0], x.shape[1]).to(x.device)
x_masked = x * attention_mask.unsqueeze(-1) # torch.Size([bz, 32, 768])
def forward_expert(input_x, expert_idx):
output_x = self.experts[expert_idx].forward(input_x)
return output_x
outputs = list()
logits_gate_lst = list()
for expert_idx in range(self.num_experts):
output_x = forward_expert(x_masked, expert_idx)
outputs.append(output_x.unsqueeze(0))
# output_x_aver = output_x.sum(1) / attention_mask.unsqueeze(-1).sum(1) # torch.Size([bz*num_beam, 768])
# gate_score = self.gates[expert_idx](output_x_aver)
output_x_aver = torch.mean(output_x, dim=1)
gate_score = self.gate(output_x_aver)
logits_gate_lst.append(gate_score)
candidate_output = torch.cat(outputs) # torch.Size([num_expert, bz*num_beam, 32, 768])
logits_gate = torch.cat(logits_gate_lst,dim=1)# torch.Size([bz*num_beam, num_expert])
current_scores = F.softmax(logits_gate, dim=-1) # torch.Size([bz*num_beam, num_experts])
if use_log:
current_scores_log = torch.log(current_scores) # 取log之后可以直接相加
else:
current_scores_log = current_scores
# importance loss
importance_loss = self._importance_auxiliary_loss(current_scores)
batch_size = x.shape[0] # bz*num_beam
beam_scores, expert_route, beam_idx = self.beam_search(current_scores_log, beam_scores, expert_route, batch_size)
# beam_scores torch.Size([bz*num_beam])
# expert_route torch.Size([bz*num_beam, layer_n])
current_select_expert = expert_route[:,-1]
# current_select_expert torch.Size([bz*num_beam, 1])
output = list()
for i in range(beam_idx.shape[0]):
b_idx = beam_idx[i]
ex_idx = current_select_expert[i]
ex_out = candidate_output[ex_idx, b_idx, :,:]
if self.weight_type == 'ffn_prob':
prob = current_scores[b_idx, ex_idx]
ex_out = ex_out*prob
output.append(ex_out.unsqueeze(0))
final_output = torch.concat(output, dim=0)
# import pdb;pdb.set_trace()
return final_output, beam_scores, expert_route, beam_idx, importance_loss

View File

@ -1,155 +0,0 @@
import torch
import copy
import pickle
import torch
import torch.nn as nn
import torch.nn.functional as F
device = torch.device("cuda:2" if torch.cuda.is_available() else "cpu")
def forward_expert(input_x, expert_idx):
input_x += torch.randn(32,768)
return input_x
# output_x = self.experts[expert_idx].forward(input_x)
# return output_x
def forward_ffn(x_repeat, expert_select):
"""
x_repeat : [bz*num_beams, 32,768]
expert_select : [bz*num_beams]
"""
outputs = list()
num_beams_bz = x_repeat.shape[0]
for i in range(num_beams_bz):
output_x = forward_expert(x_repeat[i], expert_select[i]) # (32,768)
outputs.append(output_x.unsqueeze(0))
candidate_output = torch.cat(outputs)
return candidate_output # torch.Size([bz*num_beams, 32, 768])
def forward_gate(x, num_expert):
"""
x : torch.Size([bz*num_beams, 32, 768]) or torch.Size([bz, 32, 768])
prob_gate : torch.Size([bz*num_beams, num_experts]) or torch.Size([bz, num_experts])
"""
# attention_mask = torch.ones(x.shape[0], x.shape[1]).to(x.device)
# x_masked = x * attention_mask.unsqueeze(-1) # torch.Size([bz*num_beams, 32, 768])
# x_average = x_masked.sum(1) / attention_mask.unsqueeze(-1).sum(1) # torch.Size([bz*num_beams, 768])
# logits_gate = gate(x_average) # torch.Size([bz, num_experts])
logits_gate = torch.randn(x.shape[0], num_expert)
prob_gate = F.softmax(logits_gate, dim=-1) # torch.Size([bz*num_beams, num_experts])
return prob_gate
def beam_search(layer, current_scores, beam_scores, expert_route, num_beams):
if layer == 0 and beam_scores==None and expert_route==None:
topk_values, gate = torch.topk(current_scores, num_beams, dim=1) # gate, 每个样本被分配的expert: torch.Size([bz, topk])
beam_scores = topk_values.view(num_beams*batch_size) # torch.Size([bz * num_beams])
expert_route = gate.view(num_beams*batch_size).unsqueeze(1) # torch.Size([bz * num_beams])
else:
next_scores_raw = current_scores + beam_scores.unsqueeze(1) # torch.Size([4*3, 5]) # 取log 之后,可以直接相加概率
next_scores_raw1 = next_scores_raw.view(
batch_size, num_beams * num_expert
) # torch.Size([4, 3*5])
next_scores, next_experts = torch.topk(next_scores_raw1, num_beams, dim=1, largest=True, sorted=True)
# next_scores torch.Size([4, 3*num_beams])
# next_tokens torch.Size([4, 3*num_beams])
next_batch_beam = list()
for batch_idx in range(batch_size):
next_sent_beam = list()
print(batch_idx)
for rank, (expert_id, expert_score) in enumerate(
zip(next_experts[batch_idx], next_scores[batch_idx])
):
expert_id = expert_id.item()
beam_id = expert_id // num_expert
ex_id = expert_id % num_expert
effective_beam_id = batch_idx*num_beams + beam_id
# print(expert_id, beam_id, ex_id, effective_beam_id, expert_score)
next_sent_beam.append((expert_score, ex_id, effective_beam_id))
next_batch_beam.extend(next_sent_beam)
# print()
import pdb;pdb.set_trace()
beam_scores = beam_scores.new([x[0] for x in next_batch_beam])
beam_experts = expert_route[:,-1].new([x[1] for x in next_batch_beam])
beam_idx = expert_route[:,-1].new([x[2] for x in next_batch_beam])
pre_route = expert_route[beam_idx,:]
expert_route = torch.cat([pre_route, beam_experts.unsqueeze(1)], dim=-1)
return beam_scores, expert_route
if __name__ == '__main__':
batch_size = 3
num_beams = 2
num_expert = 5
x = torch.randn(batch_size, 32, 768)
beam_scores, expert_route = None, None
for layer in range(0,3):
# import pdb;pdb.set_trace()
current_scores = forward_gate(x, num_expert)
import pdb;pdb.set_trace()
beam_scores, expert_route = beam_search(layer, current_scores, beam_scores, expert_route, num_beams)
current_expert_select = expert_route[:,-1]
if layer == 0:
replicated_tensor = x.unsqueeze(1).expand(batch_size, num_beams, 32, 768)
x = replicated_tensor.contiguous().view(-1, 32, 768) # [12,32,768] [bz*num_beams, 32,768]
else:
x = candidate_output
candidate_output = forward_ffn(x, current_expert_select) # torch.Size([4*3, 5])
x = candidate_output
scores = beam_scores.view(batch_size, num_beams)
topk_values, gate = torch.topk(scores, 1, dim=1)
# gate [batch_size, 1]
# topk_values [batch_size, 1]
selects = [ (bz_idx * num_beams + gate[bz_idx].item()) for bz_idx in range(batch_size)]
final_scores = beam_scores[selects]
final_expert_route = expert_route[selects]
final_output = candidate_output[selects]
# def forward_ffn_post(x_repeat, expert_select):
# """
# x_repeat : [bz*num_beams, 32,768]
# expert_select : [bz*num_beams]
# prob_gate : torch.Size([bz*num_beams, num_experts])
# """
# outputs = list()
# logits_gate_lst = list()
# # attention_mask = torch.ones([batch_size, 32])
# for i in range(num_beams*batch_size):
# output_x = forward_expert(x_repeat[i], expert_select[i]) # (32,768)
# outputs.append(output_x.unsqueeze(0))
# # output_x_aver = output_x.sum(1) / attention_mask.unsqueeze(-1).sum(1) # torch.Size([bz, 768])
# # gate_acore = self.gates[expert_idx](output_x_aver)
# # gate_score = self.gate(output_x_aver)
# num_expert = 5
# gate_score = torch.randn(1,num_expert)
# logits_gate_lst.append(gate_score)
# candidate_output = torch.cat(outputs) # torch.Size([bz*num_beams, 32, 768])
# logits_gate = torch.cat(logits_gate_lst,dim=0)# torch.Size([bz*num_beams, num_expert])
# prob_gate = F.softmax(logits_gate, dim=-1) # torch.Size([bz*num_beams, num_experts])
# return prob_gate, candidate_output

View File

@ -5,7 +5,7 @@ import torch.nn as nn
import torch.nn.functional as F
class MoELayer(nn.Module):
def __init__(self, hidden_size, expert, num_experts, route_method, topk=1, use_balance_loss=True, weight_type='l2_norm'):
def __init__(self, hidden_size, expert, num_experts, route_method, topk=1, use_balance_loss=True, weight_type='raw_prob'):
# remove hash list
nn.Module.__init__(self)
self.num_experts = num_experts
@ -81,54 +81,6 @@ class MoELayer(nn.Module):
return x, balance_loss, gate_load
def _forward_gate_sentence_top1_raw(self, x, attention_mask):
"""
x: query_attention_output , torch.Size([bz, 32, 768])
attention_mask: torch.ones([bz, 32])
### Notice:
the raw version of expert_attention_mask is the extended_attention_mask,
which will be add to attention_score directly
the values of extended_attention_mask are -0.0 or -10000
it should be adjust to 1/0 version to be processed by experts
"""
attention_mask = torch.ones(attention_mask.shape[0], attention_mask.shape[1]).to(x.device)
x_masked = x * attention_mask.unsqueeze(-1) # torch.Size([bz, 32, 768])
x_average = x_masked.sum(1) / attention_mask.unsqueeze(-1).sum(1) # torch.Size([bz, 768])
logits_gate = self.gate(x_average) # torch.Size([bz, num_experts])
prob_gate = F.softmax(logits_gate, dim=-1) # torch.Size([bz, num_experts])
gate = torch.argmax(prob_gate, dim=-1) # torch.Size([bz])
order = gate.argsort(0)
num_sentences = F.one_hot(gate, self.num_experts).gt(0).sum(0)
gate_load = num_sentences.clone()
x = x[order] # reorder according to expert number
x = x.split(num_sentences.tolist(), dim=0) # a list of length self.num_experts
# compute the load balancing loss
P = prob_gate.mean(0)
temp = num_sentences.float()
f = temp / temp.sum(0, keepdim=True)
balance_loss = self.num_experts * torch.sum(P * f)
prob_gate = prob_gate.gather(dim=1, index=gate.unsqueeze(1))
prob_gate = prob_gate[order]
prob_gate = prob_gate.split(num_sentences.tolist(), dim=0)
def forward_expert(input_x, prob_x, expert_idx):
input_x = self.experts[expert_idx].forward(input_x)
input_x = input_x * prob_x.unsqueeze(-1)
return input_x
result = []
for i in range(self.num_experts):
if x[i].size(0) > 0:
result.append(forward_expert(x[i], prob_gate[i], i))
result = torch.vstack(result)
result = result[order.argsort(0)] # restore original order
return result, balance_loss, gate_load
def _forward_gate_sentence_post(self, x, attention_mask):
"""
x: query_attention_output; torch.Size([bz, 32, 768])
@ -174,13 +126,17 @@ class MoELayer(nn.Module):
# importance loss
importance_loss = self._importance_auxiliary_loss(prob_gate)
# output_average = candidate_output.sum(2) / candidate_attn_mask.unsqueeze(-1).sum(2) # torch.Size([num_expert, bz, 768])
# output_average = torch.permute(output_average, (1, 0, 2)) # torch.Size([bz, num_expert, 768])
# logits_gate = self.gate(output_average) # torch.Size([bz, num_experts, 1])
prob_gate_topk = torch.zeros_like(prob_gate)
prob_gate_topk.scatter_(1, gate, topk_values)
prob_gate_normalized = prob_gate_topk / prob_gate_topk.sum(dim=1, keepdim=True) # torch.Size([bz, num_expert])
if self.weight_type == 'average':
# torch.Size([bz, num_expert]) 未选中的expert prob_gate_norm为0
prob_gate_normalized = prob_gate_topk / prob_gate_topk.sum(dim=1, keepdim=True)
elif self.weight_type == 'raw_prob':
prob_gate_normalized = prob_gate_topk
elif self.weight_type == 'softmax_norm':
prob_gate_normalized = F.softmax(prob_gate_topk, dim=-1) # torch.Size([bz, num_expert])
candidate_output_ad = torch.permute(candidate_output, (1, 0, 2, 3)) # torch.Size([bz, num_expert, 32, 768])
results = prob_gate_normalized.unsqueeze(-1).unsqueeze(-1) * candidate_output_ad # torch.Size([bz, num_expert, 32, 768])
moe_result = torch.sum(results, dim=1) # torch.Size([bz, 32, 768])
@ -188,6 +144,46 @@ class MoELayer(nn.Module):
return moe_result, (balance_loss+importance_loss), prob_gate_normalized
def router(self, x, attention_mask):
# Prepare input x
attention_mask = torch.ones(attention_mask.shape[0], attention_mask.shape[1]).to(x.device)
x_masked = x * attention_mask.unsqueeze(-1) # torch.Size([bz, 32, 768])
x_average = torch.mean(x_masked, dim=1) # torch.Size([bz, 768])
# Forward Gate
# logits_gate: [bz, num_experts]
logits_gate = self.gate(x_average)
# Probabilities for each sample of what expert it should be sent to.
# prob_gate: [bz, num_experts]
prob_gate = F.softmax(logits_gate, dim=-1)
# Get Top-K experts for each sample
# gate: [bz, topk]
# select_prob_gate: [bz, topk]
select_prob_gate, gate = torch.topk(prob_gate, self.topk, dim=1)
# Reshap Prob_gate & Gate
# expert_mask: [batch_size, topk, num_experts]
# expert_gate: [batch_size, topk, num_experts]
# combine_tensor: [batch_size, num_experts]
expert_mask = F.one_hot(gate, self.num_experts)
expert_gate = select_prob_gate.unsqueeze(-1) * expert_mask
combine_tensor = torch.sum(expert_gate, dim=1)
# Calculate Balancing Loss
if self.use_balance_loss:
num_sentences = F.one_hot(gate, self.num_experts).sum(1).gt(0).sum(0) # 每个expert被分配的样本数 torch.Size([num_expert])
balance_loss = self._balancing_loss(prob_gate, num_sentences)
else:
balance_loss = 0.0
# Calculate Importance Loss
importance_loss = self._importance_auxiliary_loss(prob_gate)
# import pdb; pdb.set_trace()
return expert_mask, combine_tensor, balance_loss, importance_loss
def _forward_gate_sentence(self, x, attention_mask):
"""
@ -200,81 +196,37 @@ class MoELayer(nn.Module):
the values of extended_attention_mask are -0.0 or -10000
it should be adjust to 1/0 version to be processed by experts
"""
attention_mask = torch.ones(attention_mask.shape[0], attention_mask.shape[1]).to(x.device)
x_masked = x * attention_mask.unsqueeze(-1) # torch.Size([bz, 32, 768])
x_average = x_masked.sum(1) / attention_mask.unsqueeze(-1).sum(1) # torch.Size([bz, 768])
logits_gate = self.gate(x_average) # torch.Size([bz, num_experts])
prob_gate = F.softmax(logits_gate, dim=-1) # torch.Size([bz, num_experts])
select_prob_gate, gate = torch.topk(prob_gate, self.topk, dim=1) # gate, 每个样本被分配的expert: torch.Size([bz, topk])
# Forward Router
expert_mask, combine_tensor, balance_loss, importance_loss = self.router(x, attention_mask)
# Forward Expert FFN
result = []
for expert_idx in range(self.num_experts):
output_x = self.experts[expert_idx].forward(x)
result.append(output_x.unsqueeze(0))
expert_output = torch.cat(result).permute(1,0,2,3) # torch.Size([batch_size, num_expert, num_tokens, hidden_states])
# 这里用l2 norm 去加权
if self.weight_type == 'l2_norm':
normalized_tensor = torch.nn.functional.normalize(select_prob_gate, p=2, dim=0) # L2 Normalization torch.Size([bz, topk])
elif self.weight_type == 'average':
normalized_tensor = select_prob_gate / select_prob_gate.sum(dim=1, keepdim=True)
# multiply outputs of experts by the routing probability
if self.weight_type == 'raw_prob':
expert_outputs_combined = expert_output * combine_tensor.unsqueeze(-1).unsqueeze(-1) # torch.Size([batch_size, num_expert, num_tokens, hidden_states])
elif self.weight_type == 'no_prob':
combine_index = torch.sum(expert_mask, dim=1)
expert_outputs_combined = expert_output * combine_index.unsqueeze(-1).unsqueeze(-1) # torch.Size([batch_size, num_expert, num_tokens, hidden_states])
num_sentences = F.one_hot(gate, self.num_experts).sum(1).gt(0).sum(0) # 每个expert被分配的样本数 torch.Size([num_expert])
gate_load = num_sentences.clone()
outputs = torch.sum(expert_outputs_combined, dim=1) # torch.Size([batch_size, num_tokens, hidden_states])
# load balancing loss
if self.use_balance_loss:
balance_loss = self._balancing_loss(prob_gate, num_sentences)
else:
balance_loss = 0.0
# import pdb; pdb.set_trace()
# importance loss
importance_loss = self._importance_auxiliary_loss(prob_gate)
# forward experts
def forward_expert(input_x, expert_idx):
input_x = self.experts[expert_idx].forward(input_x)
return input_x
result_lst = list()
for i in range(self.topk):
# top1、top2... 分别为一组进行gate分组之后过expert然后乘以概率后相加
tmp_gate = gate[:,i]
tmp_prob = normalized_tensor[:,i].unsqueeze(-1).unsqueeze(-1)
order = tmp_gate.argsort(0)
num_sentences_t = F.one_hot(tmp_gate, self.num_experts).gt(0).sum(0)
x1 = x[order] # reorder according to expert number
x1 = x1.split(num_sentences_t.tolist(), dim=0) # a list of length self.num_experts
result = []
for i in range(self.num_experts):
if x1[i].size(0) > 0:
result.append(forward_expert(x1[i], i))
result = torch.vstack(result)
result = result[order.argsort(0)] # restore original order
# result_lst.append(result * tmp_prob) # result * prob
result_lst.append(result) # result * prob # add_1212
moe_result = sum(result_lst)
# import pdb;pdb.set_trace()
return moe_result, (balance_loss+importance_loss), gate
def _forward_sentence_single_expert(self, x, attention_mask):
x_masked = x * attention_mask.unsqueeze(-1)
x_average = x_masked.sum(1) / attention_mask.unsqueeze(-1).sum(1)
logits_gate = self.gate(x_average)
prob_gate = F.softmax(logits_gate, dim=-1)
gate = torch.argmax(prob_gate, dim=-1)
gate_load = F.one_hot(gate, self.num_experts).gt(0).sum(0)
x = self.experts[gate.cpu().item()].forward(x)
return x, 0.0, gate_load
return outputs, (balance_loss+importance_loss), combine_tensor
def forward(self, x, attention_mask):
if self.route_method == "gate-token":
x, balance_loss, gate_load = self._forward_gate_token(x)
elif self.route_method == "gate-sentence":
if x.size(0) == 1:
x, balance_loss, gate_load = self._forward_sentence_single_expert(x, attention_mask)
else:
x, balance_loss, gate_load = self._forward_gate_sentence(x, attention_mask)
x, balance_loss, gate_load = self._forward_gate_sentence(x, attention_mask)
elif self.route_method == "gate-sentence-post":
x, balance_loss, gate_load = self._forward_gate_sentence_post(x, attention_mask)
else:
raise KeyError("Routing method not supported.")
# import pdb; pdb.set_trace()
return x, balance_loss, gate_load

View File

@ -0,0 +1,330 @@
import copy
import pickle
import torch
import torch.nn as nn
import torch.nn.functional as F
class MoELayer(nn.Module):
def __init__(self, hidden_size, expert, num_experts, route_method, topk=1, use_balance_loss=True, weight_type='l2_norm'):
# remove hash list
nn.Module.__init__(self)
self.num_experts = num_experts
self.experts = nn.ModuleList([copy.deepcopy(expert) for i in range(num_experts)])
self.route_method = route_method
self.topk = topk
self.use_balance_loss = use_balance_loss
self.weight_type = weight_type
if route_method in ["gate-token", "gate-sentence"]:
self.gate = nn.Linear(hidden_size, num_experts, bias=False).float()
elif route_method in ["gate-sentence-post"]:
gate = nn.Linear(hidden_size, 1, bias=False).float()
# self.gates = nn.ModuleList([copy.deepcopy(gate) for i in range(num_experts)])
self.gate = gate
else:
raise KeyError("Routing method not supported.")
def _balancing_loss(self, prob_gate, num_tokens):
# From MOEBERT
# compute the load balancing loss
# prob_gate是 [bz, num_expert]每个样本被分配给每个expert的概率
# 等价于 VMOE 中 _gshard_auxiliary_loss
P = prob_gate.mean(0) # torch.Size([num_expert]) 每个expert被分配到样本的平均概率
temp = num_tokens.float()
f = temp / temp.sum(0, keepdim=True) # 每个expert被分配的sample比例
balance_loss = self.num_experts * torch.sum(P * f)
return balance_loss
def _importance_auxiliary_loss(self, prob_gate):
# From VMOE
# _importance_auxiliary_loss
axis = tuple(range(prob_gate.ndim - 1)) # All except last.
importance_per_expert = torch.sum(prob_gate, dim=axis)
std_importance_per_expert = torch.std(importance_per_expert)
mean_importance_per_expert = torch.mean(importance_per_expert)
# Compute coefficient of variation (i.e. std/mean) squared.
return (std_importance_per_expert / mean_importance_per_expert)**2
def _forward_gate_token(self, x):
bsz, seq_len, dim = x.size()
x = x.view(-1, dim)
logits_gate = self.gate(x)
prob_gate = F.softmax(logits_gate, dim=-1)
gate = torch.argmax(prob_gate, dim=-1)
order = gate.argsort(0)
num_tokens = F.one_hot(gate, self.num_experts).gt(0).sum(0)
gate_load = num_tokens.clone()
x = x[order] # reorder according to expert number
x = x.split(num_tokens.tolist(), dim=0) # a list of length self.num_experts
# compute the load balancing loss
P = prob_gate.mean(0)
temp = num_tokens.float()
f = temp / temp.sum(0, keepdim=True)
balance_loss = self.num_experts * torch.sum(P * f)
prob_gate = prob_gate.gather(dim=1, index=gate.unsqueeze(1))
prob_gate = prob_gate[order]
prob_gate = prob_gate.split(num_tokens.tolist(), dim=0)
def forward_expert(input_x, prob_x, expert_idx):
input_x = self.experts[expert_idx].forward(input_x)
input_x = input_x * prob_x
return input_x
x = [forward_expert(x[i], prob_gate[i], i) for i in range(self.num_experts)]
x = torch.vstack(x)
x = x[order.argsort(0)] # restore original order
x = x.view(bsz, seq_len, dim)
return x, balance_loss, gate_load
def _forward_gate_sentence_top1_raw(self, x, attention_mask):
"""
x: query_attention_output , torch.Size([bz, 32, 768])
attention_mask: torch.ones([bz, 32])
### Notice:
the raw version of expert_attention_mask is the extended_attention_mask,
which will be add to attention_score directly
the values of extended_attention_mask are -0.0 or -10000
it should be adjust to 1/0 version to be processed by experts
"""
attention_mask = torch.ones(attention_mask.shape[0], attention_mask.shape[1]).to(x.device)
x_masked = x * attention_mask.unsqueeze(-1) # torch.Size([bz, 32, 768])
x_average = x_masked.sum(1) / attention_mask.unsqueeze(-1).sum(1) # torch.Size([bz, 768])
logits_gate = self.gate(x_average) # torch.Size([bz, num_experts])
prob_gate = F.softmax(logits_gate, dim=-1) # torch.Size([bz, num_experts])
gate = torch.argmax(prob_gate, dim=-1) # torch.Size([bz])
order = gate.argsort(0)
num_sentences = F.one_hot(gate, self.num_experts).gt(0).sum(0)
gate_load = num_sentences.clone()
x = x[order] # reorder according to expert number
x = x.split(num_sentences.tolist(), dim=0) # a list of length self.num_experts
# compute the load balancing loss
P = prob_gate.mean(0)
temp = num_sentences.float()
f = temp / temp.sum(0, keepdim=True)
balance_loss = self.num_experts * torch.sum(P * f)
prob_gate = prob_gate.gather(dim=1, index=gate.unsqueeze(1))
prob_gate = prob_gate[order]
prob_gate = prob_gate.split(num_sentences.tolist(), dim=0)
def forward_expert(input_x, prob_x, expert_idx):
input_x = self.experts[expert_idx].forward(input_x)
input_x = input_x * prob_x.unsqueeze(-1)
return input_x
result = []
for i in range(self.num_experts):
if x[i].size(0) > 0:
result.append(forward_expert(x[i], prob_gate[i], i))
result = torch.vstack(result)
result = result[order.argsort(0)] # restore original order
return result, balance_loss, gate_load
def _forward_gate_sentence_post(self, x, attention_mask):
"""
x: query_attention_output; torch.Size([bz, 32, 768])
attention_mask: torch.ones([bz, 32])
bz = 4
x = torch.randn(bz,32,768)
attention_mask = torch.ones([bz, 32])
"""
attention_mask = torch.ones(attention_mask.shape[0], attention_mask.shape[1]).to(x.device)
x_masked = x * attention_mask.unsqueeze(-1) # torch.Size([bz, 32, 768])
def forward_expert(input_x, expert_idx):
# input_x += torch.randn(4,32,768)
# return input_x
output_x = self.experts[expert_idx].forward(input_x)
return output_x
outputs = list()
logits_gate_lst = list()
for expert_idx in range(self.num_experts):
output_x = forward_expert(x_masked, expert_idx)
outputs.append(output_x.unsqueeze(0))
output_x_aver = output_x.sum(1) / attention_mask.unsqueeze(-1).sum(1) # torch.Size([bz, 768])
# gate_acore = self.gates[expert_idx](output_x_aver)
gate_score = self.gate(output_x_aver)
logits_gate_lst.append(gate_score)
candidate_output = torch.cat(outputs) # torch.Size([num_expert, bz, 32, 768])
logits_gate = torch.cat(logits_gate_lst,dim=1)# torch.Size([bz, num_expert])
prob_gate = F.softmax(logits_gate, dim=-1) # torch.Size([bz, num_experts])
topk_values, gate = torch.topk(prob_gate, self.topk, dim=1) # gate, 每个样本被分配的expert: torch.Size([bz, topk])
num_sentences = F.one_hot(gate, self.num_experts).sum(1).gt(0).sum(0) # 每个expert被分配的样本数 torch.Size([num_expert])
gate_load = num_sentences.clone()
# load balancing loss
if self.use_balance_loss:
balance_loss = self._balancing_loss(prob_gate, num_sentences)
else:
balance_loss = 0.0
# importance loss
importance_loss = self._importance_auxiliary_loss(prob_gate)
# output_average = candidate_output.sum(2) / candidate_attn_mask.unsqueeze(-1).sum(2) # torch.Size([num_expert, bz, 768])
# output_average = torch.permute(output_average, (1, 0, 2)) # torch.Size([bz, num_expert, 768])
# logits_gate = self.gate(output_average) # torch.Size([bz, num_experts, 1])
prob_gate_topk = torch.zeros_like(prob_gate)
prob_gate_topk.scatter_(1, gate, topk_values)
if self.weight_type == 'average':
# torch.Size([bz, num_expert]) 未选中的expert prob_gate_norm为0
prob_gate_normalized = prob_gate_topk / prob_gate_topk.sum(dim=1, keepdim=True)
elif self.weight_type == 'raw_prob':
prob_gate_normalized = prob_gate_topk
elif self.weight_type == 'softmax_norm':
prob_gate_normalized = F.softmax(prob_gate_topk, dim=-1) # torch.Size([bz, num_expert])
candidate_output_ad = torch.permute(candidate_output, (1, 0, 2, 3)) # torch.Size([bz, num_expert, 32, 768])
results = prob_gate_normalized.unsqueeze(-1).unsqueeze(-1) * candidate_output_ad # torch.Size([bz, num_expert, 32, 768])
moe_result = torch.sum(results, dim=1) # torch.Size([bz, 32, 768])
# import pdb;pdb.set_trace()
return moe_result, (balance_loss+importance_loss), prob_gate_normalized
# def _forward_gate_sentence(self, x, attention_mask):
# attention_mask = torch.ones(attention_mask.shape[0], attention_mask.shape[1]).to(x.device)
# x_masked = x * attention_mask.unsqueeze(-1) # torch.Size([bz, 32, 768])
# x_average = x_masked.sum(1) / attention_mask.unsqueeze(-1).sum(1)
# logits_gate = self.gate(x_average)
# prob_gate = F.softmax(logits_gate, dim=-1)
# gate = torch.argmax(prob_gate, dim=-1)
# order = gate.argsort(0)
# num_sentences = F.one_hot(gate, self.num_experts).gt(0).sum(0)
# gate_load = num_sentences.clone()
# x = x[order] # reorder according to expert number
# x = x.split(num_sentences.tolist(), dim=0) # a list of length self.num_experts
# # compute the load balancing loss
# P = prob_gate.mean(0)
# temp = num_sentences.float()
# f = temp / temp.sum(0, keepdim=True)
# balance_loss = self.num_experts * torch.sum(P * f)
# prob_gate = prob_gate.gather(dim=1, index=gate.unsqueeze(1))
# prob_gate = prob_gate[order]
# prob_gate = prob_gate.split(num_sentences.tolist(), dim=0)
# def forward_expert(input_x, prob_x, expert_idx):
# input_x = self.experts[expert_idx].forward(input_x)
# input_x = input_x * prob_x.unsqueeze(-1)
# return input_x
# result = []
# for i in range(self.num_experts):
# if x[i].size(0) > 0:
# result.append(forward_expert(x[i], prob_gate[i], i))
# result = torch.vstack(result)
# result = result[order.argsort(0)] # restore original order
# return result, balance_loss, gate_load
def _forward_gate_sentence(self, x, attention_mask):
"""
x: query_attention_output , torch.Size([bz, 32, 768])
attention_mask: torch.ones([bz, 32])
### Notice:
the raw version of expert_attention_mask is the extended_attention_mask,
which will be add to attention_score directly
the values of extended_attention_mask are -0.0 or -10000
it should be adjust to 1/0 version to be processed by experts
"""
attention_mask = torch.ones(attention_mask.shape[0], attention_mask.shape[1]).to(x.device)
x_masked = x * attention_mask.unsqueeze(-1) # torch.Size([bz, 32, 768])
x_average = x_masked.sum(1) / attention_mask.unsqueeze(-1).sum(1) # torch.Size([bz, 768])
logits_gate = self.gate(x_average) # torch.Size([bz, num_experts])
prob_gate = F.softmax(logits_gate, dim=-1) # torch.Size([bz, num_experts])
select_prob_gate, gate = torch.topk(prob_gate, self.topk, dim=1) # gate, 每个样本被分配的expert: torch.Size([bz, topk])
# 这里用l2 norm 去加权
if self.weight_type == 'l2_norm':
# actually neigther dim=0 nor dim=1 is right
normalized_tensor = torch.nn.functional.normalize(select_prob_gate, p=2, dim=1) # L2 Normalization torch.Size([bz, topk])
elif self.weight_type == 'l2_norm_0':
normalized_tensor = torch.nn.functional.normalize(select_prob_gate, p=2, dim=0) # L2 Normalization torch.Size([bz, topk])
elif self.weight_type == 'average':
normalized_tensor = select_prob_gate / select_prob_gate.sum(dim=1, keepdim=True)
elif self.weight_type == 'raw_prob':
normalized_tensor = select_prob_gate
num_sentences = F.one_hot(gate, self.num_experts).sum(1).gt(0).sum(0) # 每个expert被分配的样本数 torch.Size([num_expert])
gate_load = num_sentences.clone()
# load balancing loss
if self.use_balance_loss:
balance_loss = self._balancing_loss(prob_gate, num_sentences)
else:
balance_loss = 0.0
# importance loss
importance_loss = self._importance_auxiliary_loss(prob_gate)
# forward experts
def forward_expert(input_x, expert_idx):
input_x = self.experts[expert_idx].forward(input_x)
return input_x
result_lst = list()
for i in range(self.topk):
# top1、top2... 分别为一组进行gate分组之后过expert然后乘以概率后相加
tmp_gate = gate[:,i]
tmp_prob = normalized_tensor[:,i].unsqueeze(-1).unsqueeze(-1)
order = tmp_gate.argsort(0)
num_sentences_t = F.one_hot(tmp_gate, self.num_experts).gt(0).sum(0)
x1 = x[order] # reorder according to expert number
x1 = x1.split(num_sentences_t.tolist(), dim=0) # a list of length self.num_experts
result = []
for i in range(self.num_experts):
if x1[i].size(0) > 0:
result.append(forward_expert(x1[i], i))
result = torch.vstack(result)
result = result[order.argsort(0)] # restore original order
result_lst.append(result * tmp_prob) # result * prob
# result_lst.append(result) # result * prob # add_1212
moe_result = sum(result_lst)
return moe_result, (balance_loss+importance_loss), gate
def _forward_sentence_single_expert(self, x, attention_mask):
x_masked = x * attention_mask.unsqueeze(-1)
x_average = x_masked.sum(1) / attention_mask.unsqueeze(-1).sum(1)
logits_gate = self.gate(x_average)
prob_gate = F.softmax(logits_gate, dim=-1)
gate = torch.argmax(prob_gate, dim=-1)
gate_load = F.one_hot(gate, self.num_experts).gt(0).sum(0)
x = self.experts[gate.cpu().item()].forward(x)
return x, 0.0, gate_load
def forward(self, x, attention_mask):
if self.route_method == "gate-token":
x, balance_loss, gate_load = self._forward_gate_token(x)
elif self.route_method == "gate-sentence":
if x.size(0) == 1:
x, balance_loss, gate_load = self._forward_sentence_single_expert(x, attention_mask)
else:
x, balance_loss, gate_load = self._forward_gate_sentence(x, attention_mask)
elif self.route_method == "gate-sentence-post":
x, balance_loss, gate_load = self._forward_gate_sentence_post(x, attention_mask)
else:
raise KeyError("Routing method not supported.")
# import pdb; pdb.set_trace()
return x, balance_loss, gate_load

View File

@ -92,7 +92,6 @@ class PrePromptMoE(PromptMoEBase):
self.topk = topk
if route_method in ["gate-token", "gate-single-token", "gate-sentence"]:
self.gate = nn.Linear(hidden_size, num_experts, bias=False).float()
print(self.gate)
else:
raise KeyError("Routing method not supported.")

View File

@ -5,7 +5,7 @@ import torch.nn as nn
import torch.nn.functional as F
class RouteMoELayer(nn.Module):
def __init__(self, hidden_size, expert, num_experts, num_beams=2, layer_judge=None, route_method="pre-route"):
def __init__(self, hidden_size, expert, num_experts, num_beams=2, layer_judge=None, route_method="pre-route", weight_type="ffn_prob"):
# remove hash list
nn.Module.__init__(self)
self.num_experts = num_experts
@ -13,6 +13,7 @@ class RouteMoELayer(nn.Module):
self.num_beams = num_beams
self.hidden_size = hidden_size
self.layer_judge = layer_judge
self.weight_type = weight_type
self.route_method = route_method
if self.route_method == "pre-route":
@ -22,6 +23,17 @@ class RouteMoELayer(nn.Module):
self.gate = gate
# self.gates = nn.ModuleList([copy.deepcopy(gate) for i in range(num_experts)])
def _importance_auxiliary_loss(self, prob_gate):
# From VMOE
# _importance_auxiliary_loss
axis = tuple(range(prob_gate.ndim - 1)) # All except last.
importance_per_expert = torch.sum(prob_gate, dim=axis)
std_importance_per_expert = torch.std(importance_per_expert)
mean_importance_per_expert = torch.mean(importance_per_expert)
# Compute coefficient of variation (i.e. std/mean) squared.
return (std_importance_per_expert / mean_importance_per_expert)**2
def forward_gate(self, x):
"""
x : torch.Size([bz*num_beams, 32, 768]) or torch.Size([bz, 32, 768])
@ -29,7 +41,8 @@ class RouteMoELayer(nn.Module):
"""
attention_mask = torch.ones(x.shape[0], x.shape[1]).to(x.device)
x_masked = x * attention_mask.unsqueeze(-1) # torch.Size([bz*num_beams, 32, 768])
x_average = x_masked.sum(1) / attention_mask.unsqueeze(-1).sum(1) # torch.Size([bz*num_beams, 768])
# x_average = x_masked.sum(1) / attention_mask.unsqueeze(-1).sum(1) # torch.Size([bz*num_beams, 768])
x_average = torch.mean(x_masked, dim=1) # torch.Size([bz*num_beams, 768])
logits_gate = self.gate(x_average) # torch.Size([bz*num_beams, num_experts])
prob_gate = F.softmax(logits_gate, dim=-1) # torch.Size([bz*num_beams, num_experts])
return prob_gate
@ -42,7 +55,7 @@ class RouteMoELayer(nn.Module):
topk_values, gate = torch.topk(current_scores, self.num_beams, dim=1) # gate, 每个样本被分配的expert: torch.Size([bz, topk])
beam_scores = topk_values.view(self.num_beams * batch_size) # torch.Size([bz * num_beams])
expert_route = gate.view(self.num_beams * batch_size).unsqueeze(1) # torch.Size([bz * num_beams,1])
beam_idx = None
beam_idx = torch.tensor(range(self.num_beams * batch_size))
else:
if self.layer_judge=='first' and self.route_method == 'post-route':
batch_size = batch_size
@ -89,54 +102,63 @@ class RouteMoELayer(nn.Module):
return beam_scores, expert_route, beam_idx
def forward_expert_ffn(self, x, expert_select, beam_scores):
def forward_expert_ffn(self, x, expert_select, current_scores):
"""
x_repeat : [bz*num_beams, 32,768]
expert_select : [bz*num_beams]
current_scores : [bz*num_beams, num_experts] / [bz, num_experts]
"""
# add_1212 l2_normalization
# normalized_tensor = torch.nn.functional.normalize(beam_scores, p=2, dim=0) # L2 Normalization torch.Size([bz, topk])
# add_1228 l2_normalization
# normalized_tensor = torch.nn.functional.normalize(current_scores, p=2, dim=0) # L2 Normalization torch.Size([bz, topk])
# tmp_prob = normalized_tensor.unsqueeze(-1).unsqueeze(-1)
# import pdb;pdb.set_trace()
outputs = list()
for i in range(x.shape[0]):
output_x = self.experts[expert_select[i]].forward(x[i])
outputs.append(output_x.unsqueeze(0))
candidate_output = torch.cat(outputs)
# candidate_output = candidate_output * tmp_prob
return candidate_output # torch.Size([bz*num_beams, 32, 768])
for i in range(self.num_experts):
output_x = self.experts[i].forward(x)
outputs.append(output_x.unsqueeze(1))
candidate_output = torch.cat(outputs, dim=1)
expert_select_matrix = F.one_hot(expert_select, self.num_experts)
if self.weight_type == 'ffn_prob':
tmp_prob = current_scores * expert_select_matrix
candidate_output = candidate_output * tmp_prob.unsqueeze(-1).unsqueeze(-1)
else:
candidate_output = candidate_output * expert_select_matrix.unsqueeze(-1).unsqueeze(-1)
output = torch.sum(candidate_output, dim=1)
# import pdb;pdb.set_trace()
return output # torch.Size([bz*num_beams, 32, 768])
def forward_pre_route(self, x, beam_scores, expert_route, use_log=True):
current_scores = self.forward_gate(x) # [bz*num_beams, 5]
current_scores = self.forward_gate(x) # [bz, num_beams] / [bz*num_beams, num_beams]
importance_loss = self._importance_auxiliary_loss(current_scores)
if use_log:
current_scores_log = torch.log(current_scores) # 取log之后可以直接相加
else:
current_scores_log = current_scores
# import pdb;pdb.set_trace()
batch_size, num_tokens = x.shape[0], x.shape[1]
beam_scores, expert_route, beam_idx = self.beam_search(current_scores_log, beam_scores, expert_route, batch_size)
current_expert_select = expert_route[:,-1]
if self.layer_judge=='first': # expand first dim to batch_size * num_beams
replicated_tensor = x.unsqueeze(1).expand(batch_size, self.num_beams, num_tokens, self.hidden_size)
x = replicated_tensor.contiguous().view(-1, num_tokens, self.hidden_size) # [bz*num_beams, 32,768]
current_scores_t = current_scores.unsqueeze(1).expand(batch_size, self.num_beams, self.num_experts)
current_scores = current_scores_t.contiguous().view(-1, self.num_experts) # [bz*num_beams, num_experts]
candidate_output = self.forward_expert_ffn(x, current_expert_select, beam_scores) # [bz*num_beams, 32,768]
return candidate_output, beam_scores, expert_route, beam_idx
input_x = x[beam_idx]
candidate_output = self.forward_expert_ffn(input_x, current_expert_select, current_scores) # [bz*num_beams, 32,768]
# import pdb;pdb.set_trace()
return candidate_output, beam_scores, expert_route, beam_idx, importance_loss
def forward_post_route(self, x, beam_scores, expert_route, use_log=True):
attention_mask = torch.ones(x.shape[0], x.shape[1]).to(x.device)
x_masked = x * attention_mask.unsqueeze(-1) # torch.Size([bz, 32, 768])
def forward_expert(input_x, expert_idx):
output_x = self.experts[expert_idx].forward(input_x)
return output_x
@ -145,12 +167,14 @@ class RouteMoELayer(nn.Module):
logits_gate_lst = list()
for expert_idx in range(self.num_experts):
output_x = forward_expert(x_masked, expert_idx)
outputs.append(output_x.unsqueeze(0))
output_x_aver = output_x.sum(1) / attention_mask.unsqueeze(-1).sum(1) # torch.Size([bz*num_beam, 768])
# output_x_aver = output_x.sum(1) / attention_mask.unsqueeze(-1).sum(1) # torch.Size([bz*num_beam, 768])
output_x_aver = torch.mean(output_x, dim=1)
# gate_score = self.gates[expert_idx](output_x_aver)
gate_score = self.gate(output_x_aver)
logits_gate_lst.append(gate_score)
candidate_output = torch.cat(outputs) # torch.Size([num_expert, bz*num_beam, 32, 768])
outputs.append(output_x.unsqueeze(0))
candidate_output_raw = torch.cat(outputs) # torch.Size([num_expert, bz*num_beam, 32, 768])
logits_gate = torch.cat(logits_gate_lst,dim=1)# torch.Size([bz*num_beam, num_expert])
current_scores = F.softmax(logits_gate, dim=-1) # torch.Size([bz*num_beam, num_experts])
@ -159,24 +183,33 @@ class RouteMoELayer(nn.Module):
else:
current_scores_log = current_scores
batch_size = x.shape[0] # bz*num_beam
# importance loss
importance_loss = self._importance_auxiliary_loss(current_scores)
batch_size, num_tokens = x.shape[0], x.shape[1] # bz*num_beam
beam_scores, expert_route, beam_idx = self.beam_search(current_scores_log, beam_scores, expert_route, batch_size)
# beam_scores torch.Size([bz*num_beam])
# expert_route torch.Size([bz*num_beam, layer_n])
current_select_expert = expert_route[:,-1]
# current_select_expert torch.Size([bz*num_beam, 1])
output = list()
for i in range(beam_idx.shape[0]):
b_idx = beam_idx[i]
ex_idx = current_select_expert[i]
ex_out = candidate_output[ex_idx, b_idx, :,:]
output.append(ex_out.unsqueeze(0))
final_output = torch.concat(output, dim=0)
return final_output, beam_scores, expert_route, beam_idx
if self.layer_judge == 'first':
replicated_tensor = candidate_output_raw.unsqueeze(2).expand(self.num_experts, batch_size, self.num_beams, num_tokens, self.hidden_size)
candidate_output_raw = replicated_tensor.contiguous().view(self.num_experts, -1, num_tokens, self.hidden_size) # [bz*num_beams, 32,768]
current_scores_t = current_scores.unsqueeze(1).expand(batch_size, self.num_beams, self.num_experts)
current_scores = current_scores_t.contiguous().view(-1, self.num_experts) # [bz*num_beams, num_experts]
candidate_output = candidate_output_raw.permute(1, 0, 2, 3)[beam_idx] # torch.Size([8, 2, 32, 768])
expert_select_matrix = F.one_hot(current_select_expert, self.num_experts)
if self.weight_type == 'ffn_prob':
tmp_prob = current_scores[beam_idx] * expert_select_matrix
output = candidate_output * tmp_prob.unsqueeze(-1).unsqueeze(-1)
else:
output = candidate_output * expert_select_matrix.unsqueeze(-1).unsqueeze(-1)
final_output = torch.sum(output, dim=1)
return final_output, beam_scores, expert_route, beam_idx, importance_loss
def forward(self, x, attention_mask, beam_scores, expert_route, use_log=True):
"""
if first_layer: x [bz, 32, 768]
@ -184,11 +217,11 @@ class RouteMoELayer(nn.Module):
"""
if self.route_method == 'pre-route':
candidate_output, beam_scores, expert_route, beam_idx = self.forward_pre_route(x, beam_scores, expert_route, use_log=True)
candidate_output, beam_scores, expert_route, beam_idx, importance_loss = self.forward_pre_route(x, beam_scores, expert_route, use_log=True)
elif self.route_method == "post-route":
candidate_output, beam_scores, expert_route, beam_idx = self.forward_post_route(x, beam_scores, expert_route, use_log=True)
candidate_output, beam_scores, expert_route, beam_idx, importance_loss = self.forward_post_route(x, beam_scores, expert_route, use_log=True)
return candidate_output, beam_scores, expert_route, beam_idx
return candidate_output, beam_scores, expert_route, beam_idx, importance_loss

View File

@ -0,0 +1,294 @@
import copy
import pickle
import torch
import torch.nn as nn
import torch.nn.functional as F
import copy
import pickle
import torch
import torch.nn as nn
import torch.nn.functional as F
class MoELayer(nn.Module):
def __init__(self, hidden_size, expert, num_experts, route_method, topk=1, use_balance_loss=True, weight_type='raw_prob, topk(softmax)'):
# remove hash list
nn.Module.__init__(self)
self.num_experts = num_experts
self.experts = nn.ModuleList([copy.deepcopy(expert) for i in range(num_experts)])
self.route_method = route_method
self.topk = topk
self.use_balance_loss = use_balance_loss
self.weight_type = weight_type
if route_method in ["gate-token", "gate-sentence"]:
self.gate = nn.Linear(hidden_size, num_experts, bias=False).float()
elif route_method in ["gate-sentence-post"]:
gate = nn.Linear(hidden_size, 1, bias=False).float()
# self.gates = nn.ModuleList([copy.deepcopy(gate) for i in range(num_experts)])
self.gate = gate
else:
raise KeyError("Routing method not supported.")
def _balancing_loss(self, prob_gate, num_tokens):
# From MOEBERT
# compute the load balancing loss
# prob_gate是 [bz, num_expert]每个样本被分配给每个expert的概率
# 等价于 VMOE 中 _gshard_auxiliary_loss
P = prob_gate.mean(0) # torch.Size([num_expert]) 每个expert被分配到样本的平均概率
temp = num_tokens.float()
f = temp / temp.sum(0, keepdim=True) # 每个expert被分配的sample比例
balance_loss = self.num_experts * torch.sum(P * f)
return balance_loss
def _importance_auxiliary_loss(self, prob_gate):
# From VMOE
# _importance_auxiliary_loss
axis = tuple(range(prob_gate.ndim - 1)) # All except last.
importance_per_expert = torch.sum(prob_gate, dim=axis)
std_importance_per_expert = torch.std(importance_per_expert)
mean_importance_per_expert = torch.mean(importance_per_expert)
# Compute coefficient of variation (i.e. std/mean) squared.
return (std_importance_per_expert / mean_importance_per_expert)**2
def _forward_gate_token(self, x):
bsz, seq_len, dim = x.size()
x = x.view(-1, dim)
logits_gate = self.gate(x)
prob_gate = F.softmax(logits_gate, dim=-1)
gate = torch.argmax(prob_gate, dim=-1)
order = gate.argsort(0)
num_tokens = F.one_hot(gate, self.num_experts).gt(0).sum(0)
gate_load = num_tokens.clone()
x = x[order] # reorder according to expert number
x = x.split(num_tokens.tolist(), dim=0) # a list of length self.num_experts
# compute the load balancing loss
P = prob_gate.mean(0)
temp = num_tokens.float()
f = temp / temp.sum(0, keepdim=True)
balance_loss = self.num_experts * torch.sum(P * f)
prob_gate = prob_gate.gather(dim=1, index=gate.unsqueeze(1))
prob_gate = prob_gate[order]
prob_gate = prob_gate.split(num_tokens.tolist(), dim=0)
def forward_expert(input_x, prob_x, expert_idx):
input_x = self.experts[expert_idx].forward(input_x)
input_x = input_x * prob_x
return input_x
x = [forward_expert(x[i], prob_gate[i], i) for i in range(self.num_experts)]
x = torch.vstack(x)
x = x[order.argsort(0)] # restore original order
x = x.view(bsz, seq_len, dim)
return x, balance_loss, gate_load
def _forward_gate_sentence_post(self, x, attention_mask):
"""
x: query_attention_output; torch.Size([bz, 32, 768])
attention_mask: torch.ones([bz, 32])
bz = 4
x = torch.randn(bz,32,768)
attention_mask = torch.ones([bz, 32])
"""
# Prepare Input x
attention_mask = torch.ones(attention_mask.shape[0], attention_mask.shape[1]).to(x.device)
x_masked = x * attention_mask.unsqueeze(-1) # torch.Size([bz, 32, 768])
# FeedForward(x) & Forward Gate
outputs = list()
logits_gate_lst = list()
for expert_idx in range(self.num_experts):
output_x = self.experts[expert_idx].forward(x_masked)
outputs.append(output_x.unsqueeze(0))
output_x_aver = torch.mean(output_x, dim=1)
# gate_acore = self.gates[expert_idx](output_x_aver)
gate_score = self.gate(output_x_aver)
logits_gate_lst.append(gate_score)
candidate_output = torch.cat(outputs) # torch.Size([num_expert, bz, 32, 768])
logits_gate = torch.cat(logits_gate_lst,dim=1)# torch.Size([bz, num_expert])
# Probabilities for each sample of what expert it should be sent to.
prob_gate = F.softmax(logits_gate, dim=-1) # torch.Size([bz, num_experts])
if 'softmax(topk)' in self.weight_type:
prob_gate1, gate = torch.topk(logits_gate, self.topk, dim=1)
select_prob_gate = F.softmax(prob_gate1, dim=-1)
else:
select_prob_gate, gate = torch.topk(prob_gate, self.topk, dim=1) # gate, 每个样本被分配的expert: torch.Size([bz, topk])
# Calculate Balancing Loss
if self.use_balance_loss:
num_sentences = F.one_hot(gate, self.num_experts).sum(1).gt(0).sum(0) # 每个expert被分配的样本数 torch.Size([num_expert])
balance_loss = self._balancing_loss(prob_gate, num_sentences)
else:
balance_loss = 0.0
# Calculate Importance Loss
importance_loss = self._importance_auxiliary_loss(prob_gate)
# Reshap Prob_gate & Gate
# expert_mask: [batch_size, topk, num_experts]
# expert_gate: [batch_size, topk, num_experts]
# combine_tensor: [batch_size, num_experts]
expert_mask = F.one_hot(gate, self.num_experts)
expert_gate = select_prob_gate.unsqueeze(-1) * expert_mask
combine_tensor = torch.sum(expert_gate, dim=1)
# combine_tensor = torch.zeros_like(prob_gate)
# combine_tensor.scatter_(1, gate, select_prob_gate) # 等价操作,但可能不可导
candidate_output_ad = torch.permute(candidate_output, (1, 0, 2, 3)) # torch.Size([bz, num_expert, 32, 768])
results = candidate_output_ad * combine_tensor.unsqueeze(-1).unsqueeze(-1) # torch.Size([bz, num_expert, 32, 768])
outputs = torch.sum(results, dim=1) # torch.Size([bz, 32, 768])
import pdb;pdb.set_trace()
return outputs, (balance_loss+importance_loss), combine_tensor
def pre_router(self, x, attention_mask):
# Prepare input x
attention_mask = torch.ones(attention_mask.shape[0], attention_mask.shape[1]).to(x.device)
x_masked = x * attention_mask.unsqueeze(-1) # torch.Size([bz, 32, 768])
x_average = torch.mean(x_masked, dim=1) # torch.Size([bz, 768])
# Forward Gate
# logits_gate: [bz, num_experts]
logits_gate = self.gate(x_average)
# Probabilities for each sample of what expert it should be sent to.
# prob_gate: [bz, num_experts]
prob_gate = F.softmax(logits_gate, dim=-1)
if 'softmax(topk)' in self.weight_type:
prob_gate1, gate = torch.topk(logits_gate, self.topk, dim=1)
select_prob_gate = F.softmax(prob_gate1, dim=-1)
else:
# topk(softmax)
# Get Top-K experts for each sample
# gate: [bz, topk]
# select_prob_gate: [bz, topk]
select_prob_gate, gate = torch.topk(prob_gate, self.topk, dim=1)
# Reshap Prob_gate & Gate
# expert_mask: [batch_size, topk, num_experts]
# expert_gate: [batch_size, topk, num_experts]
# combine_tensor: [batch_size, num_experts]
expert_mask = F.one_hot(gate, self.num_experts)
expert_gate = select_prob_gate.unsqueeze(-1) * expert_mask
combine_tensor = torch.sum(expert_gate, dim=1)
# Calculate Balancing Loss
if self.use_balance_loss:
num_sentences = F.one_hot(gate, self.num_experts).sum(1).gt(0).sum(0) # 每个expert被分配的样本数 torch.Size([num_expert])
balance_loss = self._balancing_loss(prob_gate, num_sentences)
else:
balance_loss = 0.0
# Calculate Importance Loss
importance_loss = self._importance_auxiliary_loss(prob_gate)
import pdb; pdb.set_trace()
return expert_mask, combine_tensor, balance_loss, importance_loss
def _forward_gate_sentence(self, x, attention_mask):
"""
x: query_attention_output , torch.Size([bz, 32, 768])
attention_mask: torch.ones([bz, 32])
### Notice:
the raw version of expert_attention_mask is the extended_attention_mask,
which will be add to attention_score directly
the values of extended_attention_mask are -0.0 or -10000
it should be adjust to 1/0 version to be processed by experts
"""
# Forward Router
expert_mask, combine_tensor, balance_loss, importance_loss = self.pre_router(x, attention_mask)
# Forward Expert FFN
result = []
for expert_idx in range(self.num_experts):
output_x = self.experts[expert_idx].forward(x)
result.append(output_x.unsqueeze(0))
expert_output = torch.cat(result).permute(1,0,2,3) # torch.Size([batch_size, num_expert, num_tokens, hidden_states])
# multiply outputs of experts by the routing probability
expert_outputs_combined = expert_output * combine_tensor.unsqueeze(-1).unsqueeze(-1) # torch.Size([batch_size, num_expert, num_tokens, hidden_states])
outputs = torch.sum(expert_outputs_combined, dim=1) # torch.Size([batch_size, num_tokens, hidden_states])
import pdb; pdb.set_trace()
return outputs, (balance_loss+importance_loss), combine_tensor
def forward(self, x, attention_mask):
if self.route_method == "gate-token":
x, balance_loss, gate_load = self._forward_gate_token(x)
elif self.route_method == "gate-sentence":
x, balance_loss, gate_load = self._forward_gate_sentence(x, attention_mask)
elif self.route_method == "gate-sentence-post":
x, balance_loss, gate_load = self._forward_gate_sentence_post(x, attention_mask)
else:
raise KeyError("Routing method not supported.")
# import pdb; pdb.set_trace()
return x, balance_loss, gate_load
if __name__ == '__main__':
import sys
sys.path.append("/mnt/pfs-guan-ssai/nlu/wanghanzi/multimodal/PromptMoE")
from minigpt4.models.QformerRouteMoE import BertConfig
from minigpt4.models.QformerRouteMoE import FeedForward
from minigpt4.models.moe.utils import (
moe_layer_judge,
)
vision_width = 1408
cross_attention_freq = 2
num_query_token = 32
# init_QformerMoE
config = BertConfig.from_pretrained("/mnt/pfs-guan-ssai/nlu/wanghanzi/models/bert-base-uncased")
config.encoder_width = vision_width
# insert cross-attention layer every other block
config.add_cross_attention = True
config.cross_attention_freq = cross_attention_freq
config.query_length = num_query_token
config.moebert_expert_num = 3
config.moebert_num_beams = 2
config.moebert_route_method = 'gate-sentence-post'
config.moe_topk = 1
config.use_balance_loss = False
# config.moe_weight_type = 'raw_prob, softmax(topk)'
config.moe_weight_type = 'raw_prob, topk(softmax)'
batch_size = 4
x2 = torch.randn(batch_size, 32, 768)
beam_scores, expert_route = None, None
for layer_num in [6, 8, 10]:
layer_judge = moe_layer_judge(layer_num)
ffn = FeedForward(config)
gate = nn.Linear(768, config.moebert_expert_num, bias=False).float()
experts_moe = MoELayer(
hidden_size=config.hidden_size,
expert=ffn,
num_experts=config.moebert_expert_num,
route_method=config.moebert_route_method,
topk=config.moe_topk,
use_balance_loss=config.use_balance_loss,
weight_type=config.moe_weight_type,
)
attn_mask = torch.ones([batch_size, 32])
layer_output = experts_moe(x2, attn_mask)
hidden_states3, aux_loss, combine_tensor = layer_output
print(combine_tensor)
print(aux_loss)
x2 = hidden_states3
print("------------------------------------")
import pdb; pdb.set_trace()

View File

@ -19,15 +19,33 @@ def use_experts(layer_idx):
else:
return False
def use_experts_route(layer_idx):
# if layer_idx % 2 == 0:
# use moe_ffn after cross_attns
# if int(layer_idx) in [0,2,4,6,8,10]:
if int(layer_idx) in [6,7,8,9,10,11]:
return True
else:
return False
def moe_layer_judge(layer_idx):
if layer_idx == 6:
return 'first'
elif layer_idx == 8:
elif layer_idx in [7,8,9,10]:
return 'mid'
elif layer_idx == 10:
elif layer_idx == 11:
return 'last'
else:
return None
# if layer_idx == 0:
# return 'first'
# elif layer_idx in [2,4,6,8]:
# return 'mid'
# elif layer_idx == 10:
# return 'last'
# else:
# return None
def process_ffn(model):
if model.config.model_type == "bert":

View File

@ -10,7 +10,6 @@ model:
load_finetuned: False
vit_model: eva_clip_g
pretrained: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/blip2/blip2_vicuna7b/blip2_pretrained_vicuna7b.pth"
# finetuned: "/mnt/pfs-guan-ssai/nlu/wanghanzi/experiments/blip2/vicuna7b/qformer_moe/mix_coco_gqa_balance_raw_QformerMoE_Post_train_qf_train_qt_aver_weight_5ex_top1_1loss_textinqf_epo3_s42_1201/20231201184/checkpoint_best.pth"
finetuned: ""
q_former_model: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/blip2/blip2_vicuna7b/blip2_pretrained_vicuna7b.pth"
@ -38,7 +37,7 @@ model:
# moe
use_moeqformer: True
moebert_expert_num: 5
moebert_expert_num: 3
moebert_route_method: "gate-sentence-post"
moebert_load_balance: 0
moe_topk: 1
@ -110,6 +109,7 @@ run:
max_epoch: 1
num_workers: 4
warmup_steps: 600
iters_per_epoch: 1000
seed: 42
output_dir: "/mnt/pfs-guan-ssai/nlu/wanghanzi/experiments/blip2/vicuna7b/qformer_moe/mix_coco_gqa_balance_raw_QformerMoE_Post_train_qf_train_qt_aver_weight_5ex_top1_1loss_textinqf_epo3_s42_1201/"

View File

@ -10,7 +10,7 @@ model:
load_finetuned: True
vit_model: eva_clip_g
pretrained: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/blip2/blip2_vicuna7b/blip2_pretrained_vicuna7b.pth"
finetuned: "/mnt/pfs-guan-ssai/nlu/wanghanzi/experiments/blip2/vicuna7b/qformer_moe_route/mix_coco_gqa_1610k_raw_QformerMoE_Route_Post_linear_gate_3ex_3beam_1loss_top3layer_log_textinqf_epo3_1216/20231216155/checkpoint_best.pth"
finetuned: "/mnt/pfs-guan-ssai/nlu/wanghanzi/experiments/blip2/vicuna7b/qformer_moe_route/mix_coco_gqa_1610k_raw_QformerMoE_Route_Post_ffn_prob_linear_gate_2ex_2beam_1gate_2loss_5e5lr_top6layer_textinqf_epo8_0112/20240112212/checkpoint_best.pth"
q_former_model: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/blip2/blip2_vicuna7b/blip2_pretrained_vicuna7b.pth"
# vit encoder
@ -38,10 +38,12 @@ model:
# moe
use_moeqformer: True
use_route_moe: True
moebert_expert_num: 3
moebert_num_beams: 3
moebert_route_method: "post-route"
gate_save_path: "/mnt/pfs-guan-ssai/nlu/wanghanzi/experiments/blip2/vicuna7b/qformer_moe_route/route_save/mix_coco_gqa_balance_raw_QformerMoE_Route_linear_gate_5ex_2beam_1loss_textinqf_epo5_toplayer3_1209_eval_latest1/"
moebert_load_balance: 0
moebert_expert_num: 2
moebert_num_beams: 2
moe_weight_type: 'ffn_prob'
gate_save_path: "/mnt/pfs-guan-ssai/nlu/wanghanzi/experiments/blip2/vicuna7b/qformer_moe_route/route_save/mix_coco_gqa_1610k_raw_QformerMoE_Route_Post_ffn_prob_linear_gate_2ex_2beam_1gate_2loss_5e5lr_top6layer_textinqf_epo8_0112/"
datasets:
gqa:
@ -81,19 +83,20 @@ run:
task: instruction_tuning
# optimizer
lr_sched: "linear_warmup_cosine_lr"
init_lr: 2e-5
init_lr: 5e-5
min_lr: 1e-6
warmup_lr: 1e-6
log_freq: 5
save_freq: 1500
weight_decay: 0.05
max_epoch: 5
max_epoch: 10
num_workers: 4
warmup_steps: 600
iters_per_epoch: 3000
seed: 42
output_dir: "/mnt/pfs-guan-ssai/nlu/wanghanzi/experiments/blip2/vicuna7b/qformer_moe_route/eval/mix_coco_gqa_balance_raw_QformerMoE_Route_linear_gate_5ex_2beam_1loss_textinqf_epo5_toplayer3_1209/"
output_dir: "/mnt/pfs-guan-ssai/nlu/wanghanzi/experiments/blip2/vicuna7b/qformer_moe_route/eval/mix_coco_gqa_1610k_raw_QformerMoE_Route_Post_ffn_prob_linear_gate_2ex_2beam_1gate_2loss_5e5lr_top6layer_textinqf_epo8_0112/"
amp: True
resume_ckpt_path: null

View File

@ -38,10 +38,12 @@ model:
# moe
use_moeqformer: True
use_route_moe: True
moebert_route_method: "post-route"
moebert_load_balance: 0
moebert_expert_num: 3
moebert_num_beams: 3
moebert_route_method: "post-route"
gate_save_path: "/mnt/pfs-guan-ssai/nlu/wanghanzi/experiments/blip2/vicuna7b/qformer_moe/route_save/mix_coco_gqa_balance_raw_QformerMoE_Route_linear_gate_5ex_2beam_1loss_textinqf_epo5_toplayer3_1209/"
moe_weight_type: 'ffn_prob'
use_balance_loss: False
datasets:
gqa: # train: 943000, 12578, 12578)
@ -97,19 +99,20 @@ run:
task: instruction_tuning
# optimizer
lr_sched: "linear_warmup_cosine_lr"
init_lr: 2e-5
init_lr: 5e-5
min_lr: 1e-6
warmup_lr: 1e-6
log_freq: 5
save_freq: 1500
weight_decay: 0.05
max_epoch: 5
max_epoch: 8
num_workers: 4
warmup_steps: 600
iters_per_epoch: 5000
seed: 42
output_dir: "/mnt/pfs-guan-ssai/nlu/wanghanzi/experiments/blip2/vicuna7b/qformer_moe_route/mix_coco_gqa_balance_raw_QformerMoE_Route_linear_gate_5ex_2beam_1loss_textinqf_epo5_toplayer3_1209/"
output_dir: "/mnt/pfs-guan-ssai/nlu/wanghanzi/experiments/blip2/vicuna7b/qformer_moe_route/mix_coco_gqa_1610k_raw_QformerMoE_Route_Post_ffn_prob_linear_1gate_3ex_3beam_1loss_5e5lr_top6layer_textinqf_epo8_0117/"
amp: True
resume_ckpt_path: null

View File

@ -0,0 +1,129 @@
# Copyright (c) 2022, salesforce.com, inc.
# All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
# 0107test
model:
arch: blip2_vicuna_instruct
model_type: vicuna7b_pretrain
load_pretrained: True
load_finetuned: False
vit_model: eva_clip_g
pretrained: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/blip2/blip2_vicuna7b/blip2_pretrained_vicuna7b.pth"
# finetuned: ""
q_former_model: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/blip2/blip2_vicuna7b/blip2_pretrained_vicuna7b.pth"
# vit encoder
image_size: 224
drop_path_rate: 0
use_grad_checkpoint: False
vit_precision: "fp16"
# Q-Former
num_query_token: 32
qformer_text_input: True
# vicuna
llm_model: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/vicuna-7b-v1.1"
prompt: ""
max_txt_len: 256
max_output_txt_len: 256
# freeze
freeze_vit: True
freeze_llm: True
freeze_qformer: False
freeze_t5_proj: False
# moe
use_moeqformer: True
use_route_moe: True
moebert_route_method: "post-route"
moebert_load_balance: 0
moebert_expert_num: 2
moebert_num_beams: 2
moe_weight_type: 'ffn_prob'
use_balance_loss: False
# gate_save_path: "/mnt/pfs-guan-ssai/nlu/wanghanzi/experiments/blip2/vicuna7b/qformer_moe/route_save/mix_coco_gqa_balance_raw_QformerMoE_Route_linear_gate_5ex_2beam_1loss_textinqf_epo5_toplayer3_1209/"
datasets:
# gqa: # train: 943000, 12578, 12578)
# type: balanced_sft_raw
# batch_size: 1
# vis_processor:
# train:
# name: "blip2_image_train"
# image_size: 224
# eval:
# name: "blip2_image_eval"
# image_size: 224
# text_processor:
# train:
# name: "blip_caption"
# eval:
# name: "blip_caption"
# sample_ratio: 10
ok_vqa: # train, valid (9009, 5046)
batch_size: 1
vis_processor:
train:
name: "blip2_image_train"
image_size: 224
eval:
name: "blip2_image_eval"
image_size: 224
text_processor:
train:
name: "blip_caption"
eval:
name: "blip_caption"
sample_ratio: 1
# coco_vqa: # 658104
# batch_size: 1
# vis_processor:
# train:
# name: "blip2_image_train"
# image_size: 224
# eval:
# name: "blip2_image_eval"
# image_size: 224
# text_processor:
# train:
# name: "blip_caption"
# eval:
# name: "blip_caption"
# sample_ratio: 9
run:
task: instruction_tuning
# optimizer
lr_sched: "linear_warmup_cosine_lr"
init_lr: 2e-5
min_lr: 1e-6
warmup_lr: 1e-6
log_freq: 5
save_freq: 1500
weight_decay: 0.05
max_epoch: 5
num_workers: 4
warmup_steps: 600
seed: 42
output_dir: "/mnt/pfs-guan-ssai/nlu/wanghanzi/experiments/blip2/vicuna7b/qformer_moe_route/mix_coco_gqa_balance_raw_QformerMoE_Route_linear_gate_5ex_2beam_1loss_textinqf_epo5_toplayer3_1209/"
amp: True
resume_ckpt_path: null
evaluate: False
train_splits: ["train"]
valid_splits: ["val"]
# test_splits: ["val"]
device: "cuda"
world_size: 1
dist_url: "env://"
distributed: True

View File

@ -10,7 +10,7 @@ model:
load_finetuned: True
vit_model: eva_clip_g
pretrained: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/blip2/blip2_vicuna7b/blip2_pretrained_vicuna7b.pth"
finetuned: "/mnt/pfs-guan-ssai/nlu/wanghanzi/experiments/blip2/vicuna7b/qformer_moe_route/mix_1048k_raw_QformerMoE_Route_Post_NoNorm_5ex_2beam_1loss_top3layer_textinqf_epo6_1215/20231216161/checkpoint_best.pth"
finetuned: "/mnt/pfs-guan-ssai/nlu/wanghanzi/experiments/blip2/vicuna7b/qformer_moe_route/mix_coco_gqa_1610k_raw_QformerMoE_Route_Post_ffn_prob_linear_gate_2ex_2beam_1gate_1loss_5e5lr_top6layer_textinqf_epo8_0111/20240111145/checkpoint_best.pth"
q_former_model: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/blip2/blip2_vicuna7b/blip2_pretrained_vicuna7b.pth"
# vit encoder
@ -39,8 +39,11 @@ model:
use_moeqformer: True
use_route_moe: True
moebert_route_method: "post-route"
moebert_expert_num: 5
moebert_load_balance: 0
moebert_expert_num: 2
moebert_num_beams: 2
moe_weight_type: 'ffn_prob'
use_balance_loss: False
datasets:
ok_vqa: # train, valid (9009, 5046)
@ -78,7 +81,7 @@ evaluation_datasets:
run:
task: instruction_tuning
name: vqa_benchmark_evaluation
save_path: "/mnt/pfs-guan-ssai/nlu/wanghanzi/experiments/blip2/vicuna7b/qformer_moe/eval/benchmarks/mix_1048k_raw_QformerMoE_Route_Post_NoNorm_5ex_2beam_1loss_top3layer_textinqf_epo6_1215/"
save_path: "/mnt/pfs-guan-ssai/nlu/wanghanzi/experiments/blip2/vicuna7b/qformer_moe_route/eval/benchmarks/mix_coco_gqa_1610k_raw_QformerMoE_Route_Post_ffn_prob_linear_gate_2ex_2beam_1gate_1loss_5e5lr_top6layer_textinqf_epo8_0111/"
seed: 42

View File

@ -0,0 +1,131 @@
# Copyright (c) 2022, salesforce.com, inc.
# All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
model:
arch: blip2_vicuna_instruct
model_type: vicuna7b_pretrain
load_pretrained: True
load_finetuned: False
vit_model: eva_clip_g
pretrained: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/blip2/blip2_vicuna7b/blip2_pretrained_vicuna7b.pth"
# finetuned: ""
q_former_model: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/blip2/blip2_vicuna7b/blip2_pretrained_vicuna7b.pth"
# vit encoder
image_size: 224
drop_path_rate: 0
use_grad_checkpoint: False
vit_precision: "fp16"
# Q-Former
num_query_token: 32
qformer_text_input: True
# vicuna7b
llm_model: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/vicuna-7b-v1.1"
prompt: ""
max_txt_len: 256
max_output_txt_len: 256
# freeze
freeze_vit: True
freeze_llm: True
freeze_qformer: False
freeze_t5_proj: False
# moe
use_moeqformer: True
use_route_moe: True
moebert_route_method: "post-route"
moebert_load_balance: 0.05
moebert_expert_num: 3
moebert_num_beams: 3
moe_weight_type: 'ffn_prob'
use_balance_loss: False
datasets:
gqa: # train: 943000, 12578, 12578)
type: balanced_sft_raw
# batch_size: 16
batch_size: 32
vis_processor:
train:
name: "blip2_image_train"
image_size: 224
eval:
name: "blip2_image_eval"
image_size: 224
text_processor:
train:
name: "blip_caption"
eval:
name: "blip_caption"
sample_ratio: 50
ok_vqa: # train, valid (9009, 5046)
# batch_size: 16
batch_size: 32
vis_processor:
train:
name: "blip2_image_train"
image_size: 224
eval:
name: "blip2_image_eval"
image_size: 224
text_processor:
train:
name: "blip_caption"
eval:
name: "blip_caption"
sample_ratio: 8
coco_vqa: # 658104
# batch_size: 16
batch_size: 32
vis_processor:
train:
name: "blip2_image_train"
image_size: 224
eval:
name: "blip2_image_eval"
image_size: 224
text_processor:
train:
name: "blip_caption"
eval:
name: "blip_caption"
sample_ratio: 15
run:
task: instruction_tuning
# optimizer
lr_sched: "linear_warmup_cosine_lr"
# init_lr: 2e-5
init_lr: 5e-5
min_lr: 1e-6
warmup_lr: 1e-6
log_freq: 5
save_freq: 1500
weight_decay: 0.05
max_epoch: 8
num_workers: 4
warmup_steps: 600
iters_per_epoch: 5000
seed: 42
output_dir: "/mnt/pfs-guan-ssai/nlu/wanghanzi/experiments/blip2/vicuna7b/qformer_moe_route/mix_coco_gqa_1610k_raw_QformerMoE_Route_Post_ffn_prob_linear_gate_3ex_3beam_1gate_2loss_5e5lr_top6layer_textinqf_epo8_0112/"
amp: True
resume_ckpt_path: null
evaluate: False
train_splits: ["train"]
valid_splits: ["val"]
device: "cuda"
world_size: 1
dist_url: "env://"
distributed: True

View File

@ -37,19 +37,19 @@ model:
# moe
use_moeqformer: True
moebert_expert_num: 5
moebert_expert_num: 3
moebert_route_method: "gate-sentence"
moebert_load_balance: 0
moe_topk: 1
use_balance_loss: False
moe_weight_type: 'l2_norm'
gate_save_path: "/mnt/pfs-guan-ssai/nlu/wanghanzi/experiments/blip2/vicuna7b/qformer_moe/gate_save/mix_coco_gqa_balance_raw_QformerMoE_train_qf_train_qt_linear_gate_5ex_top1_1loss_textinqf_training_epo5_toplayer3_1206/"
moe_weight_type: 'raw_prob'
# gate_save_path: "/mnt/pfs-guan-ssai/nlu/wanghanzi/experiments/blip2/vicuna7b/qformer_moe/gate_save/mix_coco_gqa_balance_raw_QformerMoE_train_qf_train_qt_linear_gate_5ex_top1_1loss_textinqf_training_epo5_toplayer3_1206/"
datasets:
gqa: # train: 94254
type: balanced_sft_raw_part
batch_size: 32
batch_size: 1
vis_processor:
train:
name: "blip2_image_train"
@ -65,7 +65,7 @@ datasets:
sample_ratio: 50
ok_vqa: # train, valid (9009, 5046
batch_size: 32
batch_size: 1
vis_processor:
train:
name: "blip2_image_train"
@ -80,22 +80,22 @@ datasets:
name: "blip_caption"
sample_ratio: 8
coco_vqa: # 214352 vqa_val
type: vqa_v2_part
batch_size: 32
vis_processor:
train:
name: "blip2_image_train"
image_size: 224
eval:
name: "blip2_image_eval"
image_size: 224
text_processor:
train:
name: "blip_caption"
eval:
name: "blip_caption"
sample_ratio: 15
# coco_vqa: # 214352 vqa_val
# type: vqa_v2_part
# batch_size: 1
# vis_processor:
# train:
# name: "blip2_image_train"
# image_size: 224
# eval:
# name: "blip2_image_eval"
# image_size: 224
# text_processor:
# train:
# name: "blip_caption"
# eval:
# name: "blip_caption"
# sample_ratio: 15
run:
task: instruction_tuning
@ -108,12 +108,13 @@ run:
save_freq: 1500
weight_decay: 0.05
max_epoch: 5
max_epoch: 1
num_workers: 4
warmup_steps: 600
iters_per_epoch: 1000
seed: 42
output_dir: "/mnt/pfs-guan-ssai/nlu/wanghanzi/experiments/blip2/vicuna7b/qformer_moe/mix_coco_gqa_balance_raw_QformerMoE_train_qf_train_qt_linear_gate_5ex_top1_1loss_textinqf_training_epo5_toplayer3_1206/"
output_dir: "/mnt/pfs-guan-ssai/nlu/wanghanzi/experiments/blip2/vicuna7b/qformer_moe/mix_coco_gqa_balance_raw_QformerMoE_train_qf_train_qt_linear_gate_5ex_top1_1loss_textinqf_training_epo5_toplayer3_1220_test/"
amp: True
resume_ckpt_path: null

View File

@ -0,0 +1,125 @@
# Copyright (c) 2022, salesforce.com, inc.
# All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
model:
arch: blip2_vicuna_instruct
model_type: vicuna7b_pretrain
load_pretrained: True
load_finetuned: False
vit_model: eva_clip_g
pretrained: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/blip2/blip2_vicuna7b/blip2_pretrained_vicuna7b.pth"
# finetuned: ""
q_former_model: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/blip2/blip2_vicuna7b/blip2_pretrained_vicuna7b.pth"
# vit encoder
image_size: 224
drop_path_rate: 0
use_grad_checkpoint: False
vit_precision: "fp16"
# Q-Former
num_query_token: 32
qformer_text_input: True
# vicuna7b
llm_model: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/vicuna-7b-v1.1"
prompt: ""
max_txt_len: 256
max_output_txt_len: 256
# freeze
freeze_vit: True
freeze_llm: True
freeze_qformer: False
freeze_t5_proj: False
# moe
use_moeqformer: False
moebert_expert_num: 1
moebert_route_method: "gate-sentence"
moebert_load_balance: 0.05
moe_topk: 1
datasets:
gqa: # train: 943000, 12578, 12578)
type: balanced_sft_raw
batch_size: 16
vis_processor:
train:
name: "blip2_image_train"
image_size: 224
eval:
name: "blip2_image_eval"
image_size: 224
text_processor:
train:
name: "blip_caption"
eval:
name: "blip_caption"
sample_ratio: 50
ok_vqa: # train, valid (9009, 5046)
batch_size: 16
vis_processor:
train:
name: "blip2_image_train"
image_size: 224
eval:
name: "blip2_image_eval"
image_size: 224
text_processor:
train:
name: "blip_caption"
eval:
name: "blip_caption"
sample_ratio: 8
coco_vqa: # 658104
batch_size: 16
vis_processor:
train:
name: "blip2_image_train"
image_size: 224
eval:
name: "blip2_image_eval"
image_size: 224
text_processor:
train:
name: "blip_caption"
eval:
name: "blip_caption"
sample_ratio: 15
run:
task: instruction_tuning
# optimizer
lr_sched: "linear_warmup_cosine_lr"
# init_lr: 2e-5
init_lr: 5e-5
min_lr: 1e-6
warmup_lr: 1e-6
log_freq: 5
save_freq: 1500
weight_decay: 0.05
max_epoch: 8
num_workers: 4
warmup_steps: 600
iters_per_epoch: 5000
seed: 42
output_dir: "/mnt/pfs-guan-ssai/nlu/wanghanzi/experiments/blip2/vicuna7b/qformer_moe/mix_coco_gqa_1610k_raw_QformerMoE_train_qf_train_qt_1ex_top1_textinqf_epo8_lr5e5_seed42_0112/"
amp: True
resume_ckpt_path: null
evaluate: False
train_splits: ["train"]
valid_splits: ["val"]
device: "cuda"
world_size: 1
dist_url: "env://"
distributed: True

View File

@ -110,6 +110,7 @@ class RunnerBase:
else:
p_wd.append(p)
num_parameters += p.data.nelement()
# import pdb; pdb.set_trace() # 0107test
logging.info("number of trainable parameters: %d" % num_parameters)
optim_params = [
{

View File

@ -238,13 +238,17 @@ class BaseTask:
with torch.cuda.amp.autocast(enabled=use_amp):
loss = self.train_step(model=model, samples=samples)
# after_train_step()
if use_amp:
# torch.autograd.set_detect_anomaly(True)
# 反向传播时检测是否有异常值定位code
# with torch.autograd.detect_anomaly():
scaler.scale(loss).backward()
else:
loss.backward()
# import pdb; pdb.set_trace() # 0107test
# update gradients every accum_grad_iters iterations
if (i + 1) % accum_grad_iters == 0:
if use_amp:
@ -252,6 +256,9 @@ class BaseTask:
scaler.update()
else:
optimizer.step()
# import pdb; pdb.set_trace()# 0107test
optimizer.zero_grad()
# if self.cfg.wandb_log:
# if self.cfg.run_cfg.wandb_log:

View File

@ -44,4 +44,6 @@ wheel
visualizer
tensorboard
kmeans_pytorch
visual_genome
visual_genome
gpustat
torchviz

36
setup.py Normal file
View File

@ -0,0 +1,36 @@
"""
Copyright (c) 2022, salesforce.com, inc.
All rights reserved.
SPDX-License-Identifier: BSD-3-Clause
For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""
from setuptools import setup, find_namespace_packages
import platform
DEPENDENCY_LINKS = []
if platform.system() == "Windows":
DEPENDENCY_LINKS.append("https://download.pytorch.org/whl/torch_stable.html")
def fetch_requirements(filename):
with open(filename) as f:
return [ln.strip() for ln in f.read().split("\n")]
setup(
name="PromptMoE",
version="1.0.1",
author="Hanzi Wang",
description="PromptMoE & QformerMoE Based on LAVIS",
long_description=open("README.md", "r", encoding="utf-8").read(),
long_description_content_type="text/markdown",
keywords="Vision-Language, Multimodal, Image Captioning, Generative AI, Deep Learning, Library, PyTorch",
license="3-Clause BSD",
packages=find_namespace_packages(include="lavis.*"),
install_requires=fetch_requirements("requirements.txt"),
python_requires=">=3.7.0",
include_package_data=True,
dependency_links=DEPENDENCY_LINKS,
zip_safe=False,
)

5570
test.pdf/backward_graph Normal file

File diff suppressed because it is too large Load Diff

BIN
test.pdf/backward_graph.pdf Normal file

Binary file not shown.

360
test.txt Normal file
View File

@ -0,0 +1,360 @@
tmp_name = [name for name, p in model.named_parameters() if (p.requires_grad and '10.expert' in name)]
tmp = [p for name, p in model.named_parameters() if (p.requires_grad and '10.expert' in name)]
tensor([[-1.4032e-02, 3.7242e-03, 8.4997e-03, -3.4016e-03, -6.4855e-03,
4.3595e-02, 3.4423e-02, -8.6274e-03, -1.9702e-02, 9.1813e-03,
1.1643e-02, 2.3939e-02, -2.0908e-02, 3.4555e-03, 9.1636e-03,
1.5413e-02, 2.4148e-02, -1.0880e-03, 1.1193e-02, -1.3591e-02,
9.3484e-03, 1.5999e-02, -9.6086e-04, 3.8322e-02, -8.0687e-03,
-1.4056e-02, 3.9486e-02, 3.5167e-02, -9.3226e-03, -1.0493e-02,
-2.5795e-02, -9.7541e-03, 4.4437e-03, 7.7226e-03, 7.5210e-03,
-1.3526e-02, -5.0316e-03, -1.1149e-02, 6.0583e-03, 2.0564e-02,
-6.4477e-03, 1.4170e-02, -3.7847e-02, 1.1780e-02, 1.3321e-02,
-8.2501e-03, -1.0298e-02, 1.4805e-02, -1.2432e-02, -1.9159e-02,
-5.7095e-04, -3.8618e-02, -2.4230e-02, -1.4991e-03, -1.4114e-02,
-1.5365e-02, 1.5640e-02, -4.8623e-02, -2.9991e-02, 1.2796e-02,
-4.9917e-03, 2.3846e-03, 7.7368e-03, 1.2913e-02, 1.5300e-02,
8.5125e-03, 1.1582e-02, 8.1161e-03, 4.2259e-03, 7.6109e-03,
-2.0747e-02, -3.5099e-03, 2.2282e-02, 5.0493e-02, -1.7849e-02,
-3.7106e-02, -1.4944e-02, -1.4582e-02, -2.2458e-02, -4.6173e-05,
-8.1270e-03, 1.9037e-02, -2.0086e-02, 3.0980e-03, -9.3947e-03,
1.3054e-02, 2.3203e-02, -9.9304e-03, -2.6038e-02, 1.8679e-02,
9.2081e-03, -2.1770e-02, -1.6568e-03, -3.6503e-02, 2.0054e-02,
1.2886e-02, -1.8021e-02, 3.4457e-02, -1.3704e-02, -6.1498e-03,
-8.6769e-03, 1.5024e-02, -1.3875e-02, 1.7416e-02, -1.1178e-02,
-2.4088e-02, -1.7802e-02, 3.3326e-02, -1.1216e-02, -8.6330e-03,
-5.5359e-03, -1.1939e-02, -1.7777e-02, -2.8666e-02, -3.8280e-02,
4.2682e-02, 1.4946e-02, 9.6427e-03, 8.2754e-03, -1.0516e-03,
2.9560e-02, 2.4552e-03, -4.8354e-02, 1.5568e-02, 2.5881e-02,
-1.7354e-02, -3.1232e-02, 2.3683e-02, -2.3239e-02, 2.2966e-02,
5.6349e-03, -8.7595e-03, 1.5173e-02, 2.7660e-02, -4.3304e-03,
-2.5330e-02, -2.1795e-02, 1.6856e-02, -2.1587e-04, 2.3707e-02,
-2.3667e-02, 3.5378e-02, -7.9245e-03, 7.1029e-04, -3.2800e-02,
-1.5402e-03, -8.5634e-03, -1.1356e-02, -2.1935e-03, -1.8854e-02,
-1.9705e-03, -3.8333e-02, 2.9131e-02, -4.4470e-02, -2.0893e-03,
1.2937e-02, -1.7116e-02, 2.7778e-02, 1.0311e-02, -6.4017e-03,
3.7647e-02, -1.9953e-02, -5.3925e-03, 3.6978e-02, -1.5534e-02,
1.2241e-02, 1.3597e-02, 2.0703e-03, 2.4213e-03, 9.2604e-03,
6.6108e-03, -5.8213e-03, 9.8167e-03, -9.8300e-04, -1.0236e-02,
2.9581e-02, 1.0987e-02, 2.0046e-02, -1.0500e-02, -3.2221e-03,
-2.6303e-02, 1.3688e-02, -2.2529e-02, -5.7654e-03, 1.1784e-02,
1.6221e-02, 2.8743e-02, 5.7565e-03, 1.8129e-02, 1.5140e-02,
-1.1748e-02, -1.7528e-02, 4.7977e-02, 1.5568e-02, 4.7030e-04,
3.2757e-03, 1.6631e-02, 1.9986e-02, -7.3463e-03, 1.1435e-02,
-1.4739e-02, -3.2959e-03, -2.8770e-03, 2.9260e-02, 1.7007e-02,
3.0611e-02, 2.2102e-02, -3.3819e-02, -1.9403e-02, 2.5524e-02,
3.0738e-02, -1.9951e-02, -1.4553e-02, -1.5796e-02, -2.3143e-02,
-2.8826e-02, 2.4739e-02, -5.8602e-03, 4.1871e-02, 5.0821e-04,
3.3493e-02, 2.3524e-02, 2.3191e-02, 9.0416e-03, 3.3262e-02,
-1.6805e-02, 1.1545e-02, -1.7195e-02, -3.8696e-02, -8.4358e-04,
-8.1605e-03, 3.1372e-03, 1.0726e-03, 1.0865e-03, 1.0760e-02,
-5.2421e-03, 1.3039e-02, 3.6873e-04, 1.0464e-02, -1.1544e-02,
-2.2775e-02, -4.8439e-02, -1.0711e-02, 4.4236e-03, 2.0351e-02,
2.4479e-03, -1.9968e-02, -2.2941e-02, -2.0486e-02, -1.9528e-02,
-2.3176e-02, -3.2731e-03, 1.1789e-02, 2.0921e-02, 2.9809e-03,
-8.8507e-03, -3.5716e-02, 8.8418e-03, 5.3665e-05, -1.1288e-02,
-7.5571e-03, 2.1053e-02, -3.7381e-03, -4.0165e-03, -2.2628e-03,
3.7554e-03, -1.6597e-02, 7.6946e-03, -3.2689e-02, 2.2016e-02,
5.5122e-03, 4.5455e-02, 6.7586e-03, 1.5714e-02, 5.2125e-03,
3.9596e-03, 1.8134e-02, 1.5834e-03, -1.6239e-02, -1.3889e-02,
-2.3522e-02, 1.4738e-02, 5.5867e-03, -7.0727e-03, -2.8140e-03,
1.6849e-02, -3.1327e-02, -3.2443e-02, 4.7851e-03, 1.2980e-02,
-2.0014e-04, -9.9475e-03, 8.0657e-03, 1.9468e-02, -1.5774e-02,
1.7017e-02, -8.7196e-03, -4.0681e-03, -6.9754e-03, -2.2007e-02,
-6.6217e-03, -1.8219e-02, 4.2186e-02, -5.6621e-03, -9.3449e-03,
-1.1662e-02, 2.8700e-02, -9.0654e-03, 3.1569e-02, -2.9825e-03,
-3.8198e-02, -5.2723e-02, -4.8325e-02, -2.7871e-03, 5.1127e-03,
1.4511e-02, 9.3245e-03, -2.3339e-02, -8.6658e-03, 1.5276e-02,
-1.5823e-02, -3.4476e-03, 1.4601e-02, 6.3504e-03, -1.4307e-02,
2.2817e-02, 2.1998e-02, 1.7330e-02, -2.4448e-02, 4.0178e-03,
3.2280e-03, -1.2721e-02, 1.9661e-02, 7.5263e-03, 2.0245e-02,
4.5525e-02, -1.5658e-02, -4.0676e-02, 9.3160e-03, 1.1920e-02,
-1.9317e-02, 1.7848e-02, -5.8601e-03, 1.1786e-03, 8.3864e-03,
-1.8341e-02, 2.5985e-02, -1.1387e-02, -1.5069e-02, -2.8097e-02,
2.4966e-02, 1.4790e-02, 2.0424e-02, -1.3062e-02, 3.1314e-02,
1.7811e-02, 7.2393e-03, 1.4413e-02, -1.2746e-02, 3.1039e-02,
-1.1697e-02, -1.4826e-02, -8.8397e-03, 1.5157e-02, -1.5855e-02,
-1.8157e-03, 1.3024e-02, -1.8902e-03, 2.5212e-02, -3.4886e-02,
4.3029e-02, -4.0842e-02, 1.1362e-02, -1.4654e-02, -1.3337e-02,
-3.1832e-02, 3.6222e-03, 8.2804e-03, -1.4269e-02, 2.8399e-03,
-1.2008e-02, 2.4685e-02, -4.3070e-03, 6.3163e-03, -1.3517e-02,
-1.3807e-02, 2.4617e-02, 2.1453e-02, 4.7332e-03, 9.1636e-03,
-1.2881e-02, 1.9077e-02, 1.7571e-04, -5.2817e-03, -2.8821e-02,
5.8223e-03, -3.0979e-02, 2.4609e-02, 3.6666e-02, -1.0950e-02,
2.0421e-02, -2.6378e-03, 3.1825e-02, -9.6689e-04, -2.8398e-02,
-2.7513e-02, 1.6946e-02, -2.4110e-02, -1.3575e-02, -1.3443e-02,
8.4217e-03, 2.6754e-02, -2.3309e-03, -2.5086e-02, 1.1844e-02,
1.4152e-02, 1.2989e-02, -5.7336e-03, 4.7391e-03, 3.4106e-02,
1.0142e-02, -1.8029e-02, -1.5410e-04, -1.3548e-02, 9.1742e-03,
-3.0150e-02, 1.5666e-02, 4.3049e-03, 1.6273e-02, 2.0672e-02,
-1.2458e-02, 4.5496e-02, 3.2131e-02, -3.0967e-03, 2.1891e-02,
2.5524e-02, -1.1998e-02, -1.8866e-03, -1.0945e-02, 5.9930e-03,
-8.4233e-03, -8.9095e-03, -1.8261e-02, 1.9308e-02, -1.9728e-02,
-1.4216e-02, 1.4952e-02, 5.7355e-04, -2.4753e-02, -1.0948e-02,
1.0965e-02, 1.3607e-03, 3.4974e-02, -4.1396e-03, 2.5519e-02,
1.0364e-02, -1.5851e-02, -4.9224e-03, 1.0903e-02, -1.0523e-04,
3.1355e-02, -1.5105e-02, 5.6972e-03, -8.4078e-03, -1.9868e-02,
1.7186e-03, 2.9396e-02, -4.1439e-02, 1.4124e-02, -3.7745e-03,
3.3007e-02, 8.0368e-04, 8.5574e-03, 1.7269e-02, 1.1955e-02,
8.8142e-03, -1.3123e-02, 1.6817e-02, -1.5456e-02, -1.3868e-02,
2.4139e-02, -9.1566e-03, -1.8477e-02, -4.7972e-03, -6.8459e-03,
1.6818e-02, 3.1645e-03, -3.0901e-02, -5.6036e-03, -1.4758e-02,
2.0473e-02, -7.5411e-05, 2.0673e-03, -7.0061e-03, 9.5544e-03,
1.6600e-02, -1.7315e-02, -2.0168e-02, -5.3008e-03, 2.0206e-02,
2.4209e-03, 2.1205e-02, -8.9188e-03, -4.1350e-04, -1.0638e-02,
1.3705e-02, 9.5925e-05, 3.8877e-02, 3.2884e-02, -2.7730e-03,
1.0052e-02, 1.9311e-02, 1.1341e-02, -1.2988e-02, -1.7157e-02,
3.2095e-02, -1.8493e-02, -9.2551e-03, -2.6509e-03, -1.1130e-02,
1.6581e-02, 1.0216e-02, 1.3687e-02, 1.1860e-02, -3.0462e-03,
-1.2082e-02, 2.8502e-03, -1.2620e-02, 8.8330e-03, 1.7357e-02,
1.8383e-02, -2.3130e-02, -3.2654e-02, 1.2853e-02, -7.8144e-03,
1.9418e-04, 3.8635e-03, 4.9333e-02, 1.9350e-02, -2.0643e-02,
8.4650e-04, 5.0242e-02, 1.6576e-02, -8.9166e-03, -5.8805e-03,
-4.1484e-02, 9.3217e-03, -1.1292e-02, -8.7944e-03, -3.3190e-03,
5.7970e-03, -6.6078e-03, -2.4052e-02, -5.6347e-03, 8.4539e-03,
1.9250e-02, 7.9559e-03, -3.0055e-03, -3.0398e-04, 2.7007e-02,
3.1046e-03, 1.8332e-02, 5.5470e-03, 6.6815e-03, 1.1466e-02,
1.9738e-02, 1.2176e-02, -2.0220e-02, 8.6928e-03, 4.2451e-03,
4.4517e-03, -5.1524e-03, 1.0805e-02, -2.1935e-02, -1.7575e-02,
-1.2529e-02, -2.2191e-02, -1.0854e-02, -9.4462e-03, -2.9102e-02,
2.6752e-02, -1.0919e-02, -2.6724e-02, 8.3694e-04, 2.9832e-03,
1.4416e-02, -2.9906e-02, 2.3556e-02, -6.6624e-03, 2.6671e-02,
-3.6474e-02, 1.7237e-02, -2.5176e-02, 6.5560e-03, -2.6062e-02,
-2.3838e-02, 3.0629e-02, 2.5382e-02, 1.2302e-02, -1.1665e-02,
-7.0603e-03, 1.9931e-02, 2.3401e-02, -2.6047e-03, -2.7728e-02,
-1.7212e-02, 2.3061e-02, -2.5961e-02, 3.9764e-04, -2.9022e-02,
-1.5546e-03, 4.5519e-03, 2.3589e-02, -3.5005e-02, 4.1890e-03,
-1.5586e-02, 1.2389e-02, -2.1045e-02, 1.6377e-03, -1.1328e-02,
1.0195e-02, 6.4322e-03, -3.8431e-02, 2.2918e-02, -4.0123e-03,
6.6680e-02, 4.1135e-02, -1.5031e-02, -1.3550e-02, -2.2566e-02,
-2.3622e-03, -2.9323e-02, 2.1756e-02, 1.8399e-03, -4.2460e-03,
-1.5128e-03, -2.4731e-02, 1.8663e-02, 1.3469e-02, -1.3897e-02,
2.6399e-02, -8.0740e-03, -4.6753e-03, 3.9857e-02, 6.2364e-03,
2.2371e-03, 2.1501e-03, 5.9443e-02, 1.3574e-02, 7.6483e-03,
-6.2290e-03, 1.4324e-02, 1.2572e-02, 2.7331e-02, -6.0165e-03,
-5.9154e-03, -3.7000e-02, 1.4001e-02, 1.2869e-02, -2.8854e-02,
-9.4147e-03, 8.3965e-03, -1.4530e-03, -7.4215e-03, 9.0369e-03,
-2.4612e-02, 2.0625e-02, 2.2329e-02, -1.5216e-02, 1.4947e-03,
-3.6020e-02, -2.0702e-02, -4.0410e-02, -1.3157e-02, -1.5085e-02,
1.2911e-02, -2.7552e-02, -2.9781e-02, -4.7424e-03, 2.0521e-02,
-4.0043e-02, -4.8763e-02, -1.3175e-02, 2.6802e-02, 2.8869e-02,
6.5014e-03, -2.3213e-02, 1.4438e-02, -7.6318e-03, -1.9928e-03,
1.8509e-03, 2.9728e-03, 1.5225e-02, -2.9405e-03, -7.2875e-03,
2.9562e-05, -1.8661e-02, 9.1341e-03, -2.4919e-02, 2.9786e-02,
9.5186e-03, 1.5435e-02, -1.1080e-02, 1.1192e-02, -2.7315e-03,
6.9769e-05, -1.5392e-02, 4.9892e-03, 7.9857e-03, 2.0063e-02,
-2.0283e-02, -1.2596e-02, -4.1985e-04, -6.9686e-03, -5.4704e-02,
-1.9142e-02, 9.9706e-03, 2.3217e-02, -5.0579e-03, -4.9132e-02,
2.0023e-02, -2.6238e-02, 1.0709e-02, 2.1528e-02, -1.6390e-03,
-6.7829e-03, 1.3211e-02, -9.6793e-03, 1.3130e-02, -1.2878e-02,
1.7365e-02, 1.2509e-02, 1.2986e-03, -3.9292e-02, 9.5784e-03,
-8.0514e-03, -3.5619e-02, -3.2298e-02, 6.5933e-04, 9.9298e-03,
3.7268e-02, -3.4047e-02, -7.8385e-03, 2.3999e-02, 1.0386e-02,
1.7853e-02, -1.0122e-04, 5.2483e-04, -7.3150e-03, 1.0818e-02,
1.6245e-02, -3.5619e-02, -9.9190e-03, 4.0132e-03, 9.7788e-03,
2.7039e-02, -4.7858e-02, -2.0010e-02, -2.3702e-02, 7.8376e-04,
-2.5326e-02, 1.1698e-02, -1.3041e-02, 3.8634e-03, 9.3083e-03,
4.8204e-03, 3.9503e-02, -4.1356e-03]], requires_grad=True)
model.Qformer.bert.encoder.layer[10].experts.gate.weight
layer 11
0:
model.Qformer.bert.encoder.layer[11].output.dense.weight.grad
model.Qformer.bert.encoder.layer[11].intermediate.dense.weight.grad
nan:
model.Qformer.bert.encoder.layer[11].attention.output.dense.weight.grad
model.Qformer.bert.encoder.layer[11].attention.self.query.weight.grad
model.Qformer.bert.encoder.layer[11].experts.intermediate_query.dense.weight.grad
model.Qformer.bert.encoder.layer[11].experts.output_query.dense.weight.grad
None:
model.Qformer.bert.encoder.layer[11].intermediate_query.dense.weight.grad
model.Qformer.bert.encoder.layer[11].output_query.dense.weight.grad
layer 8
0:
model.Qformer.bert.encoder.layer[8].experts.experts[0].intermediate_query.dense.weight.grad
model.Qformer.bert.encoder.layer[8].experts.experts[2].intermediate_query.dense.weight.grad
model.Qformer.bert.encoder.layer[8].experts.experts[0].output_query.dense.weight.grad
model.Qformer.bert.encoder.layer[8].experts.experts[2].output_query.dense.weight.grad
nan:
model.Qformer.bert.encoder.layer[8].experts.experts[1].intermediate_query.dense.weight.grad
model.Qformer.bert.encoder.layer[8].experts.experts[1].output_query.dense.weight.grad
(Qformer)model.Qformer.bert.encoder.layer[8].intermediate_query.dense.weight.grad
None:
model.Qformer.bert.encoder.layer[8].experts.gate.weight.grad == None
model.Qformer.bert.encoder.layer[8].experts.gate.weight.requires_grad == True
model.Qformer.bert.encoder.layer[6].experts.gate.weight
Qformer.bert.encoder.layer.6.experts.gate.weight
tensor([[-0.0089, -0.0123, -0.0168, ..., -0.0072, 0.0295, -0.0167],
[ 0.0305, 0.0277, -0.0215, ..., 0.0149, 0.0016, -0.0415],
[ 0.0199, 0.0151, 0.0237, ..., 0.0007, 0.0023, 0.0167]],
requires_grad=True)
tensor([[-0.0089, -0.0123, -0.0168, ..., -0.0072, 0.0295, -0.0167],
[ 0.0305, 0.0277, -0.0215, ..., 0.0149, 0.0016, -0.0415],
[ 0.0199, 0.0151, 0.0237, ..., 0.0007, 0.0023, 0.0167]],
requires_grad=True)
tensor([[ 4.5972e-02, -1.5231e-02, -6.9533e-03, 3.2431e-02, -7.9703e-03,
1.5567e-02, 2.9619e-03, -2.2609e-04, 1.8580e-02, -2.8783e-02,
1.3093e-02, -1.0594e-02, 1.1918e-02, 4.4701e-02, 2.0108e-02,
-1.1011e-03, -8.2449e-03, 8.8876e-03, 4.6096e-03, 2.3274e-02,
-9.2557e-03, 2.5704e-03, 1.8919e-02, -5.3251e-03, -3.2665e-03,
-3.2663e-02, -5.6756e-02, -2.3400e-02, 1.3674e-02, -6.6185e-03,
1.4429e-03, 1.2354e-02, 2.5934e-03, 2.1895e-02, -1.9793e-02,
1.5497e-03, 4.3056e-03, -4.0023e-02, 9.8740e-03, 3.8631e-03,
-1.2918e-02, -3.6782e-02, -9.8365e-03, 3.2182e-02, 2.3729e-02,
2.3509e-03, 1.8473e-02, 1.5583e-02, -1.1029e-02, -1.0738e-02,
-3.0278e-02, -9.8731e-03, -1.0500e-02, 7.9832e-05, -1.0345e-02,
8.2803e-03, -5.9923e-03, -1.2669e-02, 1.2065e-03, 7.5720e-03,
-1.9286e-02, 4.0070e-02, 3.6221e-03, -1.7486e-02, 2.1725e-02,
-3.3231e-02, 7.3948e-03, -1.0924e-02, 3.1448e-02, 1.2101e-02,
6.1737e-03, -2.0851e-02, -3.7964e-02, 8.0938e-03, -8.8967e-03,
2.5925e-02, -7.8063e-04, 8.6102e-03, 2.7370e-02, 1.2323e-02,
4.0606e-03, 3.9316e-02, -1.0837e-02, -2.6835e-03, 3.1941e-03,
-1.2017e-02, -2.3022e-02, 8.3533e-03, -2.2668e-02, 1.4438e-02,
-2.3664e-02, 4.5595e-02, -1.0962e-02, 1.7547e-02, -1.6739e-03,
1.2048e-02, 2.0544e-02, 2.8837e-02, -1.6736e-02, 2.1207e-02,
8.7612e-03, 2.8757e-02, -3.8561e-03, 8.4050e-03, -1.1503e-02,
-5.8332e-03, 1.5734e-02, -1.0773e-02, 7.5827e-03, 6.5794e-03,
2.4291e-02, 2.6811e-02, 1.1681e-02, -3.3246e-02, 4.5776e-03,
-9.0628e-04, -2.9400e-02, 4.2933e-03, 1.5885e-03, 5.5757e-02,
7.5518e-03, 1.0099e-02, 5.3507e-03, -3.0182e-02, 2.0830e-02,
1.0102e-02, -9.3074e-03, 3.1161e-02, -1.7800e-02, -4.4445e-03,
-3.1503e-02, 2.3028e-02, 8.3472e-03, 7.4444e-03, 1.8838e-02,
-1.1977e-02, -2.6713e-02, 1.1364e-02, 8.3522e-04, 3.3736e-03,
6.9425e-03, -2.0632e-02, 1.8155e-02, -2.1711e-02, -3.4703e-02,
-3.6268e-03, -4.8810e-03, -2.8142e-02, -1.5781e-02, -3.3166e-02,
-2.9910e-02, -9.7459e-03, -6.7474e-03, 1.7988e-02, 9.0176e-03,
1.9452e-02, 4.2009e-02, 1.7217e-02, 1.4959e-02, -1.6552e-02,
-3.8206e-03, -2.4889e-02, 7.7993e-03, -1.9285e-02, -1.9770e-02,
2.6936e-02, -5.0484e-03, -2.5117e-02, -2.3122e-02, 1.3754e-02,
1.6025e-02, -9.1569e-03, -2.0068e-02, -1.6013e-02, -2.1775e-02,
-2.4154e-02, 6.2840e-03, -1.3684e-02, 2.5378e-02, -1.3166e-02,
-1.2201e-02, 1.0011e-02, -8.2324e-03, -5.6623e-03, -1.0383e-02,
-1.6251e-02, 1.0723e-02, -3.0207e-03, -6.9374e-03, -2.3161e-03,
-2.0850e-03, -3.4216e-02, 3.3997e-02, 3.7444e-02, -3.4273e-02,
1.5051e-02, -9.5605e-03, -2.6979e-03, 1.8848e-02, 2.3090e-02,
1.9669e-02, -3.9656e-02, 1.0453e-02, 5.2222e-03, -7.2493e-03,
1.4122e-02, 5.6583e-04, -1.3991e-02, 4.0975e-02, 1.3947e-02,
4.6919e-03, 7.9121e-03, 2.6936e-02, 1.2338e-02, 1.9048e-02,
7.7740e-03, -6.4494e-03, -5.2965e-02, 8.1929e-03, -1.3503e-02,
3.7466e-03, -3.3504e-02, -8.1192e-03, 1.0463e-02, -2.1568e-02,
1.0076e-02, -1.3420e-02, -6.3353e-04, 7.4253e-03, 2.2281e-02,
5.2829e-03, 1.4102e-02, 1.4427e-02, 1.6331e-02, -2.3305e-04,
-4.4875e-02, 6.5300e-03, 2.4963e-02, 2.2141e-03, 3.9830e-02,
1.1405e-02, 8.6810e-03, -2.0404e-03, -1.8579e-03, 1.4765e-02,
5.4752e-03, -1.3364e-02, -1.3082e-03, 1.5873e-03, 1.9309e-02,
3.4367e-02, 1.8459e-02, -1.1323e-02, -1.8764e-02, -1.5370e-02,
3.6180e-03, 2.8253e-02, -1.6867e-03, 3.5884e-03, -2.1952e-02,
-1.5026e-02, -2.1070e-02, -1.2149e-02, 1.1162e-02, -3.0343e-02,
-4.1372e-02, 1.0880e-02, 2.2365e-02, 1.2896e-02, 2.9694e-02,
-8.4248e-03, -7.8876e-03, -6.7049e-03, 2.3700e-02, 4.7528e-03,
-7.8350e-03, -5.9220e-03, 3.8396e-02, -4.1598e-02, -2.3161e-03,
1.3419e-02, 7.1029e-03, 1.4195e-02, -1.1124e-02, 1.5812e-02,
-1.9789e-02, -2.3883e-02, -8.2788e-04, 1.4670e-02, -2.1482e-02,
-1.1182e-02, -1.6532e-02, -8.0637e-03, -3.7822e-02, 3.9402e-02,
-1.4097e-03, -7.6648e-03, -3.7156e-02, 2.5791e-02, 6.1038e-03,
-6.3429e-03, 3.2865e-03, 3.6277e-02, 9.4312e-03, -2.1003e-02,
-3.6885e-03, 1.7147e-02, -1.3079e-02, -4.9414e-02, -3.2066e-02,
1.4835e-02, -2.9742e-02, 1.8358e-02, -2.1733e-02, 3.0256e-03,
1.7825e-02, 1.1079e-02, 1.1619e-02, -2.3680e-02, -7.8721e-03,
2.4456e-03, 4.3608e-02, -4.5674e-03, -3.6818e-02, 3.3952e-02,
3.3108e-02, -3.1665e-03, -2.3468e-03, 1.5091e-02, 7.0856e-03,
1.1723e-02, -2.0713e-02, -6.9180e-03, 3.7929e-02, 3.7671e-03,
4.6663e-02, 9.5301e-03, 1.2638e-02, -6.5623e-03, -3.1771e-03,
-1.7568e-02, 1.8711e-03, -1.2310e-02, 2.1518e-02, 4.3408e-03,
-6.7171e-03, -5.0451e-03, 2.6870e-02, -1.9832e-02, 7.0422e-03,
1.1274e-02, -2.4637e-02, -4.8450e-03, 2.1892e-02, -2.6059e-02,
1.5605e-02, -1.1617e-02, -1.9273e-02, -8.6735e-04, -9.8002e-04,
-1.8553e-02, 2.1239e-02, 2.1078e-02, -1.2091e-02, 9.7025e-03,
1.3426e-02, -1.1710e-02, -2.2242e-03, 6.4133e-03, -1.4820e-02,
1.4682e-02, 3.0679e-02, 1.1526e-02, 1.0072e-02, -1.1572e-02,
2.6128e-02, 4.0879e-03, -1.7936e-02, 1.3715e-02, -2.3667e-02,
2.0419e-03, -1.6887e-02, 1.2595e-02, -2.1988e-02, -2.3777e-02,
-1.0399e-02, 2.4868e-03, -1.2265e-02, -1.8543e-02, 3.4672e-02,
2.1114e-02, 2.0523e-02, 7.6818e-03, 2.9282e-02, -5.9593e-03,
-2.8496e-02, 2.8482e-03, 3.6874e-04, 4.7455e-02, -2.9770e-02,
-2.0684e-02, -2.0749e-02, -5.7681e-02, -2.6175e-03, -2.4488e-02,
-5.2550e-03, -7.1191e-03, 3.8192e-02, 4.3438e-02, 5.4181e-03,
2.8392e-02, 1.9493e-02, -3.5262e-02, 1.4839e-02, 4.6481e-03,
1.7219e-02, 2.0160e-02, 4.9998e-03, 2.1316e-02, -8.7929e-04,
-2.1542e-02, 3.9816e-03, 1.5879e-02, 9.9231e-03, 1.3962e-02,
-5.3418e-03, 3.9857e-02, 2.0997e-02, -2.1291e-05, 1.8133e-02,
-1.2472e-02, 4.9437e-03, -1.5099e-02, 4.8860e-02, 6.1980e-03,
2.0197e-02, 1.3141e-04, -3.1087e-03, -2.2718e-03, 2.3804e-02,
6.0726e-03, -2.0485e-02, -2.0514e-02, -2.7679e-02, -3.0412e-02,
-1.7661e-02, -1.7462e-02, 7.5216e-03, 2.2238e-02, 1.1413e-03,
2.6647e-02, -2.3855e-02, 2.2652e-03, -4.3256e-03, -9.3274e-03,
2.5149e-02, 6.8432e-03, 4.2664e-03, 3.8221e-02, 7.7480e-03,
8.7203e-03, -1.2851e-03, -1.1325e-02, -1.0650e-02, -2.8079e-02,
-1.5375e-02, 2.2630e-02, -4.3439e-03, 1.3493e-02, -1.8223e-02,
9.9750e-03, -2.4560e-02, 1.0904e-03, -3.1198e-02, 4.7331e-03,
1.6713e-02, -1.7653e-02, -3.8674e-02, 1.5458e-02, 4.0555e-02,
6.9451e-03, 1.1988e-03, 8.0718e-04, 3.9985e-03, -2.2781e-02,
8.1173e-04, 2.0106e-02, -1.2800e-02, -1.2961e-02, -2.1273e-02,
-4.4104e-05, -3.6080e-02, -1.9392e-02, 3.2862e-02, -5.6041e-03,
2.3288e-02, -4.6795e-02, 1.7282e-02, 5.7052e-03, 2.2405e-02,
1.9871e-03, -1.4333e-02, 5.3773e-03, 4.3568e-02, 9.8980e-03,
-1.9403e-03, 1.8981e-02, -2.5712e-02, -3.3621e-03, 2.9886e-02,
1.3326e-03, 1.1318e-02, -3.3238e-03, -1.5494e-02, -3.0565e-02,
1.7137e-02, -2.7874e-02, -1.1257e-02, 3.2250e-02, -2.5293e-02,
-3.0693e-03, -2.7787e-02, 1.4931e-02, 2.4202e-03, -4.0572e-03,
5.0273e-03, 9.7496e-03, 2.2601e-02, 3.2389e-02, -1.1910e-02,
9.1037e-03, 5.6000e-02, -1.9640e-02, 1.5469e-02, -3.3027e-02,
1.4839e-02, 2.5071e-02, -1.2687e-02, -1.3466e-02, 1.9031e-02,
-7.3403e-03, -1.5207e-02, -1.4486e-02, 2.0678e-02, -4.1996e-02,
1.0585e-02, 3.6276e-02, 6.1149e-03, 1.6405e-02, 1.5643e-02,
1.5060e-02, -5.1235e-03, -2.2824e-02, -1.3752e-02, -1.5742e-02,
2.4032e-02, -2.1782e-03, -1.3158e-02, 3.9482e-03, 3.2267e-02,
-2.2632e-03, 1.2055e-02, 4.4731e-02, 1.8271e-02, -1.1486e-02,
1.7836e-02, 1.7886e-03, -2.4020e-02, 2.6064e-02, -2.2122e-04,
1.8643e-02, -2.9808e-02, -6.1845e-03, -4.4464e-03, 8.8374e-04,
1.5268e-02, 1.7205e-03, 5.7832e-02, -1.7486e-02, 1.1897e-02,
5.8081e-02, 1.7667e-02, -7.7282e-03, 1.4036e-02, -1.4936e-03,
6.0635e-04, 1.6124e-03, -1.6916e-02, -1.1239e-02, 1.8497e-02,
1.2334e-03, -2.0706e-02, 3.2959e-03, 2.9186e-02, 3.7506e-02,
1.2037e-02, -1.4903e-02, 8.5606e-03, 3.4136e-03, 1.1850e-02,
-7.4782e-03, 5.3924e-03, -2.4772e-02, 2.6840e-02, -2.7656e-02,
-3.2637e-02, -1.2779e-02, 1.0730e-02, 1.4096e-03, 3.1572e-02,
7.8976e-04, 3.1674e-02, 8.5333e-03, -1.2679e-02, 1.1176e-02,
-2.0446e-02, 1.8628e-02, -4.0158e-02, -2.3358e-02, -2.2504e-02,
-2.8759e-02, -1.4597e-02, -8.5879e-03, 1.0550e-02, -3.5556e-02,
-1.9046e-02, -1.9159e-02, -2.2703e-02, -7.2056e-03, 4.2380e-02,
-9.7475e-03, -2.4754e-02, 1.3992e-03, -1.0411e-02, 1.5708e-02,
-8.2899e-03, -6.4856e-03, 1.6359e-02, -5.1969e-04, -5.0958e-03,
-4.1232e-02, 2.7349e-03, -1.7723e-02, 1.3388e-02, 2.2776e-03,
-2.0786e-02, -1.8082e-02, -2.4866e-03, 2.2141e-02, 6.9998e-03,
-5.5714e-03, 2.1088e-02, 5.8745e-03, 1.2788e-02, 4.2977e-03,
5.8631e-03, -1.8121e-02, 1.9242e-03, 2.3622e-02, 1.4917e-02,
-5.3198e-03, -3.9222e-02, -2.4697e-02, 9.1218e-03, -1.0711e-02,
1.0268e-02, 1.5148e-02, -4.4508e-02, 4.6783e-03, 2.8093e-03,
9.1253e-03, -7.3281e-03, 1.0114e-03, -9.2369e-04, 1.4841e-02,
2.2642e-02, 2.3675e-02, 1.3902e-02, -5.6343e-03, 1.4851e-02,
-9.5169e-03, -3.1721e-02, 1.6696e-02, 2.9285e-02, -1.4090e-02,
2.1128e-02, 4.8656e-02, 3.8431e-02, -3.5470e-02, -4.8230e-03,
-1.6513e-02, 4.1917e-02, 8.9090e-03, -1.4022e-04, 4.0182e-03,
7.1723e-03, 3.1419e-02, -4.8508e-03, 1.7768e-03, -7.3688e-03,
3.4637e-03, -2.3227e-02, 3.9606e-05, -2.4731e-02, -1.3640e-02,
-5.1718e-03, 2.6662e-02, -1.2871e-02, -1.6009e-02, -5.3720e-03,
2.7397e-04, -3.4016e-03, 2.6429e-02, 3.8069e-02, 1.0929e-02,
-1.0620e-02, 1.2165e-02, -2.6018e-02, 1.6021e-02, 4.0644e-02,
-8.0898e-03, -3.5198e-02, -1.9602e-02, 2.4986e-02, -5.8400e-03,
3.2070e-02, -1.8265e-02, -5.4518e-03, 2.8195e-02, 5.5598e-02,
-3.9959e-02, 1.5521e-02, -2.8416e-02, 3.1130e-02, -1.0038e-02,
2.1522e-02, -1.1654e-02, 2.2382e-02, -5.4467e-03, -2.2840e-02,
2.7036e-03, -4.4607e-02, -4.1953e-02, 2.0079e-02, -5.0121e-03,
-1.7495e-02, 4.4070e-03, 3.7400e-04, 1.0899e-02, 1.7008e-02,
-1.6307e-02, -1.9986e-02, -2.3865e-02, -2.5618e-02, -2.9981e-02,
-2.7230e-03, 2.7079e-02, 5.2920e-03, 2.1069e-02, -2.5896e-02,
-1.6256e-02, -1.4182e-03, 1.1829e-02, 1.0360e-02, 2.8883e-02,
-6.8762e-03, 1.4032e-02, -4.3389e-03]], requires_grad=True)

109
test1.txt Normal file
View File

@ -0,0 +1,109 @@
from torchviz import make_dot
dot = make_dot(query_output.last_hidden_state, params=dict(self.Qformer.bert.named_parameters()))
log_dir = '/mnt/pfs-guan-ssai/nlu/wanghanzi/multimodal/PromptMoE/'
dot.render(filename="Pre_PromptMoE_RawProb_backward_graph", directory=log_dir, format="pdf")
# Pre-Prompt-MoE
model.Qformer.bert.encoder.layer[6].experts.gate.weight.grad
model.Qformer.bert.encoder.layer[8].experts.gate.weight.grad
model.Qformer.bert.encoder.layer[10].experts.gate.weight.grad
model.Qformer.bert.encoder.layer[6].experts.experts[0].dense1.weight.grad
model.Qformer.bert.encoder.layer[10].experts.experts[0].dense1.weight.grad
model.Qformer.bert.encoder.layer[8].experts.experts[0].dense1.weight.grad
model.Qformer.bert.encoder.layer[8].experts.experts[1].dense1.weight.grad
model.Qformer.bert.encoder.layer[8].experts.experts[2].dense1.weight.grad
model.Qformer.bert.encoder.layer[8].experts.gate.weight.grad
model.Qformer.bert.encoder.layer[9].intermediate_query.dense.weight
model.Qformer.bert.encoder.layer[9].intermediate_query.dense.weight.grad
model.Qformer.bert.encoder.layer[10].intermediate.dense.weight.grad
model.Qformer.bert.encoder.layer[11].intermediate.dense.weight.grad
model.Qformer.bert.encoder.layer[10].intermediate_query.dense.weight
model.Qformer.bert.encoder.layer[10].experts.experts[2].dense1.weight
model.Qformer.bert.encoder.layer[10].experts.experts[1].dense1.weight
model.Qformer.bert.encoder.layer[10].experts.experts[0].dense1.weight
model.Qformer.bert.encoder.layer[10].intermediate_query.dense.weight == model.Qformer.bert.encoder.layer[10].experts.experts[0].dense1.weight
# Pre-MoE gate-sentence
# model.Qformer.bert.encoder.layer[8].experts.gate.weight.grad 不更新
# Pre-MoE gate-token
# 正常更新
# Post-MoE gate-sentence
model.Qformer.bert.encoder.layer[8].experts.gate.weight.grad
# model.Qformer.bert.encoder.layer[8].experts.gate.weight.grad 正常更新
# model.Qformer.bert.encoder.layer[6].experts.gate.weight.grad 全是0/-0
# model.Qformer.bert.encoder.layer[10].experts.gate.weight.grad 全是0/-0
# Route-MoE
# Pre-MoE 算的beam_scores有问题
# Post-Route 会更新多个expert的参数会更新gate的参数
# Layer 6 更新了两个expert的参数 (layer 6 layer 8)
# model.Qformer.bert.encoder.layer[11].intermediate.dense.weight.grad 是0都是0
# model.Qformer.bert.encoder.layer[11].output.dense.weight.grad
model.Qformer.bert.encoder.layer[6].experts.gate.weight.grad
model.Qformer.bert.encoder.layer[6].experts.experts[0].intermediate_query.dense.weight.grad
model.Qformer.bert.encoder.layer[6].experts.experts[1].intermediate_query.dense.weight.grad
model.Qformer.bert.encoder.layer[7].experts.gate.weight.grad
model.Qformer.bert.encoder.layer[7].experts.experts[0].intermediate_query.dense.weight.grad
model.Qformer.bert.encoder.layer[7].experts.experts[1].intermediate_query.dense.weight.grad
model.Qformer.bert.encoder.layer[8].experts.gate.weight.grad
model.Qformer.bert.encoder.layer[8].experts.experts[0].intermediate_query.dense.weight.grad
model.Qformer.bert.encoder.layer[8].experts.experts[1].intermediate_query.dense.weight.grad
model.Qformer.bert.encoder.layer[9].experts.gate.weight.grad
model.Qformer.bert.encoder.layer[9].experts.experts[0].intermediate_query.dense.weight.grad
model.Qformer.bert.encoder.layer[9].experts.experts[1].intermediate_query.dense.weight.grad
model.Qformer.bert.encoder.layer[10].experts.gate.weight.grad
model.Qformer.bert.encoder.layer[10].experts.experts[0].intermediate_query.dense.weight.grad
model.Qformer.bert.encoder.layer[10].experts.experts[1].intermediate_query.dense.weight.grad
model.Qformer.bert.encoder.layer[11].experts.gate.weight.grad
model.Qformer.bert.encoder.layer[11].experts.experts[0].intermediate_query.dense.weight.grad
model.Qformer.bert.encoder.layer[11].experts.experts[1].intermediate_query.dense.weight.grad
(Pdb) [p for n, p in self.model.named_parameters() if n == 'Qformer.bert.encoder.layer.10.experts.experts.0.dense1.weight']
[Parameter containing:
tensor([[-0.0328, 0.0414, 0.0010, ..., -0.0068, 0.0244, 0.0587],
[ 0.0120, 0.0458, 0.0171, ..., -0.0439, -0.0107, -0.0397],
[ 0.0239, 0.0191, -0.0145, ..., 0.0008, -0.0067, 0.0090],
...,
[ 0.0174, -0.0465, -0.0106, ..., -0.0095, 0.0153, -0.0195],
[-0.0151, -0.0082, -0.0320, ..., -0.0016, -0.0232, -0.0147],
[ 0.0142, -0.0286, 0.0161, ..., -0.0160, -0.0306, -0.0272]],
device='cuda:0', requires_grad=True)]
(Pdb) [p for n, p in self.model.named_parameters() if n == 'Qformer.bert.encoder.layer.8.experts.experts.0.dense1.weight']
[Parameter containing:
tensor([[ 0.0024, 0.0218, -0.0186, ..., -0.0178, -0.0067, 0.0820],
[-0.0759, -0.0002, -0.0548, ..., 0.0292, 0.0531, 0.0779],
[-0.0220, -0.0037, -0.0520, ..., -0.0426, -0.0261, -0.0357],
...,
[-0.0448, 0.0471, 0.0133, ..., -0.0062, -0.0217, -0.0203],
[ 0.0532, 0.0197, 0.0320, ..., -0.0010, -0.0838, 0.0682],
[ 0.0284, 0.0038, -0.0007, ..., -0.0305, 0.0296, 0.0056]],
device='cuda:0', requires_grad=True)]
(Pdb) [p for n, p in self.model.named_parameters() if n == 'Qformer.bert.encoder.layer.6.experts.experts.0.dense1.weight']
[Parameter containing:
tensor([[ 6.5176e-02, -4.6473e-02, -2.7396e-02, ..., 2.1774e-03,
6.1457e-02, 1.9180e-03],
[ 7.3707e-03, 6.1392e-02, -2.7108e-02, ..., 4.0778e-02,
-1.9791e-02, -1.1612e-02],
[ 2.1193e-02, -3.8323e-02, -6.0238e-02, ..., -1.4539e-02,
9.2965e-02, 3.9153e-02],
...,
[ 5.3203e-03, -1.7276e-02, -3.2191e-02, ..., -1.6435e-02,
-1.8553e-02, -2.8158e-02],
[-6.9853e-02, 9.2719e-03, -1.8895e-03, ..., -2.6425e-02,
1.4880e-03, 3.4505e-02],
[-1.2168e-03, 3.7038e-02, 4.8047e-02, ..., -3.4523e-03,
-1.3030e-05, -1.4778e-02]], device='cuda:0', requires_grad=True)]