1
0
mirror of https://github.com/Vision-CAIR/MiniGPT-4.git synced 2025-04-07 19:40:45 +00:00

0329 add cls token

This commit is contained in:
hanziwang 2024-03-29 10:09:06 +00:00
parent 5bec4d0608
commit 2057032a63
11 changed files with 3200 additions and 29 deletions

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -517,15 +517,6 @@ class BertLayer(nn.Module):
outputs + cross_attention_outputs[1:-1]
) # add cross attentions if we output attention weights
# add moe query ffn
# query_attention_output size: [bz, query_length+seq_len, 768]
# attention_mask size: [bz, 1, 1, query_length+seq_len]
moe_ffn_attention_input = query_attention_output[:, :query_length, :]
moe_ffn_attention_mask = attention_mask.squeeze(dim=1).squeeze(dim=1)[:, :query_length]
layer_output = self.feed_forward_query_moe(moe_ffn_attention_input, moe_ffn_attention_mask, beam_scores, expert_route)
# layer_output = (layer_output, beam_scores, expert_route, beam_idx, importance_loss)
import pdb; pdb.set_trace() # 0107test
if attention_output.shape[1] > query_length: # have text input in Qformer
layer_output_text = apply_chunking_to_forward(
self.feed_forward_chunk,
@ -533,6 +524,19 @@ class BertLayer(nn.Module):
self.seq_len_dim,
attention_output[:, query_length:, :],
)
cls_hidden = layer_output_text[0][:, 0, :] # [bz, hidden_size]
# add moe query ffn
# query_attention_output size: [bz, query_length+seq_len, 768]
# attention_mask size: [bz, 1, 1, query_length+seq_len]
moe_ffn_attention_input = query_attention_output[:, :query_length, :]
moe_ffn_attention_mask = attention_mask.squeeze(dim=1).squeeze(dim=1)[:, :query_length]
layer_output = self.feed_forward_query_moe(moe_ffn_attention_input, moe_ffn_attention_mask, beam_scores, expert_route, cls_hidden)
# layer_output = (layer_output, beam_scores, expert_route, beam_idx, importance_loss)
# import pdb; pdb.set_trace() # 0107test
if attention_output.shape[1] > query_length: # have text input in Qformer
if self.layer_judge == 'first' and self.num_beams>1:
# if layer_output[0].shape[0] == layer_output_text.shape[0]*self.num_beams and self.num_beams>1:
# adjust the dimension of layer_output_text to bz*num_beams
@ -622,13 +626,13 @@ class BertLayer(nn.Module):
layer_output = self.output(intermediate_output, attention_output)
return layer_output
def feed_forward_query_moe(self, attention_output, expert_attention_mask, beam_scores, expert_route):
def feed_forward_query_moe(self, attention_output, expert_attention_mask, beam_scores, expert_route, cls_hidden):
if not self.use_experts:
layer_output = self.experts(attention_output)
return layer_output, None, None, None, 0.0
layer_output, beam_scores, expert_route, beam_idx, importance_loss = self.experts(
attention_output, expert_attention_mask, beam_scores, expert_route
attention_output, expert_attention_mask, beam_scores, expert_route, cls_hidden
)
return layer_output, beam_scores, expert_route, beam_idx, importance_loss

View File

