1
0
mirror of https://github.com/Vision-CAIR/MiniGPT-4.git synced 2025-04-21 21:20:46 +00:00
MiniGPT-4/minigpt4/models/moe/beam_search.py
2023-12-19 11:24:51 +08:00

393 lines
17 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import copy
import pickle
import torch
import torch.nn as nn
import torch.nn.functional as F
class MoELayer(nn.Module):
def __init__(self, hidden_size, expert, gate, num_experts, route_method, topk=1, use_balance_loss=True, weight_type='l2_norm'):
# remove hash list
nn.Module.__init__(self)
self.num_experts = num_experts
self.experts = nn.ModuleList([copy.deepcopy(expert) for i in range(num_experts)])
self.route_method = route_method
self.topk = topk
self.use_balance_loss = use_balance_loss
self.weight_type = weight_type
if route_method in ["gate-token", "gate-sentence"]:
self.gate = gate
else:
raise KeyError("Routing method not supported.")
def _forward_gate_sentence(self, x, attention_mask):
"""
x: query_attention_output , torch.Size([bz, 32, 768])
attention_mask: torch.ones([bz, 32])
### Notice:
the raw version of expert_attention_mask is the extended_attention_mask,
which will be add to attention_score directly
the values of extended_attention_mask are -0.0 or -10000
it should be adjust to 1/0 version to be processed by experts
"""
attention_mask = torch.ones(attention_mask.shape[0], attention_mask.shape[1]).to(x.device)
x_masked = x * attention_mask.unsqueeze(-1) # torch.Size([bz, 32, 768])
x_average = x_masked.sum(1) / attention_mask.unsqueeze(-1).sum(1) # torch.Size([bz, 768])
logits_gate = self.gate(x_average) # torch.Size([bz, num_experts])
prob_gate = F.softmax(logits_gate, dim=-1) # torch.Size([bz, num_experts])
select_prob_gate, gate = torch.topk(prob_gate, self.topk, dim=1) # gate, 每个样本被分配的expert: torch.Size([bz, topk])
# 这里用l2 norm 去加权
if self.weight_type == 'l2_norm':
# normalized_tensor = torch.nn.functional.normalize(select_prob_gate, p=2, dim=0) # L2 Normalization torch.Size([bz, topk])
normalized_tensor = select_prob_gate
num_sentences = F.one_hot(gate, self.num_experts).sum(1).gt(0).sum(0) # 每个expert被分配的样本数 torch.Size([num_expert])
gate_load = num_sentences.clone()
# forward experts
def forward_expert(input_x, expert_idx):
input_x = self.experts[expert_idx].forward(input_x)
return input_x
result_lst = list()
for i in range(self.topk):
# top1、top2... 分别为一组进行gate分组之后过expert然后乘以概率后相加
tmp_gate = gate[:,i]
tmp_prob = normalized_tensor[:,i].unsqueeze(-1).unsqueeze(-1)
order = tmp_gate.argsort(0)
num_sentences_t = F.one_hot(tmp_gate, self.num_experts).gt(0).sum(0)
x1 = x[order] # reorder according to expert number
x1 = x1.split(num_sentences_t.tolist(), dim=0) # a list of length self.num_experts
result = []
for i in range(self.num_experts):
if x1[i].size(0) > 0:
result.append(forward_expert(x1[i], i))
result = torch.vstack(result)
result = result[order.argsort(0)] # restore original order
# result_lst.append(result * tmp_prob) # result * prob
result_lst.append(result) # result * prob
moe_result = sum(result_lst)
print('Layer Qformer MoE: \n',prob_gate)
return moe_result, select_prob_gate, gate
def forward(self, x, attention_mask):
if self.route_method == "gate-token":
x, balance_loss, gate_load = self._forward_gate_token(x)
elif self.route_method == "gate-sentence":
if x.size(0) == 1:
x, balance_loss, gate_load = self._forward_sentence_single_expert(x, attention_mask)
else:
x, balance_loss, gate_load = self._forward_gate_sentence(x, attention_mask)
elif self.route_method == "gate-sentence-post":
x, balance_loss, gate_load = self._forward_gate_sentence_post(x, attention_mask)
else:
raise KeyError("Routing method not supported.")
return x, balance_loss, gate_load
class RouteMoELayer(nn.Module):
def __init__(self, hidden_size, expert, gate, num_experts, num_beams=2, layer_judge=None, route_method="pre-route"):
# remove hash list
nn.Module.__init__(self)
self.num_experts = num_experts
self.experts = nn.ModuleList([copy.deepcopy(expert) for i in range(num_experts)])
self.num_beams = num_beams
self.hidden_size = hidden_size
self.layer_judge = layer_judge
self.route_method = route_method
if self.route_method == "pre-route":
self.gate = nn.Linear(hidden_size, num_experts, bias=False).float()
elif self.route_method == "post-route":
# gate = nn.Linear(hidden_size, 1, bias=False).float()
self.gates = nn.ModuleList([copy.deepcopy(gate) for i in range(num_experts)])
def forward_gate(self, x):
"""
x : torch.Size([bz*num_beams, 32, 768]) or torch.Size([bz, 32, 768])
prob_gate : torch.Size([bz*num_beams, num_experts]) or torch.Size([bz, num_experts])
"""
attention_mask = torch.ones(x.shape[0], x.shape[1]).to(x.device)
x_masked = x * attention_mask.unsqueeze(-1) # torch.Size([bz*num_beams, 32, 768])
x_average = x_masked.sum(1) / attention_mask.unsqueeze(-1).sum(1) # torch.Size([bz*num_beams, 768])
logits_gate = self.gate(x_average) # torch.Size([bz*num_beams, num_experts])
prob_gate = F.softmax(logits_gate, dim=-1) # torch.Size([bz*num_beams, num_experts])
return prob_gate
def beam_search(self, current_scores_log, beam_scores, expert_route, batch_size):
import pdb;pdb.set_trace()
if self.layer_judge=='first' and self.route_method=='pre-route':
assert beam_scores==None and expert_route==None
current_scores = torch.exp(current_scores_log)
topk_values, gate = torch.topk(current_scores, self.num_beams, dim=1) # gate, 每个样本被分配的expert: torch.Size([bz, topk])
beam_scores = topk_values.view(self.num_beams * batch_size) # torch.Size([bz * num_beams])
expert_route = gate.view(self.num_beams * batch_size).unsqueeze(1) # torch.Size([bz * num_beams,1])
beam_idx = None
else:
if self.layer_judge=='first' and self.route_method == 'post-route':
batch_size = batch_size
next_scores_raw1 = torch.exp(current_scores_log) # torch.Size([bz, num_experts])
else:
batch_size = int(batch_size // self.num_beams)
next_scores_raw = current_scores_log + torch.log(beam_scores).unsqueeze(1) # torch.Size([4*3, 5]) # 取log 之后,可以直接相加概率
next_scores_exp = torch.exp(next_scores_raw)
next_scores_raw1 = next_scores_exp.view(
batch_size, self.num_beams * self.num_experts
) # torch.Size([bz, num_beams*num_experts])
next_scores, next_experts = torch.topk(next_scores_raw1, self.num_beams, dim=1, largest=True, sorted=True)
# next_scores torch.Size([bz, num_beams])
# next_tokens torch.Size([bz, num_beams])
print(next_scores_raw1)
print(next_scores)
print(next_experts)
next_batch_beam = list()
for batch_idx in range(batch_size):
next_sent_beam = list()
for rank, (expert_id, expert_score) in enumerate(
zip(next_experts[batch_idx], next_scores[batch_idx])
):
expert_id = expert_id.item()
beam_id = expert_id // self.num_experts
ex_id = expert_id % self.num_experts
effective_beam_id = batch_idx*self.num_beams + beam_id
next_sent_beam.append((expert_score, ex_id, effective_beam_id))
next_batch_beam.extend(next_sent_beam)
import pdb;pdb.set_trace()
if self.layer_judge=='first' and self.route_method == 'post-route':
beam_scores = next_scores.view(self.num_beams * batch_size) # torch.Size([bz * num_beams])
expert_route = next_experts.view(self.num_beams * batch_size)
beam_scores = beam_scores.new([x[0] for x in next_batch_beam])
beam_experts = expert_route.new([x[1] for x in next_batch_beam]).unsqueeze(-1)
beam_idx = expert_route.new([int(x[2]/self.num_beams) for x in next_batch_beam])
expert_route = beam_experts
else:
beam_scores = beam_scores.new([x[0] for x in next_batch_beam])
beam_experts = expert_route[:,-1].new([x[1] for x in next_batch_beam])
beam_idx = expert_route[:,-1].new([x[2] for x in next_batch_beam])
pre_route = expert_route[beam_idx,:]
expert_route = torch.cat([pre_route, beam_experts.unsqueeze(1)], dim=-1)
import pdb;pdb.set_trace()
return beam_scores, expert_route, beam_idx
def forward_expert_ffn(self, x, expert_select, beam_scores):
"""
x_repeat : [bz*num_beams, 32,768]
expert_select : [bz*num_beams]
"""
# add_1212 l2_normalization
# normalized_tensor = torch.nn.functional.normalize(beam_scores, p=2, dim=0) # L2 Normalization torch.Size([bz, topk])
# tmp_prob = normalized_tensor.unsqueeze(-1).unsqueeze(-1)
outputs = list()
for i in range(x.shape[0]):
output_x = self.experts[expert_select[i]].forward(x[i])
outputs.append(output_x.unsqueeze(0))
candidate_output = torch.cat(outputs)
# candidate_output = candidate_output * tmp_prob
return candidate_output # torch.Size([bz*num_beams, 32, 768])
def forward_pre_route(self, x, beam_scores, expert_route, use_log=True):
current_scores = self.forward_gate(x) # [bz*num_beams, 5]
if use_log:
current_scores_log = torch.log(current_scores) # 取log之后可以直接相加
else:
current_scores_log = current_scores
batch_size, num_tokens = x.shape[0], x.shape[1]
beam_scores, expert_route, _ = self.beam_search(current_scores_log, beam_scores, expert_route, batch_size)
current_expert_select = expert_route[:,-1]
if self.layer_judge=='first': # expand first dim to batch_size * num_beams
replicated_tensor = x.unsqueeze(1).expand(batch_size, self.num_beams, num_tokens, self.hidden_size)
x = replicated_tensor.contiguous().view(-1, num_tokens, self.hidden_size) # [bz*num_beams, 32,768]
candidate_output = self.forward_expert_ffn(x, current_expert_select, beam_scores) # [bz*num_beams, 32,768]
return candidate_output, beam_scores, expert_route
def forward_post_route(self, x, beam_scores, expert_route, use_log=True):
# if self.layer_judge=='first': # expand first dim to batch_size * num_beams
# batch_size, num_tokens = x.shape[0], x.shape[1]
# replicated_tensor = x.unsqueeze(1).expand(batch_size, self.num_beams, num_tokens, self.hidden_size)
# x = replicated_tensor.contiguous().view(-1, num_tokens, self.hidden_size) # [bz*num_beams, 32,768]
attention_mask = torch.ones(x.shape[0], x.shape[1]).to(x.device)
x_masked = x * attention_mask.unsqueeze(-1) # torch.Size([bz, 32, 768])
def forward_expert(input_x, expert_idx):
output_x = self.experts[expert_idx].forward(input_x)
return output_x
outputs = list()
logits_gate_lst = list()
for expert_idx in range(self.num_experts):
output_x = forward_expert(x_masked, expert_idx)
outputs.append(output_x.unsqueeze(0))
output_x_aver = output_x.sum(1) / attention_mask.unsqueeze(-1).sum(1) # torch.Size([bz*num_beam, 768])
gate_acore = self.gates[expert_idx](output_x_aver)
logits_gate_lst.append(gate_acore)
candidate_output = torch.cat(outputs) # torch.Size([num_expert, bz*num_beam, 32, 768])
logits_gate = torch.cat(logits_gate_lst,dim=1)# torch.Size([bz*num_beam, num_expert])
current_scores = F.softmax(logits_gate, dim=-1) # torch.Size([bz*num_beam, num_experts])
if use_log:
current_scores_log = torch.log(current_scores) # 取log之后可以直接相加
else:
current_scores_log = current_scores
import pdb;pdb.set_trace()
batch_size = x.shape[0] # bz*num_beam
beam_scores, expert_route, beam_idx = self.beam_search(current_scores_log, beam_scores, expert_route, batch_size)
# beam_scores torch.Size([bz*num_beam])
# expert_route torch.Size([bz*num_beam, layer_n])
current_select_expert = expert_route[:,-1]
output = list()
for i in range(beam_idx.shape[0]):
b_idx = beam_idx[i]
ex_idx = current_select_expert[i]
ex_out = candidate_output[ex_idx, b_idx, :,:]
output.append(ex_out.unsqueeze(0))
final_output = torch.concat(output, dim=0)
return final_output, beam_scores, expert_route, beam_idx
def forward(self, x, attention_mask, beam_scores, expert_route, use_log=True):
"""
if first_layer: x [bz, 32, 768]
else: x [bz*num_beams, 32, 768]
"""
if self.route_method == 'pre-route':
candidate_output, beam_scores, expert_route, _ = self.forward_pre_route(x, beam_scores, expert_route, use_log=True)
elif self.route_method == "post-route":
candidate_output, beam_scores, expert_route, beam_idx = self.forward_post_route(x, beam_scores, expert_route, use_log=True)
return candidate_output, beam_scores, expert_route, beam_idx
if __name__ == '__main__':
import sys
sys.path.append("/mnt/pfs-guan-ssai/nlu/wanghanzi/multimodal/PromptMoE")
from minigpt4.models.QformerRouteMoE import BertConfig
from minigpt4.models.QformerRouteMoE import FeedForward
from minigpt4.models.moe.utils import (
use_experts,
moe_layer_judge,
)
vision_width = 1408
cross_attention_freq = 2
num_query_token = 32
# init_QformerMoE
config = BertConfig.from_pretrained("/mnt/pfs-guan-ssai/nlu/wanghanzi/models/bert-base-uncased")
config.encoder_width = vision_width
# insert cross-attention layer every other block
config.add_cross_attention = True
config.cross_attention_freq = cross_attention_freq
config.query_length = num_query_token
config.moebert_expert_num = 3
config.moebert_num_beams = 3
config.moebert_route_method = 'gate-sentence'
config.moe_topk = 2
config.use_balance_loss = False
config.moe_weight_type = 'l2_norm'
batch_size = 4
x = torch.randn(batch_size, 32, 768)
beam_scores, expert_route = None, None
x1 = x
x2 = x
beam_scores1, expert_route1 = None, None
for layer_num in [6, 8, 10]:
layer_judge = moe_layer_judge(layer_num)
ffn = FeedForward(config)
gate = nn.Linear(768, config.moebert_expert_num, bias=False).float()
# experts = RouteMoELayer(
# hidden_size=768,
# expert=ffn,
# gate = gate,
# num_experts=config.moebert_expert_num,
# num_beams=config.moebert_num_beams,
# layer_judge = layer_judge,
# route_method = "pre-route"
# )
# layer_output = experts(x, None, beam_scores, expert_route)
# hidden_states1, beam_scores, expert_route,_ = layer_output
# print(beam_scores)
# print(expert_route)
gate1 = nn.Linear(768, 1, bias=False).float()
experts_post = RouteMoELayer(
hidden_size=768,
expert=ffn,
gate = gate1,
num_experts=config.moebert_expert_num,
num_beams=config.moebert_num_beams,
layer_judge = layer_judge,
route_method = "post-route"
)
layer_output = experts_post(x1, None, beam_scores1, expert_route1, False)
hidden_states2, beam_scores1, expert_route1, beam_idx = layer_output
print(beam_scores1)
print(expert_route1)
print(beam_idx)
# experts_moe = MoELayer(
# hidden_size=config.hidden_size,
# expert=ffn,
# gate=gate,
# num_experts=config.moebert_expert_num,
# route_method=config.moebert_route_method,
# topk=config.moe_topk,
# use_balance_loss=config.use_balance_loss,
# weight_type=config.moe_weight_type,
# )
# attn_mask = torch.ones([batch_size, 32])
# layer_output = experts_moe(x2, attn_mask)
# hidden_states3, select_prob_gate, gate_load,_ = layer_output
# print(select_prob_gate)
# print(gate_load)
# x = hidden_states1
x1 = hidden_states2
# x2 = hidden_states3
print("------------------------------------")