mirror of
https://github.com/Vision-CAIR/MiniGPT-4.git
synced 2025-04-07 19:40:45 +00:00
0329 add cls token
This commit is contained in:
parent
5bec4d0608
commit
2057032a63
1374
minigpt4/models/QformerRouteMoECLS.py
Normal file
1374
minigpt4/models/QformerRouteMoECLS.py
Normal file
File diff suppressed because it is too large
Load Diff
1365
minigpt4/models/QformerRouteMoECLSLN.py
Normal file
1365
minigpt4/models/QformerRouteMoECLSLN.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -517,15 +517,6 @@ class BertLayer(nn.Module):
|
||||
outputs + cross_attention_outputs[1:-1]
|
||||
) # add cross attentions if we output attention weights
|
||||
|
||||
# add moe query ffn
|
||||
# query_attention_output size: [bz, query_length+seq_len, 768]
|
||||
# attention_mask size: [bz, 1, 1, query_length+seq_len]
|
||||
moe_ffn_attention_input = query_attention_output[:, :query_length, :]
|
||||
moe_ffn_attention_mask = attention_mask.squeeze(dim=1).squeeze(dim=1)[:, :query_length]
|
||||
layer_output = self.feed_forward_query_moe(moe_ffn_attention_input, moe_ffn_attention_mask, beam_scores, expert_route)
|
||||
# layer_output = (layer_output, beam_scores, expert_route, beam_idx, importance_loss)
|
||||
import pdb; pdb.set_trace() # 0107test
|
||||
|
||||
if attention_output.shape[1] > query_length: # have text input in Qformer
|
||||
layer_output_text = apply_chunking_to_forward(
|
||||
self.feed_forward_chunk,
|
||||
@ -533,6 +524,19 @@ class BertLayer(nn.Module):
|
||||
self.seq_len_dim,
|
||||
attention_output[:, query_length:, :],
|
||||
)
|
||||
cls_hidden = layer_output_text[0][:, 0, :] # [bz, hidden_size]
|
||||
|
||||
# add moe query ffn
|
||||
# query_attention_output size: [bz, query_length+seq_len, 768]
|
||||
# attention_mask size: [bz, 1, 1, query_length+seq_len]
|
||||
moe_ffn_attention_input = query_attention_output[:, :query_length, :]
|
||||
moe_ffn_attention_mask = attention_mask.squeeze(dim=1).squeeze(dim=1)[:, :query_length]
|
||||
layer_output = self.feed_forward_query_moe(moe_ffn_attention_input, moe_ffn_attention_mask, beam_scores, expert_route, cls_hidden)
|
||||
# layer_output = (layer_output, beam_scores, expert_route, beam_idx, importance_loss)
|
||||
# import pdb; pdb.set_trace() # 0107test
|
||||
|
||||
if attention_output.shape[1] > query_length: # have text input in Qformer
|
||||
|
||||
if self.layer_judge == 'first' and self.num_beams>1:
|
||||
# if layer_output[0].shape[0] == layer_output_text.shape[0]*self.num_beams and self.num_beams>1:
|
||||
# adjust the dimension of layer_output_text to bz*num_beams
|
||||
@ -622,13 +626,13 @@ class BertLayer(nn.Module):
|
||||
layer_output = self.output(intermediate_output, attention_output)
|
||||
return layer_output
|
||||
|
||||
def feed_forward_query_moe(self, attention_output, expert_attention_mask, beam_scores, expert_route):
|
||||
def feed_forward_query_moe(self, attention_output, expert_attention_mask, beam_scores, expert_route, cls_hidden):
|
||||
if not self.use_experts:
|
||||
layer_output = self.experts(attention_output)
|
||||
return layer_output, None, None, None, 0.0
|
||||
|
||||
layer_output, beam_scores, expert_route, beam_idx, importance_loss = self.experts(
|
||||
attention_output, expert_attention_mask, beam_scores, expert_route
|
||||
attention_output, expert_attention_mask, beam_scores, expert_route, cls_hidden
|
||||
)
|
||||
return layer_output, beam_scores, expert_route, beam_idx, importance_loss
|
||||
|
||||
|
@ -523,12 +523,22 @@ class BertLayer(nn.Module):
|
||||
outputs + cross_attention_outputs[1:-1]
|
||||
) # add cross attentions if we output attention weights
|
||||
|
||||
|
||||
if attention_output.shape[1] > query_length: # have text input in Qformer
|
||||
layer_output_text = apply_chunking_to_forward(
|
||||
self.feed_forward_chunk,
|
||||
self.chunk_size_feed_forward,
|
||||
self.seq_len_dim,
|
||||
attention_output[:, query_length:, :],
|
||||
)
|
||||
cls_hidden = layer_output_text[0][:, 0, :] # [bz, hidden_size]
|
||||
|
||||
# add moe query ffn
|
||||
# query_attention_output size: [bz, query_length+seq_len, 768]
|
||||
# attention_mask size: [bz, 1, 1, query_length+seq_len]
|
||||
moe_ffn_attention_input = query_attention_output[:, :query_length, :]
|
||||
moe_ffn_attention_mask = attention_mask.squeeze(dim=1).squeeze(dim=1)[:, :query_length]
|
||||
layer_output = self.feed_forward_query_moe(moe_ffn_attention_input, moe_ffn_attention_mask, beam_scores, expert_route)
|
||||
layer_output = self.feed_forward_query_moe(moe_ffn_attention_input, moe_ffn_attention_mask, beam_scores, expert_route, cls_hidden)
|
||||
# layer_output = (layer_output, beam_scores, expert_route, beam_idx, importance_loss)
|
||||
# import pdb; pdb.set_trace() # 0107test
|
||||
|
||||
@ -628,14 +638,14 @@ class BertLayer(nn.Module):
|
||||
layer_output = self.output(intermediate_output, attention_output)
|
||||
return layer_output
|
||||
|
||||
def feed_forward_query_moe(self, attention_output, expert_attention_mask, beam_scores, expert_route):
|
||||
def feed_forward_query_moe(self, attention_output, expert_attention_mask, beam_scores, expert_route, cls_hidden):
|
||||
if not self.use_experts:
|
||||
hidden_states = self.experts(attention_output)
|
||||
layer_output = self.expert_ln(hidden_states + attention_output)
|
||||
return layer_output, None, None, None, 0.0
|
||||
|
||||
hidden_states, beam_scores, expert_route, beam_idx, importance_loss = self.experts(
|
||||
attention_output, expert_attention_mask, beam_scores, expert_route
|
||||
attention_output, expert_attention_mask, beam_scores, expert_route, cls_hidden
|
||||
)
|
||||
if hidden_states.shape[0]==attention_output.shape[0]*self.num_beams and self.num_beams>1:
|
||||
attention_output = self.adjust_hidden_states_by_num_beams(attention_output)
|
||||
|
@ -27,6 +27,8 @@ from minigpt4.models.QformerRouteMoE import BertMoERouteLMHeadModel
|
||||
from minigpt4.models.QformerRouteMoELN import BertMoERouteLMHeadModelLNIn
|
||||
from minigpt4.models.QformerRouteMoELNUni import BertMoERouteLMHeadModelLNInUniversal
|
||||
from minigpt4.models.QformerRouteMoEUni import BertMoERouteLMHeadModelUniversal
|
||||
from minigpt4.models.QformerRouteMoECLS import BertMoECLSRouteLMHeadModel
|
||||
from minigpt4.models.QformerRouteMoECLSLN import BertMoECLSRouteLMHeadModelLNIn
|
||||
from minigpt4.models.eva_vit import create_eva_vit_g
|
||||
from transformers import BertTokenizer
|
||||
from peft import (
|
||||
@ -99,6 +101,38 @@ class Blip2Base(BaseModel):
|
||||
return RouteMoEQformer, query_tokens
|
||||
|
||||
|
||||
@classmethod
|
||||
def init_RouteCLSMoEQformer(cls, num_query_token, vision_width, moebert_expert_num, moebert_num_beams, route_method, moe_weight_type, cross_attention_freq=2, ln_position="out"):
|
||||
moe_encoder_config = BertConfig.from_pretrained("/mnt/pfs-guan-ssai/nlu/wanghanzi/models/bert-base-uncased")
|
||||
|
||||
moe_encoder_config.encoder_width = vision_width
|
||||
# insert cross-attention layer every other block
|
||||
moe_encoder_config.add_cross_attention = True
|
||||
moe_encoder_config.cross_attention_freq = cross_attention_freq
|
||||
moe_encoder_config.query_length = num_query_token
|
||||
|
||||
moe_encoder_config.moebert_expert_num = moebert_expert_num
|
||||
moe_encoder_config.moebert_num_beams = moebert_num_beams
|
||||
moe_encoder_config.route_method = route_method
|
||||
moe_encoder_config.moe_weight_type = moe_weight_type
|
||||
|
||||
if ln_position == "out":
|
||||
RouteMoEQformer = BertMoECLSRouteLMHeadModel.from_pretrained(
|
||||
"/mnt/pfs-guan-ssai/nlu/wanghanzi/models/bert-base-uncased", config=moe_encoder_config
|
||||
)
|
||||
elif ln_position == "in":
|
||||
# need to adjust
|
||||
RouteMoEQformer = BertMoECLSRouteLMHeadModelLNIn.from_pretrained(
|
||||
"/mnt/pfs-guan-ssai/nlu/wanghanzi/models/bert-base-uncased", config=moe_encoder_config
|
||||
)
|
||||
query_tokens = nn.Parameter(
|
||||
torch.zeros(1, num_query_token, moe_encoder_config.hidden_size)
|
||||
)
|
||||
query_tokens.data.normal_(mean=0.0, std=moe_encoder_config.initializer_range)
|
||||
|
||||
return RouteMoEQformer, query_tokens
|
||||
|
||||
|
||||
@classmethod
|
||||
def init_RouteMoEQformer(cls, num_query_token, vision_width, moebert_expert_num, moebert_num_beams, route_method, moe_weight_type, cross_attention_freq=2, ln_position="out"):
|
||||
moe_encoder_config = BertConfig.from_pretrained("/mnt/pfs-guan-ssai/nlu/wanghanzi/models/bert-base-uncased")
|
||||
|
@ -85,7 +85,7 @@ class Blip2VicunaInstruct(Blip2Base):
|
||||
)
|
||||
|
||||
print('Initing & Loading Qformer')
|
||||
if general_version in ['naive_moe', 'route_moe', 'uni_route_moe']:
|
||||
if general_version in ['naive_moe', 'route_moe', 'uni_route_moe', 'cls_route_moe']:
|
||||
if general_version == 'naive_moe':
|
||||
self.Qformer, self.query_tokens = self.init_QformerMoE(
|
||||
num_query_token=num_query_token,
|
||||
@ -121,6 +121,17 @@ class Blip2VicunaInstruct(Blip2Base):
|
||||
cross_attention_freq=2,
|
||||
ln_position=ln_position,
|
||||
)
|
||||
elif general_version == 'cls_route_moe':
|
||||
self.Qformer, self.query_tokens = self.init_RouteCLSMoEQformer(
|
||||
num_query_token=num_query_token,
|
||||
vision_width=self.visual_encoder.num_features,
|
||||
moebert_expert_num=moebert_expert_num,
|
||||
moebert_num_beams=moebert_num_beams,
|
||||
route_method=moebert_route_method,
|
||||
moe_weight_type=moe_weight_type,
|
||||
cross_attention_freq=2,
|
||||
ln_position=ln_position,
|
||||
)
|
||||
|
||||
elif general_version == 'base':
|
||||
self.Qformer, self.query_tokens = self.init_Qformer(
|
||||
|
265
minigpt4/models/moe/backup/route_moe_layer_backup.py
Normal file
265
minigpt4/models/moe/backup/route_moe_layer_backup.py
Normal file
@ -0,0 +1,265 @@
|
||||
import copy
|
||||
import pickle
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
class RouteMoELayer(nn.Module):
|
||||
def __init__(self, hidden_size, expert, num_experts, num_beams=2, layer_judge=None, route_method="pre-route", weight_type="ffn_prob"):
|
||||
# remove hash list
|
||||
nn.Module.__init__(self)
|
||||
self.num_experts = num_experts
|
||||
self.experts = nn.ModuleList([copy.deepcopy(expert) for i in range(num_experts)])
|
||||
self.num_beams = num_beams
|
||||
self.hidden_size = hidden_size
|
||||
self.layer_judge = layer_judge
|
||||
self.weight_type = weight_type
|
||||
|
||||
self.route_method = route_method
|
||||
if self.route_method == "pre-route":
|
||||
self.gate = nn.Linear(hidden_size, num_experts, bias=False).float()
|
||||
elif self.route_method in ["post-route", "post-route-dp"]:
|
||||
gate = nn.Linear(hidden_size, 1, bias=False).float()
|
||||
self.gate = gate
|
||||
# self.gates = nn.ModuleList([copy.deepcopy(gate) for i in range(num_experts)])
|
||||
|
||||
def _importance_auxiliary_loss(self, prob_gate):
|
||||
# From VMOE
|
||||
# _importance_auxiliary_loss
|
||||
axis = tuple(range(prob_gate.ndim - 1)) # All except last.
|
||||
importance_per_expert = torch.sum(prob_gate, dim=axis)
|
||||
std_importance_per_expert = torch.std(importance_per_expert)
|
||||
mean_importance_per_expert = torch.mean(importance_per_expert)
|
||||
# Compute coefficient of variation (i.e. std/mean) squared.
|
||||
return (std_importance_per_expert / mean_importance_per_expert)**2
|
||||
|
||||
|
||||
def forward_gate(self, x):
|
||||
"""
|
||||
x : torch.Size([bz*num_beams, 32, 768]) or torch.Size([bz, 32, 768])
|
||||
prob_gate : torch.Size([bz*num_beams, num_experts]) or torch.Size([bz, num_experts])
|
||||
"""
|
||||
attention_mask = torch.ones(x.shape[0], x.shape[1]).to(x.device)
|
||||
x_masked = x * attention_mask.unsqueeze(-1) # torch.Size([bz*num_beams, 32, 768])
|
||||
# x_average = x_masked.sum(1) / attention_mask.unsqueeze(-1).sum(1) # torch.Size([bz*num_beams, 768])
|
||||
x_average = torch.mean(x_masked, dim=1) # torch.Size([bz*num_beams, 768])
|
||||
logits_gate = self.gate(x_average) # torch.Size([bz*num_beams, num_experts])
|
||||
prob_gate = F.softmax(logits_gate, dim=-1) # torch.Size([bz*num_beams, num_experts])
|
||||
return prob_gate
|
||||
|
||||
def dp_search(self, current_scores_log, beam_scores, expert_route, batch_size):
|
||||
if self.layer_judge=='first' and self.route_method in ['post-route-dp']:
|
||||
# current_scores_log torch.Size([bz, num_experts])
|
||||
assert beam_scores==None and expert_route==None
|
||||
current_scores = torch.exp(current_scores_log)
|
||||
topk_values, gate = torch.topk(current_scores, self.num_beams, dim=1) # gate, 每个样本被分配的expert: torch.Size([bz, topk])
|
||||
beam_scores = topk_values.view(self.num_beams * batch_size) # torch.Size([bz * num_beams])
|
||||
expert_route = gate.view(self.num_beams * batch_size).unsqueeze(1) # torch.Size([bz * num_beams,1])
|
||||
beam_idx = torch.tensor(range(self.num_beams * batch_size))
|
||||
|
||||
else:
|
||||
batch_size = int(batch_size // self.num_beams)
|
||||
next_scores_raw = current_scores_log + torch.log(beam_scores).unsqueeze(1) # torch.Size([4*3, 5]) # 取log 之后,可以直接相加概率
|
||||
next_scores_exp = torch.exp(next_scores_raw)
|
||||
|
||||
next_scores_raw, next_experts_raw = torch.topk(next_scores_exp, 1, dim=1, largest=True, sorted=True)
|
||||
next_scores = next_scores_raw.view(batch_size, self.num_beams)
|
||||
next_experts = next_experts_raw.view(batch_size, self.num_beams)
|
||||
# next_scores, next_experts = torch.topk(current_scores_log, 1, dim=1, largest=True, sorted=True) # equal 等价
|
||||
# next_scores torch.Size([bz * num_beams, 1])
|
||||
# next_tokens torch.Size([bz * num_beams, 1])
|
||||
|
||||
next_batch_beam = list()
|
||||
for batch_idx in range(batch_size):
|
||||
next_sent_beam = list()
|
||||
expert_id = next_experts[batch_idx]
|
||||
expert_score = next_scores[batch_idx]
|
||||
values, index = torch.topk(expert_score, self.num_beams, dim=0, largest=True, sorted=True)
|
||||
for i in range(self.num_beams):
|
||||
beam_id = index[i].item()
|
||||
ex_id = expert_id[beam_id].item()
|
||||
effective_beam_id = batch_idx*self.num_beams + beam_id
|
||||
next_sent_beam.append((values[i], ex_id, effective_beam_id))
|
||||
next_batch_beam.extend(next_sent_beam)
|
||||
|
||||
beam_scores = beam_scores.new([x[0] for x in next_batch_beam])
|
||||
beam_experts = expert_route[:,-1].new([x[1] for x in next_batch_beam])
|
||||
beam_idx = expert_route[:,-1].new([x[2] for x in next_batch_beam])
|
||||
pre_route = expert_route[beam_idx,:]
|
||||
expert_route = torch.cat([pre_route, beam_experts.unsqueeze(1)], dim=-1)
|
||||
|
||||
return beam_scores, expert_route, beam_idx
|
||||
|
||||
def beam_search(self, current_scores_log, beam_scores, expert_route, batch_size):
|
||||
if self.layer_judge=='first' and self.route_method in ['pre-route', 'post-route']:
|
||||
# current_scores_log torch.Size([bz, num_experts])
|
||||
assert beam_scores==None and expert_route==None
|
||||
current_scores = torch.exp(current_scores_log)
|
||||
topk_values, gate = torch.topk(current_scores, self.num_beams, dim=1) # gate, 每个样本被分配的expert: torch.Size([bz, topk])
|
||||
beam_scores = topk_values.view(self.num_beams * batch_size) # torch.Size([bz * num_beams])
|
||||
expert_route = gate.view(self.num_beams * batch_size).unsqueeze(1) # torch.Size([bz * num_beams,1])
|
||||
beam_idx = torch.tensor(range(self.num_beams * batch_size))
|
||||
|
||||
else:
|
||||
batch_size = int(batch_size // self.num_beams)
|
||||
next_scores_raw = current_scores_log + torch.log(beam_scores).unsqueeze(1) # torch.Size([4*3, 5]) # 取log 之后,可以直接相加概率
|
||||
next_scores_exp = torch.exp(next_scores_raw)
|
||||
|
||||
next_scores_raw1 = next_scores_exp.view(
|
||||
batch_size, self.num_beams * self.num_experts
|
||||
) # torch.Size([bz, num_beams*num_experts])
|
||||
|
||||
next_scores, next_experts = torch.topk(next_scores_raw1, self.num_beams, dim=1, largest=True, sorted=True)
|
||||
# next_scores torch.Size([bz, num_beams])
|
||||
# next_tokens torch.Size([bz, num_beams])
|
||||
|
||||
next_batch_beam = list()
|
||||
for batch_idx in range(batch_size):
|
||||
next_sent_beam = list()
|
||||
for rank, (expert_id, expert_score) in enumerate(
|
||||
zip(next_experts[batch_idx], next_scores[batch_idx])
|
||||
):
|
||||
expert_id = expert_id.item()
|
||||
beam_id = expert_id // self.num_experts
|
||||
ex_id = expert_id % self.num_experts
|
||||
effective_beam_id = batch_idx*self.num_beams + beam_id
|
||||
|
||||
next_sent_beam.append((expert_score, ex_id, effective_beam_id))
|
||||
next_batch_beam.extend(next_sent_beam)
|
||||
|
||||
beam_scores = beam_scores.new([x[0] for x in next_batch_beam])
|
||||
beam_experts = expert_route[:,-1].new([x[1] for x in next_batch_beam])
|
||||
beam_idx = expert_route[:,-1].new([x[2] for x in next_batch_beam])
|
||||
pre_route = expert_route[beam_idx,:]
|
||||
expert_route = torch.cat([pre_route, beam_experts.unsqueeze(1)], dim=-1)
|
||||
|
||||
return beam_scores, expert_route, beam_idx
|
||||
|
||||
def forward_expert_ffn(self, x, expert_select, current_scores):
|
||||
"""
|
||||
x_repeat : [bz*num_beams, 32,768]
|
||||
expert_select : [bz*num_beams]
|
||||
current_scores : [bz*num_beams, num_experts] / [bz, num_experts]
|
||||
"""
|
||||
# add_1228 l2_normalization
|
||||
# normalized_tensor = torch.nn.functional.normalize(current_scores, p=2, dim=0) # L2 Normalization torch.Size([bz, topk])
|
||||
# tmp_prob = normalized_tensor.unsqueeze(-1).unsqueeze(-1)
|
||||
# import pdb;pdb.set_trace()
|
||||
outputs = list()
|
||||
for i in range(self.num_experts):
|
||||
output_x = self.experts[i].forward(x)
|
||||
outputs.append(output_x.unsqueeze(1))
|
||||
candidate_output = torch.cat(outputs, dim=1)
|
||||
expert_select_matrix = F.one_hot(expert_select, self.num_experts)
|
||||
if self.weight_type == 'ffn_prob':
|
||||
tmp_prob = current_scores * expert_select_matrix
|
||||
candidate_output = candidate_output * tmp_prob.unsqueeze(-1).unsqueeze(-1)
|
||||
else:
|
||||
candidate_output = candidate_output * expert_select_matrix.unsqueeze(-1).unsqueeze(-1)
|
||||
output = torch.sum(candidate_output, dim=1)
|
||||
# import pdb;pdb.set_trace()
|
||||
return output # torch.Size([bz*num_beams, 32, 768])
|
||||
|
||||
def forward_pre_route(self, x, beam_scores, expert_route, use_log=True):
|
||||
|
||||
current_scores = self.forward_gate(x) # [bz, num_beams] / [bz*num_beams, num_beams]
|
||||
|
||||
importance_loss = self._importance_auxiliary_loss(current_scores)
|
||||
|
||||
if use_log:
|
||||
current_scores_log = torch.log(current_scores) # 取log之后可以直接相加
|
||||
else:
|
||||
current_scores_log = current_scores
|
||||
# import pdb;pdb.set_trace()
|
||||
batch_size, num_tokens = x.shape[0], x.shape[1]
|
||||
beam_scores, expert_route, beam_idx = self.beam_search(current_scores_log, beam_scores, expert_route, batch_size)
|
||||
current_expert_select = expert_route[:,-1]
|
||||
|
||||
if self.layer_judge=='first': # expand first dim to batch_size * num_beams
|
||||
replicated_tensor = x.unsqueeze(1).expand(batch_size, self.num_beams, num_tokens, self.hidden_size)
|
||||
x = replicated_tensor.contiguous().view(-1, num_tokens, self.hidden_size) # [bz*num_beams, 32,768]
|
||||
current_scores_t = current_scores.unsqueeze(1).expand(batch_size, self.num_beams, self.num_experts)
|
||||
current_scores = current_scores_t.contiguous().view(-1, self.num_experts) # [bz*num_beams, num_experts]
|
||||
|
||||
input_x = x[beam_idx]
|
||||
candidate_output = self.forward_expert_ffn(input_x, current_expert_select, current_scores) # [bz*num_beams, 32,768]
|
||||
# import pdb;pdb.set_trace()
|
||||
return candidate_output, beam_scores, expert_route, beam_idx, importance_loss
|
||||
|
||||
def forward_post_route(self, x, beam_scores, expert_route, use_log=True):
|
||||
|
||||
attention_mask = torch.ones(x.shape[0], x.shape[1]).to(x.device)
|
||||
x_masked = x * attention_mask.unsqueeze(-1) # torch.Size([bz, 32, 768])
|
||||
|
||||
def forward_expert(input_x, expert_idx):
|
||||
output_x = self.experts[expert_idx].forward(input_x)
|
||||
return output_x
|
||||
|
||||
outputs = list()
|
||||
logits_gate_lst = list()
|
||||
for expert_idx in range(self.num_experts):
|
||||
output_x = forward_expert(x_masked, expert_idx)
|
||||
# output_x_aver = output_x.sum(1) / attention_mask.unsqueeze(-1).sum(1) # torch.Size([bz*num_beam, 768])
|
||||
output_x_aver = torch.mean(output_x, dim=1)
|
||||
# gate_score = self.gates[expert_idx](output_x_aver)
|
||||
gate_score = self.gate(output_x_aver)
|
||||
logits_gate_lst.append(gate_score)
|
||||
outputs.append(output_x.unsqueeze(0))
|
||||
|
||||
candidate_output_raw = torch.cat(outputs) # torch.Size([num_expert, bz*num_beam, 32, 768])
|
||||
logits_gate = torch.cat(logits_gate_lst,dim=1)# torch.Size([bz*num_beam, num_expert])
|
||||
current_scores = F.softmax(logits_gate, dim=-1) # torch.Size([bz*num_beam, num_experts])
|
||||
|
||||
if use_log:
|
||||
current_scores_log = torch.log(current_scores) # 取log之后可以直接相加
|
||||
else:
|
||||
current_scores_log = current_scores
|
||||
|
||||
# importance loss
|
||||
importance_loss = self._importance_auxiliary_loss(current_scores)
|
||||
|
||||
batch_size, num_tokens = x.shape[0], x.shape[1] # bz*num_beam
|
||||
|
||||
if self.route_method == 'post-route':
|
||||
beam_scores, expert_route, beam_idx = self.beam_search(current_scores_log, beam_scores, expert_route, batch_size)
|
||||
elif self.route_method == 'post-route-dp':
|
||||
beam_scores, expert_route, beam_idx = self.dp_search(current_scores_log, beam_scores, expert_route, batch_size)
|
||||
|
||||
# beam_scores torch.Size([bz*num_beam])
|
||||
# expert_route torch.Size([bz*num_beam, layer_n])
|
||||
current_select_expert = expert_route[:,-1]
|
||||
# current_select_expert torch.Size([bz*num_beam, 1])
|
||||
|
||||
if self.layer_judge == 'first':
|
||||
replicated_tensor = candidate_output_raw.unsqueeze(2).expand(self.num_experts, batch_size, self.num_beams, num_tokens, self.hidden_size)
|
||||
candidate_output_raw = replicated_tensor.contiguous().view(self.num_experts, -1, num_tokens, self.hidden_size) # [bz*num_beams, 32,768]
|
||||
current_scores_t = current_scores.unsqueeze(1).expand(batch_size, self.num_beams, self.num_experts)
|
||||
current_scores = current_scores_t.contiguous().view(-1, self.num_experts) # [bz*num_beams, num_experts]
|
||||
|
||||
candidate_output = candidate_output_raw.permute(1, 0, 2, 3)[beam_idx] # torch.Size([8, 2, 32, 768])
|
||||
expert_select_matrix = F.one_hot(current_select_expert, self.num_experts)
|
||||
if self.weight_type == 'ffn_prob':
|
||||
tmp_prob = current_scores[beam_idx] * expert_select_matrix
|
||||
output = candidate_output * tmp_prob.unsqueeze(-1).unsqueeze(-1)
|
||||
else:
|
||||
output = candidate_output * expert_select_matrix.unsqueeze(-1).unsqueeze(-1)
|
||||
final_output = torch.sum(output, dim=1)
|
||||
|
||||
return final_output, beam_scores, expert_route, beam_idx, importance_loss
|
||||
|
||||
def forward(self, x, attention_mask, beam_scores, expert_route, use_log=True):
|
||||
"""
|
||||
if first_layer: x [bz, 32, 768]
|
||||
else: x [bz*num_beams, 32, 768]
|
||||
|
||||
"""
|
||||
if self.route_method == 'pre-route':
|
||||
candidate_output, beam_scores, expert_route, beam_idx, importance_loss = self.forward_pre_route(x, beam_scores, expert_route, use_log=True)
|
||||
elif self.route_method in ['post-route', 'post-route-dp']:
|
||||
candidate_output, beam_scores, expert_route, beam_idx, importance_loss = self.forward_post_route(x, beam_scores, expert_route, use_log=True)
|
||||
else:
|
||||
assert("route method should in pre-route, post-route, post-route-dp")
|
||||
return candidate_output, beam_scores, expert_route, beam_idx, importance_loss
|
||||
|
||||
|
||||
|
@ -22,6 +22,9 @@ class RouteMoELayer(nn.Module):
|
||||
gate = nn.Linear(hidden_size, 1, bias=False).float()
|
||||
self.gate = gate
|
||||
# self.gates = nn.ModuleList([copy.deepcopy(gate) for i in range(num_experts)])
|
||||
elif self.route_method in ['cls-route', 'cls-query-route', 'cls-cross-route']:
|
||||
self.gate = nn.Linear(hidden_size, 1, bias=False).float()
|
||||
|
||||
|
||||
def _importance_auxiliary_loss(self, prob_gate):
|
||||
# From VMOE
|
||||
@ -33,7 +36,6 @@ class RouteMoELayer(nn.Module):
|
||||
# Compute coefficient of variation (i.e. std/mean) squared.
|
||||
return (std_importance_per_expert / mean_importance_per_expert)**2
|
||||
|
||||
|
||||
def forward_gate(self, x):
|
||||
"""
|
||||
x : torch.Size([bz*num_beams, 32, 768]) or torch.Size([bz, 32, 768])
|
||||
@ -91,7 +93,7 @@ class RouteMoELayer(nn.Module):
|
||||
return beam_scores, expert_route, beam_idx
|
||||
|
||||
def beam_search(self, current_scores_log, beam_scores, expert_route, batch_size):
|
||||
if self.layer_judge=='first' and self.route_method in ['pre-route', 'post-route']:
|
||||
if self.layer_judge=='first' and self.route_method in ['pre-route', 'post-route', 'cls-route', 'cls-query-route', 'cls-cross-route']:
|
||||
# current_scores_log torch.Size([bz, num_experts])
|
||||
assert beam_scores==None and expert_route==None
|
||||
current_scores = torch.exp(current_scores_log)
|
||||
@ -247,7 +249,85 @@ class RouteMoELayer(nn.Module):
|
||||
|
||||
return final_output, beam_scores, expert_route, beam_idx, importance_loss
|
||||
|
||||
def forward(self, x, attention_mask, beam_scores, expert_route, use_log=True):
|
||||
def calculate_cls_gate_score(self, cls_hidden, output_x):
|
||||
|
||||
if self.route_method == 'cls-route':
|
||||
# cls_hidden = [bz, 768]
|
||||
gate_score = self.gate(cls_hidden) # bz, 1
|
||||
elif self.route_method == 'cls-query-route': # add cls_hiddin on query_token mean pool hidden
|
||||
mean_output = torch.mean(output_x, dim=1) # bz, 768
|
||||
gate_score = self.gate(mean_output+cls_hidden) # bz, 1
|
||||
elif self.route_method == 'cls-cross-route':
|
||||
# cls_hidden as Q, output_x as K, V calculate scaled dot-product attention between Q and K and V
|
||||
# cls_hidden: bz, 768
|
||||
# output_x: bz, 32, 768
|
||||
Q = cls_hidden.unsqueeze(1) # bz, 1, 768
|
||||
K = output_x # bz, 32, 768
|
||||
V = output_x # bz, 32, 768
|
||||
# scaled dot-product attention
|
||||
QK = torch.matmul(Q, K.transpose(-1, -2)) / (K.size(-1) ** 0.5) # bz, 1, 32
|
||||
QK = F.softmax(QK, dim=-1) # bz, 1, 32
|
||||
gate_score = torch.matmul(QK, V) # bz, 1, 768
|
||||
gate_score = gate_score.squeeze(1) # bz, 768
|
||||
gate_score = self.gate(gate_score) # bz, 1
|
||||
return gate_score
|
||||
|
||||
|
||||
def forward_cls_route(self, x, beam_scores, expert_route, cls_hidden):
|
||||
attention_mask = torch.ones(x.shape[0], x.shape[1]).to(x.device)
|
||||
x_masked = x * attention_mask.unsqueeze(-1) # torch.Size([bz, 32, 768])
|
||||
|
||||
def forward_expert(input_x, expert_idx):
|
||||
output_x = self.experts[expert_idx].forward(input_x)
|
||||
return output_x
|
||||
|
||||
outputs = list()
|
||||
logits_gate_lst = list()
|
||||
for expert_idx in range(self.num_experts):
|
||||
output_x = forward_expert(x_masked, expert_idx) # bz, 32, 768
|
||||
|
||||
gate_score = self.calculate_cls_gate_score(cls_hidden, output_x) # bz, 1
|
||||
|
||||
logits_gate_lst.append(gate_score)
|
||||
outputs.append(output_x.unsqueeze(0))
|
||||
|
||||
candidate_output_raw = torch.cat(outputs) # torch.Size([num_expert, bz*num_beam, 32, 768])
|
||||
logits_gate = torch.cat(logits_gate_lst,dim=1)# torch.Size([bz*num_beam, num_expert])
|
||||
current_scores = F.softmax(logits_gate, dim=-1) # torch.Size([bz*num_beam, num_experts])
|
||||
|
||||
current_scores_log = torch.log(current_scores) # 取log之后可以直接相加
|
||||
|
||||
# importance loss
|
||||
importance_loss = self._importance_auxiliary_loss(current_scores)
|
||||
|
||||
batch_size, num_tokens = x.shape[0], x.shape[1] # bz*num_beam
|
||||
|
||||
beam_scores, expert_route, beam_idx = self.beam_search(current_scores_log, beam_scores, expert_route, batch_size)
|
||||
|
||||
# beam_scores torch.Size([bz*num_beam])
|
||||
# expert_route torch.Size([bz*num_beam, layer_n])
|
||||
current_select_expert = expert_route[:,-1]
|
||||
# current_select_expert torch.Size([bz*num_beam, 1])
|
||||
|
||||
if self.layer_judge == 'first':
|
||||
replicated_tensor = candidate_output_raw.unsqueeze(2).expand(self.num_experts, batch_size, self.num_beams, num_tokens, self.hidden_size)
|
||||
candidate_output_raw = replicated_tensor.contiguous().view(self.num_experts, -1, num_tokens, self.hidden_size) # [bz*num_beams, 32,768]
|
||||
current_scores_t = current_scores.unsqueeze(1).expand(batch_size, self.num_beams, self.num_experts)
|
||||
current_scores = current_scores_t.contiguous().view(-1, self.num_experts) # [bz*num_beams, num_experts]
|
||||
|
||||
candidate_output = candidate_output_raw.permute(1, 0, 2, 3)[beam_idx] # torch.Size([8, 2, 32, 768])
|
||||
expert_select_matrix = F.one_hot(current_select_expert, self.num_experts)
|
||||
if self.weight_type == 'ffn_prob':
|
||||
tmp_prob = current_scores[beam_idx] * expert_select_matrix
|
||||
output = candidate_output * tmp_prob.unsqueeze(-1).unsqueeze(-1)
|
||||
else:
|
||||
output = candidate_output * expert_select_matrix.unsqueeze(-1).unsqueeze(-1)
|
||||
final_output = torch.sum(output, dim=1)
|
||||
|
||||
return final_output, beam_scores, expert_route, beam_idx, importance_loss
|
||||
|
||||
|
||||
def forward(self, x, attention_mask, beam_scores, expert_route, cls_hidden=None):
|
||||
"""
|
||||
if first_layer: x [bz, 32, 768]
|
||||
else: x [bz*num_beams, 32, 768]
|
||||
@ -257,9 +337,8 @@ class RouteMoELayer(nn.Module):
|
||||
candidate_output, beam_scores, expert_route, beam_idx, importance_loss = self.forward_pre_route(x, beam_scores, expert_route, use_log=True)
|
||||
elif self.route_method in ['post-route', 'post-route-dp']:
|
||||
candidate_output, beam_scores, expert_route, beam_idx, importance_loss = self.forward_post_route(x, beam_scores, expert_route, use_log=True)
|
||||
elif self.route_method in ['cls-route', 'cls-query-route', 'cls-cross-route']:
|
||||
candidate_output, beam_scores, expert_route, beam_idx, importance_loss = self.forward_cls_route(x, beam_scores, expert_route, cls_hidden)
|
||||
else:
|
||||
assert("route method should in pre-route, post-route, post-route-dp")
|
||||
return candidate_output, beam_scores, expert_route, beam_idx, importance_loss
|
||||
|
||||
|
||||
|
||||
|
@ -142,7 +142,32 @@ class UniRouteMoELayer(nn.Module):
|
||||
# import pdb;pdb.set_trace()
|
||||
return candidate_output, beam_scores, expert_route, beam_idx, importance_loss
|
||||
|
||||
def forward_post_route_uni(self, x, beam_scores, expert_route, use_log=True):
|
||||
|
||||
def calculate_cls_gate_score(self, cls_hidden, output_x):
|
||||
|
||||
if self.route_method == 'uni-cls-route':
|
||||
# cls_hidden = [bz, 768]
|
||||
gate_score = self.gate(cls_hidden) # bz, 1
|
||||
elif self.route_method == 'uni-cls-query-route': # add cls_hiddin on query_token mean pool hidden
|
||||
mean_output = torch.mean(output_x, dim=1) # bz, 768
|
||||
gate_score = self.gate(mean_output+cls_hidden) # bz, 1
|
||||
elif self.route_method == 'uni-cls-cross-route':
|
||||
# cls_hidden as Q, output_x as K, V calculate scaled dot-product attention between Q and K and V
|
||||
# cls_hidden: bz, 768
|
||||
# output_x: bz, 32, 768
|
||||
Q = cls_hidden.unsqueeze(1) # bz, 1, 768
|
||||
K = output_x # bz, 32, 768
|
||||
V = output_x # bz, 32, 768
|
||||
# scaled dot-product attention
|
||||
QK = torch.matmul(Q, K.transpose(-1, -2)) / (K.size(-1) ** 0.5) # bz, 1, 32
|
||||
QK = F.softmax(QK, dim=-1) # bz, 1, 32
|
||||
gate_score = torch.matmul(QK, V) # bz, 1, 768
|
||||
gate_score = gate_score.squeeze(1) # bz, 768
|
||||
gate_score = self.gate(gate_score) # bz, 1
|
||||
return gate_score
|
||||
|
||||
|
||||
def forward_route_uni(self, x, beam_scores, expert_route, use_log=True, cls_hidden=None):
|
||||
|
||||
if beam_scores == None:
|
||||
batch_size = x.shape[0]
|
||||
@ -165,8 +190,14 @@ class UniRouteMoELayer(nn.Module):
|
||||
logits_gate_lst = list()
|
||||
for expert_idx in range(self.num_route_experts): # num_expert-1
|
||||
output_x = forward_expert(x_masked, expert_idx)
|
||||
output_x_aver = torch.mean(output_x, dim=1)
|
||||
gate_score = self.gate(output_x_aver)
|
||||
|
||||
if self.route_method == 'post-route-uni':
|
||||
output_x_aver = torch.mean(output_x, dim=1)
|
||||
gate_score = self.gate(output_x_aver)
|
||||
|
||||
elif self.route_method in ['uni-cls-route', 'uni-cls-query-route', 'uni-cls-cross-route'] and cls_hidden is not None:
|
||||
gate_score = self.calculate_cls_gate_score(cls_hidden, output_x)
|
||||
|
||||
logits_gate_lst.append(gate_score)
|
||||
outputs.append(output_x.unsqueeze(0))
|
||||
|
||||
@ -198,8 +229,6 @@ class UniRouteMoELayer(nn.Module):
|
||||
output = candidate_output * expert_select_matrix.unsqueeze(-1).unsqueeze(-1)
|
||||
experts_output = torch.sum(output, dim=1) # [bz*num_beams-1, 32, 768]
|
||||
|
||||
# import pdb; pdb.set_trace()
|
||||
|
||||
####################
|
||||
### universal expert
|
||||
####################
|
||||
@ -219,15 +248,15 @@ class UniRouteMoELayer(nn.Module):
|
||||
|
||||
return final_output, beam_scores, expert_route, beam_idx, importance_loss
|
||||
|
||||
def forward(self, x, attention_mask, beam_scores, expert_route, use_log=True):
|
||||
def forward(self, x, attention_mask, beam_scores, expert_route, cls_hidden):
|
||||
"""
|
||||
if first_layer: x [bz, 32, 768]
|
||||
else: x [bz*num_beams, 32, 768]
|
||||
"""
|
||||
if self.route_method == 'pre-route-uni':
|
||||
candidate_output, beam_scores, expert_route, beam_idx, importance_loss = self.forward_pre_route(x, beam_scores, expert_route, use_log=True)
|
||||
elif self.route_method in ['post-route-uni']:
|
||||
candidate_output, beam_scores, expert_route, beam_idx, importance_loss = self.forward_post_route_uni(x, beam_scores, expert_route, use_log=True)
|
||||
elif self.route_method in ['post-route-uni', 'uni-cls-route', 'uni-cls-query-route', 'uni-cls-cross-route']:
|
||||
candidate_output, beam_scores, expert_route, beam_idx, importance_loss = self.forward_route_uni(x, beam_scores, expert_route, use_log=True, cls_hidden=cls_hidden)
|
||||
|
||||
return candidate_output, beam_scores, expert_route, beam_idx, importance_loss
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user