@ -523,12 +523,22 @@ class BertLayer(nn.Module):
outputs + cross_attention_outputs[1:-1]
) # add cross attentions if we output attention weights
if attention_output.shape[1] > query_length: # have text input in Qformer
layer_output_text = apply_chunking_to_forward(
self.feed_forward_chunk,
self.chunk_size_feed_forward,
self.seq_len_dim,
attention_output[:, query_length:, :],
)
cls_hidden = layer_output_text[0][:, 0, :] # [bz, hidden_size]
# add moe query ffn
# query_attention_output size: [bz, query_length+seq_len, 768]
# attention_mask size: [bz, 1, 1, query_length+seq_len]
moe_ffn_attention_input = query_attention_output[:, :query_length, :]
moe_ffn_attention_mask = attention_mask.squeeze(dim=1).squeeze(dim=1)[:, :query_length]
layer_output = self.feed_forward_query_moe(moe_ffn_attention_input, moe_ffn_attention_mask, beam_scores, expert_route)
layer_output = self.feed_forward_query_moe(moe_ffn_attention_input, moe_ffn_attention_mask, beam_scores, expert_route, cls_hidden)
# layer_output = (layer_output, beam_scores, expert_route, beam_idx, importance_loss)
# import pdb; pdb.set_trace() # 0107test
@ -628,14 +638,14 @@ class BertLayer(nn.Module):
layer_output = self.output(intermediate_output, attention_output)
return layer_output
def feed_forward_query_moe(self, attention_output, expert_attention_mask, beam_scores, expert_route):
def feed_forward_query_moe(self, attention_output, expert_attention_mask, beam_scores, expert_route, cls_hidden):
if not self.use_experts:
hidden_states = self.experts(attention_output)
layer_output = self.expert_ln(hidden_states + attention_output)
return layer_output, None, None, None, 0.0
hidden_states, beam_scores, expert_route, beam_idx, importance_loss = self.experts(
attention_output, expert_attention_mask, beam_scores, expert_route
attention_output, expert_attention_mask, beam_scores, expert_route, cls_hidden
)
if hidden_states.shape[0]==attention_output.shape[0]*self.num_beams and self.num_beams>1:
attention_output = self.adjust_hidden_states_by_num_beams(attention_output)

View File

@ -27,6 +27,8 @@ from minigpt4.models.QformerRouteMoE import BertMoERouteLMHeadModel
from minigpt4.models.QformerRouteMoELN import BertMoERouteLMHeadModelLNIn
from minigpt4.models.QformerRouteMoELNUni import BertMoERouteLMHeadModelLNInUniversal
from minigpt4.models.QformerRouteMoEUni import BertMoERouteLMHeadModelUniversal
from minigpt4.models.QformerRouteMoECLS import BertMoECLSRouteLMHeadModel
from minigpt4.models.QformerRouteMoECLSLN import BertMoECLSRouteLMHeadModelLNIn
from minigpt4.models.eva_vit import create_eva_vit_g
from transformers import BertTokenizer
from peft import (
@ -99,6 +101,38 @@ class Blip2Base(BaseModel):
return RouteMoEQformer, query_tokens
@classmethod
def init_RouteCLSMoEQformer(cls, num_query_token, vision_width, moebert_expert_num, moebert_num_beams, route_method, moe_weight_type, cross_attention_freq=2, ln_position="out"):
moe_encoder_config = BertConfig.from_pretrained("/mnt/pfs-guan-ssai/nlu/wanghanzi/models/bert-base-uncased")
moe_encoder_config.encoder_width = vision_width
# insert cross-attention layer every other block
moe_encoder_config.add_cross_attention = True
moe_encoder_config.cross_attention_freq = cross_attention_freq
moe_encoder_config.query_length = num_query_token
moe_encoder_config.moebert_expert_num = moebert_expert_num
moe_encoder_config.moebert_num_beams = moebert_num_beams
moe_encoder_config.route_method = route_method
moe_encoder_config.moe_weight_type = moe_weight_type
if ln_position == "out":
RouteMoEQformer = BertMoECLSRouteLMHeadModel.from_pretrained(
"/mnt/pfs-guan-ssai/nlu/wanghanzi/models/bert-base-uncased", config=moe_encoder_config
)
elif ln_position == "in":
# need to adjust
RouteMoEQformer = BertMoECLSRouteLMHeadModelLNIn.from_pretrained(
"/mnt/pfs-guan-ssai/nlu/wanghanzi/models/bert-base-uncased", config=moe_encoder_config
)
query_tokens = nn.Parameter(
torch.zeros(1, num_query_token, moe_encoder_config.hidden_size)
)
query_tokens.data.normal_(mean=0.0, std=moe_encoder_config.initializer_range)
return RouteMoEQformer, query_tokens
@classmethod
def init_RouteMoEQformer(cls, num_query_token, vision_width, moebert_expert_num, moebert_num_beams, route_method, moe_weight_type, cross_attention_freq=2, ln_position="out"):
moe_encoder_config = BertConfig.from_pretrained("/mnt/pfs-guan-ssai/nlu/wanghanzi/models/bert-base-uncased")

View File

@ -85,7 +85,7 @@ class Blip2VicunaInstruct(Blip2Base):
)
print('Initing & Loading Qformer')
if general_version in ['naive_moe', 'route_moe', 'uni_route_moe']:
if general_version in ['naive_moe', 'route_moe', 'uni_route_moe', 'cls_route_moe']:
if general_version == 'naive_moe':
self.Qformer, self.query_tokens = self.init_QformerMoE(
num_query_token=num_query_token,
@ -121,6 +121,17 @@ class Blip2VicunaInstruct(Blip2Base):
cross_attention_freq=2,
ln_position=ln_position,
)
elif general_version == 'cls_route_moe':
self.Qformer, self.query_tokens = self.init_RouteCLSMoEQformer(
num_query_token=num_query_token,
vision_width=self.visual_encoder.num_features,
moebert_expert_num=moebert_expert_num,
moebert_num_beams=moebert_num_beams,
route_method=moebert_route_method,
moe_weight_type=moe_weight_type,
cross_attention_freq=2,
ln_position=ln_position,
)
elif general_version == 'base':
self.Qformer, self.query_tokens = self.init_Qformer(

View File

@ -0,0 +1,265 @@
import copy
import pickle
import torch
import torch.nn as nn
import torch.nn.functional as F
class RouteMoELayer(nn.Module):
def __init__(self, hidden_size, expert, num_experts, num_beams=2, layer_judge=None, route_method="pre-route", weight_type="ffn_prob"):
# remove hash list
nn.Module.__init__(self)
self.num_experts = num_experts
self.experts = nn.ModuleList([copy.deepcopy(expert) for i in range(num_experts)])
self.num_beams = num_beams
self.hidden_size = hidden_size
self.layer_judge = layer_judge
self.weight_type = weight_type
self.route_method = route_method
if self.route_method == "pre-route":
self.gate = nn.Linear(hidden_size, num_experts, bias=False).float()
elif self.route_method in ["post-route", "post-route-dp"]:
gate = nn.Linear(hidden_size, 1, bias=False).float()
self.gate = gate
# self.gates = nn.ModuleList([copy.deepcopy(gate) for i in range(num_experts)])
def _importance_auxiliary_loss(self, prob_gate):
# From VMOE
# _importance_auxiliary_loss
axis = tuple(range(prob_gate.ndim - 1)) # All except last.
importance_per_expert = torch.sum(prob_gate, dim=axis)
std_importance_per_expert = torch.std(importance_per_expert)
mean_importance_per_expert = torch.mean(importance_per_expert)
# Compute coefficient of variation (i.e. std/mean) squared.
return (std_importance_per_expert / mean_importance_per_expert)**2
def forward_gate(self, x):
"""
x : torch.Size([bz*num_beams, 32, 768]) or torch.Size([bz, 32, 768])
prob_gate : torch.Size([bz*num_beams, num_experts]) or torch.Size([bz, num_experts])
"""
attention_mask = torch.ones(x.shape[0], x.shape[1]).to(x.device)
x_masked = x * attention_mask.unsqueeze(-1) # torch.Size([bz*num_beams, 32, 768])
# x_average = x_masked.sum(1) / attention_mask.unsqueeze(-1).sum(1) # torch.Size([bz*num_beams, 768])
x_average = torch.mean(x_masked, dim=1) # torch.Size([bz*num_beams, 768])
logits_gate = self.gate(x_average) # torch.Size([bz*num_beams, num_experts])
prob_gate = F.softmax(logits_gate, dim=-1) # torch.Size([bz*num_beams, num_experts])
return prob_gate
def dp_search(self, current_scores_log, beam_scores, expert_route, batch_size):
if self.layer_judge=='first' and self.route_method in ['post-route-dp']:
# current_scores_log torch.Size([bz, num_experts])
assert beam_scores==None and expert_route==None
current_scores = torch.exp(current_scores_log)
topk_values, gate = torch.topk(current_scores, self.num_beams, dim=1) # gate, 每个样本被分配的expert: torch.Size([bz, topk])
beam_scores = topk_values.view(self.num_beams * batch_size) # torch.Size([bz * num_beams])
expert_route = gate.view(self.num_beams * batch_size).unsqueeze(1) # torch.Size([bz * num_beams,1])
beam_idx = torch.tensor(range(self.num_beams * batch_size))
else:
batch_size = int(batch_size // self.num_beams)
next_scores_raw = current_scores_log + torch.log(beam_scores).unsqueeze(1) # torch.Size([4*3, 5]) # 取log 之后,可以直接相加概率
next_scores_exp = torch.exp(next_scores_raw)
next_scores_raw, next_experts_raw = torch.topk(next_scores_exp, 1, dim=1, largest=True, sorted=True)
next_scores = next_scores_raw.view(batch_size, self.num_beams)
next_experts = next_experts_raw.view(batch_size, self.num_beams)
# next_scores, next_experts = torch.topk(current_scores_log, 1, dim=1, largest=True, sorted=True) # equal 等价
# next_scores torch.Size([bz * num_beams, 1])
# next_tokens torch.Size([bz * num_beams, 1])
next_batch_beam = list()
for batch_idx in range(batch_size):
next_sent_beam = list()
expert_id = next_experts[batch_idx]
expert_score = next_scores[batch_idx]
values, index = torch.topk(expert_score, self.num_beams, dim=0, largest=True, sorted=True)
for i in range(self.num_beams):
beam_id = index[i].item()
ex_id = expert_id[beam_id].item()
effective_beam_id = batch_idx*self.num_beams + beam_id
next_sent_beam.append((values[i], ex_id, effective_beam_id))
next_batch_beam.extend(next_sent_beam)
beam_scores = beam_scores.new([x[0] for x in next_batch_beam])
beam_experts = expert_route[:,-1].new([x[1] for x in next_batch_beam])
beam_idx = expert_route[:,-1].new([x[2] for x in next_batch_beam])
pre_route = expert_route[beam_idx,:]
expert_route = torch.cat([pre_route, beam_experts.unsqueeze(1)], dim=-1)
return beam_scores, expert_route, beam_idx
def beam_search(self, current_scores_log, beam_scores, expert_route, batch_size):
if self.layer_judge=='first' and self.route_method in ['pre-route', 'post-route']:
# current_scores_log torch.Size([bz, num_experts])
assert beam_scores==None and expert_route==None
current_scores = torch.exp(current_scores_log)
topk_values, gate = torch.topk(current_scores, self.num_beams, dim=1) # gate, 每个样本被分配的expert: torch.Size([bz, topk])
beam_scores = topk_values.view(self.num_beams * batch_size) # torch.Size([bz * num_beams])
expert_route = gate.view(self.num_beams * batch_size).unsqueeze(1) # torch.Size([bz * num_beams,1])
beam_idx = torch.tensor(range(self.num_beams * batch_size))
else:
batch_size = int(batch_size // self.num_beams)
next_scores_raw = current_scores_log + torch.log(beam_scores).unsqueeze(1) # torch.Size([4*3, 5]) # 取log 之后,可以直接相加概率
next_scores_exp = torch.exp(next_scores_raw)
next_scores_raw1 = next_scores_exp.view(
batch_size, self.num_beams * self.num_experts
) # torch.Size([bz, num_beams*num_experts])
next_scores, next_experts = torch.topk(next_scores_raw1, self.num_beams, dim=1, largest=True, sorted=True)
# next_scores torch.Size([bz, num_beams])
# next_tokens torch.Size([bz, num_beams])
next_batch_beam = list()
for batch_idx in range(batch_size):
next_sent_beam = list()
for rank, (expert_id, expert_score) in enumerate(
zip(next_experts[batch_idx], next_scores[batch_idx])
):
expert_id = expert_id.item()
beam_id = expert_id // self.num_experts
ex_id = expert_id % self.num_experts
effective_beam_id = batch_idx*self.num_beams + beam_id
next_sent_beam.append((expert_score, ex_id, effective_beam_id))
next_batch_beam.extend(next_sent_beam)
beam_scores = beam_scores.new([x[0] for x in next_batch_beam])
beam_experts = expert_route[:,-1].new([x[1] for x in next_batch_beam])
beam_idx = expert_route[:,-1].new([x[2] for x in next_batch_beam])
pre_route = expert_route[beam_idx,:]
expert_route = torch.cat([pre_route, beam_experts.unsqueeze(1)], dim=-1)
return beam_scores, expert_route, beam_idx
def forward_expert_ffn(self, x, expert_select, current_scores):
"""
x_repeat : [bz*num_beams, 32,768]
expert_select : [bz*num_beams]
current_scores : [bz*num_beams, num_experts] / [bz, num_experts]
"""
# add_1228 l2_normalization
# normalized_tensor = torch.nn.functional.normalize(current_scores, p=2, dim=0) # L2 Normalization torch.Size([bz, topk])
# tmp_prob = normalized_tensor.unsqueeze(-1).unsqueeze(-1)
# import pdb;pdb.set_trace()
outputs = list()
for i in range(self.num_experts):
output_x = self.experts[i].forward(x)
outputs.append(output_x.unsqueeze(1))
candidate_output = torch.cat(outputs, dim=1)
expert_select_matrix = F.one_hot(expert_select, self.num_experts)
if self.weight_type == 'ffn_prob':
tmp_prob = current_scores * expert_select_matrix
candidate_output = candidate_output * tmp_prob.unsqueeze(-1).unsqueeze(-1)
else:
candidate_output = candidate_output * expert_select_matrix.unsqueeze(-1).unsqueeze(-1)
output = torch.sum(candidate_output, dim=1)
# import pdb;pdb.set_trace()
return output # torch.Size([bz*num_beams, 32, 768])
def forward_pre_route(self, x, beam_scores, expert_route, use_log=True):
current_scores = self.forward_gate(x) # [bz, num_beams] / [bz*num_beams, num_beams]
importance_loss = self._importance_auxiliary_loss(current_scores)
if use_log:
current_scores_log = torch.log(current_scores) # 取log之后可以直接相加
else:
current_scores_log = current_scores
# import pdb;pdb.set_trace()
batch_size, num_tokens = x.shape[0], x.shape[1]
beam_scores, expert_route, beam_idx = self.beam_search(current_scores_log, beam_scores, expert_route, batch_size)
current_expert_select = expert_route[:,-1]
if self.layer_judge=='first': # expand first dim to batch_size * num_beams
replicated_tensor = x.unsqueeze(1).expand(batch_size, self.num_beams, num_tokens, self.hidden_size)
x = replicated_tensor.contiguous().view(-1, num_tokens, self.hidden_size) # [bz*num_beams, 32,768]
current_scores_t = current_scores.unsqueeze(1).expand(batch_size, self.num_beams, self.num_experts)
current_scores = current_scores_t.contiguous().view(-1, self.num_experts) # [bz*num_beams, num_experts]
input_x = x[beam_idx]
candidate_output = self.forward_expert_ffn(input_x, current_expert_select, current_scores) # [bz*num_beams, 32,768]
# import pdb;pdb.set_trace()
return candidate_output, beam_scores, expert_route, beam_idx, importance_loss
def forward_post_route(self, x, beam_scores, expert_route, use_log=True):
attention_mask = torch.ones(x.shape[0], x.shape[1]).to(x.device)
x_masked = x * attention_mask.unsqueeze(-1) # torch.Size([bz, 32, 768])
def forward_expert(input_x, expert_idx):
output_x = self.experts[expert_idx].forward(input_x)
return output_x
outputs = list()
logits_gate_lst = list()
for expert_idx in range(self.num_experts):
output_x = forward_expert(x_masked, expert_idx)
# output_x_aver = output_x.sum(1) / attention_mask.unsqueeze(-1).sum(1) # torch.Size([bz*num_beam, 768])
output_x_aver = torch.mean(output_x, dim=1)
# gate_score = self.gates[expert_idx](output_x_aver)
gate_score = self.gate(output_x_aver)
logits_gate_lst.append(gate_score)
outputs.append(output_x.unsqueeze(0))
candidate_output_raw = torch.cat(outputs) # torch.Size([num_expert, bz*num_beam, 32, 768])
logits_gate = torch.cat(logits_gate_lst,dim=1)# torch.Size([bz*num_beam, num_expert])
current_scores = F.softmax(logits_gate, dim=-1) # torch.Size([bz*num_beam, num_experts])
if use_log:
current_scores_log = torch.log(current_scores) # 取log之后可以直接相加
else:
current_scores_log = current_scores
# importance loss
importance_loss = self._importance_auxiliary_loss(current_scores)
batch_size, num_tokens = x.shape[0], x.shape[1] # bz*num_beam
if self.route_method == 'post-route':
beam_scores, expert_route, beam_idx = self.beam_search(current_scores_log, beam_scores, expert_route, batch_size)
elif self.route_method == 'post-route-dp':
beam_scores, expert_route, beam_idx = self.dp_search(current_scores_log, beam_scores, expert_route, batch_size)
# beam_scores torch.Size([bz*num_beam])
# expert_route torch.Size([bz*num_beam, layer_n])
current_select_expert = expert_route[:,-1]
# current_select_expert torch.Size([bz*num_beam, 1])
if self.layer_judge == 'first':
replicated_tensor = candidate_output_raw.unsqueeze(2).expand(self.num_experts, batch_size, self.num_beams, num_tokens, self.hidden_size)
candidate_output_raw = replicated_tensor.contiguous().view(self.num_experts, -1, num_tokens, self.hidden_size) # [bz*num_beams, 32,768]
current_scores_t = current_scores.unsqueeze(1).expand(batch_size, self.num_beams, self.num_experts)
current_scores = current_scores_t.contiguous().view(-1, self.num_experts) # [bz*num_beams, num_experts]
candidate_output = candidate_output_raw.permute(1, 0, 2, 3)[beam_idx] # torch.Size([8, 2, 32, 768])
expert_select_matrix = F.one_hot(current_select_expert, self.num_experts)
if self.weight_type == 'ffn_prob':
tmp_prob = current_scores[beam_idx] * expert_select_matrix
output = candidate_output * tmp_prob.unsqueeze(-1).unsqueeze(-1)
else:
output = candidate_output * expert_select_matrix.unsqueeze(-1).unsqueeze(-1)
final_output = torch.sum(output, dim=1)
return final_output, beam_scores, expert_route, beam_idx, importance_loss
def forward(self, x, attention_mask, beam_scores, expert_route, use_log=True):
"""
if first_layer: x [bz, 32, 768]
else: x [bz*num_beams, 32, 768]
"""
if self.route_method == 'pre-route':
candidate_output, beam_scores, expert_route, beam_idx, importance_loss = self.forward_pre_route(x, beam_scores, expert_route, use_log=True)
elif self.route_method in ['post-route', 'post-route-dp']:
candidate_output, beam_scores, expert_route, beam_idx, importance_loss = self.forward_post_route(x, beam_scores, expert_route, use_log=True)
else:
assert("route method should in pre-route, post-route, post-route-dp")
return candidate_output, beam_scores, expert_route, beam_idx, importance_loss

View File

@ -22,6 +22,9 @@ class RouteMoELayer(nn.Module):
gate = nn.Linear(hidden_size, 1, bias=False).float()
self.gate = gate
# self.gates = nn.ModuleList([copy.deepcopy(gate) for i in range(num_experts)])
elif self.route_method in ['cls-route', 'cls-query-route', 'cls-cross-route']:
self.gate = nn.Linear(hidden_size, 1, bias=False).float()
def _importance_auxiliary_loss(self, prob_gate):
# From VMOE
@ -33,7 +36,6 @@ class RouteMoELayer(nn.Module):
# Compute coefficient of variation (i.e. std/mean) squared.
return (std_importance_per_expert / mean_importance_per_expert)**2
def forward_gate(self, x):
"""
x : torch.Size([bz*num_beams, 32, 768]) or torch.Size([bz, 32, 768])
@ -91,7 +93,7 @@ class RouteMoELayer(nn.Module):
return beam_scores, expert_route, beam_idx
def beam_search(self, current_scores_log, beam_scores, expert_route, batch_size):
if self.layer_judge=='first' and self.route_method in ['pre-route', 'post-route']:
if self.layer_judge=='first' and self.route_method in ['pre-route', 'post-route', 'cls-route', 'cls-query-route', 'cls-cross-route']:
# current_scores_log torch.Size([bz, num_experts])
assert beam_scores==None and expert_route==None
current_scores = torch.exp(current_scores_log)
@ -247,7 +249,85 @@ class RouteMoELayer(nn.Module):
return final_output, beam_scores, expert_route, beam_idx, importance_loss
def forward(self, x, attention_mask, beam_scores, expert_route, use_log=True):
def calculate_cls_gate_score(self, cls_hidden, output_x):
if self.route_method == 'cls-route':
# cls_hidden = [bz, 768]
gate_score = self.gate(cls_hidden) # bz, 1
elif self.route_method == 'cls-query-route': # add cls_hiddin on query_token mean pool hidden
mean_output = torch.mean(output_x, dim=1) # bz, 768
gate_score = self.gate(mean_output+cls_hidden) # bz, 1
elif self.route_method == 'cls-cross-route':
# cls_hidden as Q, output_x as K, V calculate scaled dot-product attention between Q and K and V
# cls_hidden: bz, 768
# output_x: bz, 32, 768
Q = cls_hidden.unsqueeze(1) # bz, 1, 768
K = output_x # bz, 32, 768
V = output_x # bz, 32, 768
# scaled dot-product attention
QK = torch.matmul(Q, K.transpose(-1, -2)) / (K.size(-1) ** 0.5) # bz, 1, 32
QK = F.softmax(QK, dim=-1) # bz, 1, 32
gate_score = torch.matmul(QK, V) # bz, 1, 768
gate_score = gate_score.squeeze(1) # bz, 768
gate_score = self.gate(gate_score) # bz, 1
return gate_score
def forward_cls_route(self, x, beam_scores, expert_route, cls_hidden):
attention_mask = torch.ones(x.shape[0], x.shape[1]).to(x.device)
x_masked = x * attention_mask.unsqueeze(-1) # torch.Size([bz, 32, 768])
def forward_expert(input_x, expert_idx):
output_x = self.experts[expert_idx].forward(input_x)
return output_x
outputs = list()
logits_gate_lst = list()
for expert_idx in range(self.num_experts):
output_x = forward_expert(x_masked, expert_idx) # bz, 32, 768
gate_score = self.calculate_cls_gate_score(cls_hidden, output_x) # bz, 1
logits_gate_lst.append(gate_score)
outputs.append(output_x.unsqueeze(0))
candidate_output_raw = torch.cat(outputs) # torch.Size([num_expert, bz*num_beam, 32, 768])
logits_gate = torch.cat(logits_gate_lst,dim=1)# torch.Size([bz*num_beam, num_expert])
current_scores = F.softmax(logits_gate, dim=-1) # torch.Size([bz*num_beam, num_experts])
current_scores_log = torch.log(current_scores) # 取log之后可以直接相加
# importance loss
importance_loss = self._importance_auxiliary_loss(current_scores)
batch_size, num_tokens = x.shape[0], x.shape[1] # bz*num_beam
beam_scores, expert_route, beam_idx = self.beam_search(current_scores_log, beam_scores, expert_route, batch_size)
# beam_scores torch.Size([bz*num_beam])
# expert_route torch.Size([bz*num_beam, layer_n])
current_select_expert = expert_route[:,-1]
# current_select_expert torch.Size([bz*num_beam, 1])
if self.layer_judge == 'first':
replicated_tensor = candidate_output_raw.unsqueeze(2).expand(self.num_experts, batch_size, self.num_beams, num_tokens, self.hidden_size)
candidate_output_raw = replicated_tensor.contiguous().view(self.num_experts, -1, num_tokens, self.hidden_size) # [bz*num_beams, 32,768]
current_scores_t = current_scores.unsqueeze(1).expand(batch_size, self.num_beams, self.num_experts)
current_scores = current_scores_t.contiguous().view(-1, self.num_experts) # [bz*num_beams, num_experts]
candidate_output = candidate_output_raw.permute(1, 0, 2, 3)[beam_idx] # torch.Size([8, 2, 32, 768])
expert_select_matrix = F.one_hot(current_select_expert, self.num_experts)
if self.weight_type == 'ffn_prob':
tmp_prob = current_scores[beam_idx] * expert_select_matrix
output = candidate_output * tmp_prob.unsqueeze(-1).unsqueeze(-1)
else:
output = candidate_output * expert_select_matrix.unsqueeze(-1).unsqueeze(-1)
final_output = torch.sum(output, dim=1)
return final_output, beam_scores, expert_route, beam_idx, importance_loss
def forward(self, x, attention_mask, beam_scores, expert_route, cls_hidden=None):
"""
if first_layer: x [bz, 32, 768]
else: x [bz*num_beams, 32, 768]
@ -257,9 +337,8 @@ class RouteMoELayer(nn.Module):
candidate_output, beam_scores, expert_route, beam_idx, importance_loss = self.forward_pre_route(x, beam_scores, expert_route, use_log=True)
elif self.route_method in ['post-route', 'post-route-dp']:
candidate_output, beam_scores, expert_route, beam_idx, importance_loss = self.forward_post_route(x, beam_scores, expert_route, use_log=True)
elif self.route_method in ['cls-route', 'cls-query-route', 'cls-cross-route']:
candidate_output, beam_scores, expert_route, beam_idx, importance_loss = self.forward_cls_route(x, beam_scores, expert_route, cls_hidden)
else:
assert("route method should in pre-route, post-route, post-route-dp")
return candidate_output, beam_scores, expert_route, beam_idx, importance_loss

View File

@ -142,7 +142,32 @@ class UniRouteMoELayer(nn.Module):
# import pdb;pdb.set_trace()
return candidate_output, beam_scores, expert_route, beam_idx, importance_loss
def forward_post_route_uni(self, x, beam_scores, expert_route, use_log=True):
def calculate_cls_gate_score(self, cls_hidden, output_x):
if self.route_method == 'uni-cls-route':
# cls_hidden = [bz, 768]
gate_score = self.gate(cls_hidden) # bz, 1
elif self.route_method == 'uni-cls-query-route': # add cls_hiddin on query_token mean pool hidden
mean_output = torch.mean(output_x, dim=1) # bz, 768
gate_score = self.gate(mean_output+cls_hidden) # bz, 1
elif self.route_method == 'uni-cls-cross-route':
# cls_hidden as Q, output_x as K, V calculate scaled dot-product attention between Q and K and V
# cls_hidden: bz, 768
# output_x: bz, 32, 768
Q = cls_hidden.unsqueeze(1) # bz, 1, 768
K = output_x # bz, 32, 768
V = output_x # bz, 32, 768
# scaled dot-product attention
QK = torch.matmul(Q, K.transpose(-1, -2)) / (K.size(-1) ** 0.5) # bz, 1, 32
QK = F.softmax(QK, dim=-1) # bz, 1, 32
gate_score = torch.matmul(QK, V) # bz, 1, 768
gate_score = gate_score.squeeze(1) # bz, 768
gate_score = self.gate(gate_score) # bz, 1
return gate_score
def forward_route_uni(self, x, beam_scores, expert_route, use_log=True, cls_hidden=None):
if beam_scores == None:
batch_size = x.shape[0]
@ -165,8 +190,14 @@ class UniRouteMoELayer(nn.Module):
logits_gate_lst = list()
for expert_idx in range(self.num_route_experts): # num_expert-1
output_x = forward_expert(x_masked, expert_idx)
output_x_aver = torch.mean(output_x, dim=1)
gate_score = self.gate(output_x_aver)
if self.route_method == 'post-route-uni':
output_x_aver = torch.mean(output_x, dim=1)
gate_score = self.gate(output_x_aver)
elif self.route_method in ['uni-cls-route', 'uni-cls-query-route', 'uni-cls-cross-route'] and cls_hidden is not None:
gate_score = self.calculate_cls_gate_score(cls_hidden, output_x)
logits_gate_lst.append(gate_score)
outputs.append(output_x.unsqueeze(0))
@ -198,8 +229,6 @@ class UniRouteMoELayer(nn.Module):
output = candidate_output * expert_select_matrix.unsqueeze(-1).unsqueeze(-1)
experts_output = torch.sum(output, dim=1) # [bz*num_beams-1, 32, 768]
# import pdb; pdb.set_trace()
####################
### universal expert
####################
@ -219,15 +248,15 @@ class UniRouteMoELayer(nn.Module):
return final_output, beam_scores, expert_route, beam_idx, importance_loss
def forward(self, x, attention_mask, beam_scores, expert_route, use_log=True):
def forward(self, x, attention_mask, beam_scores, expert_route, cls_hidden):
"""
if first_layer: x [bz, 32, 768]
else: x [bz*num_beams, 32, 768]
"""
if self.route_method == 'pre-route-uni':
candidate_output, beam_scores, expert_route, beam_idx, importance_loss = self.forward_pre_route(x, beam_scores, expert_route, use_log=True)
elif self.route_method in ['post-route-uni']:
candidate_output, beam_scores, expert_route, beam_idx, importance_loss = self.forward_post_route_uni(x, beam_scores, expert_route, use_log=True)
elif self.route_method in ['post-route-uni', 'uni-cls-route', 'uni-cls-query-route', 'uni-cls-cross-route']:
candidate_output, beam_scores, expert_route, beam_idx, importance_loss = self.forward_route_uni(x, beam_scores, expert_route, use_log=True, cls_hidden=cls_hidden)
return candidate_output, beam_scores, expert_route, beam_idx, importance_loss