MiniGPT-4/minigpt4/models/moe/utils.py
2024-03-04 10:51:12 +08:00

172 lines
6.3 KiB
Python

import numpy as np
import pickle
import torch
import torch.nn as nn
from dataclasses import dataclass
from torch import Tensor
from transformers.activations import ACT2FN
from transformers.file_utils import ModelOutput
from typing import Optional, Tuple, List
def use_experts(layer_idx):
# if layer_idx % 2 == 0:
# use moe_ffn after cross_attns
if int(layer_idx) in [6,7,8,9,10,11]:
# layer 6/8/10
return True
else:
return False
def use_experts_route(layer_idx):
# if layer_idx % 2 == 0:
# use moe_ffn after cross_attns
# if int(layer_idx) in [0,2,4,6,8,10]:
if int(layer_idx) in [6,7,8,9,10,11]:
return True
else:
return False
def moe_layer_judge(layer_idx):
if layer_idx == 6:
return 'first'
elif layer_idx in [7,8,9,10]:
return 'mid'
elif layer_idx == 11:
return 'last'
else:
return None
# if layer_idx == 0:
# return 'first'
# elif layer_idx in [2,4,6,8]:
# return 'mid'
# elif layer_idx == 10:
# return 'last'
# else:
# return None
def process_ffn(model):
if model.config.model_type == "bert":
inner_model = model.bert
else:
raise ValueError("Model type not recognized.")
for i in range(model.config.num_hidden_layers):
model_layer = inner_model.encoder.layer[i]
if model_layer.use_experts:
model_layer.importance_processor.load_experts(model_layer)
class ImportanceProcessor:
def __init__(self, config, layer_idx, num_local_experts, local_group_rank):
self.num_experts = config.moebert_expert_num # total number of experts
self.num_local_experts = num_local_experts # number of experts on this device
self.local_group_rank = local_group_rank # rank in the current process group
self.intermediate_size = config.moebert_expert_dim # FFN hidden dimension
self.share_importance = config.moebert_share_importance # number of shared FFN dimension
importance = ImportanceProcessor.load_importance_single(config.moebert_load_importance)[layer_idx, :]
self.importance = self._split_importance(importance)
self.is_moe = False # safety check
@staticmethod
def load_importance_single(importance_files):
with open(importance_files, "rb") as file:
data = pickle.load(file)
data = data["idx"]
return np.array(data)
def _split_importance(self, arr):
result = []
top_importance = arr[:self.share_importance]
remain = arr[self.share_importance:]
all_experts_remain = []
for i in range(self.num_experts):
all_experts_remain.append(remain[i::self.num_experts])
all_experts_remain = np.array(all_experts_remain)
for i in range(self.num_local_experts):
temp = all_experts_remain[self.num_local_experts * self.local_group_rank + i]
temp = np.concatenate((top_importance, temp))
temp = temp[:self.intermediate_size]
result.append(temp.copy())
result = np.array(result)
return result
def load_experts(self, model_layer):
expert_list = model_layer.experts.experts
fc1_weight_data = model_layer.intermediate.dense.weight.data
fc1_bias_data = model_layer.intermediate.dense.bias.data
fc2_weight_data = model_layer.output.dense.weight.data
fc2_bias_data = model_layer.output.dense.bias.data
layernorm_weight_data = model_layer.output.LayerNorm.weight.data
layernorm_bias_data = model_layer.output.LayerNorm.bias.data
for i in range(self.num_local_experts):
idx = self.importance[i]
expert_list[i].fc1.weight.data = fc1_weight_data[idx, :].clone()
expert_list[i].fc1.bias.data = fc1_bias_data[idx].clone()
expert_list[i].fc2.weight.data = fc2_weight_data[:, idx].clone()
expert_list[i].fc2.bias.data = fc2_bias_data.clone()
expert_list[i].LayerNorm.weight.data = layernorm_weight_data.clone()
expert_list[i].LayerNorm.bias.data = layernorm_bias_data.clone()
del model_layer.intermediate
del model_layer.output
self.is_moe = True
class FeedForward(nn.Module):
def __init__(self, config, intermediate_size, dropout):
nn.Module.__init__(self)
# first layer
self.fc1 = nn.Linear(config.hidden_size, intermediate_size)
if isinstance(config.hidden_act, str):
self.intermediate_act_fn = ACT2FN[config.hidden_act]
else:
self.intermediate_act_fn = config.hidden_act
# second layer
self.fc2 = nn.Linear(intermediate_size, config.hidden_size)
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(dropout)
def forward(self, hidden_states: Tensor):
input_tensor = hidden_states
hidden_states = self.fc1(hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states)
hidden_states = self.fc2(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor)
return hidden_states
@dataclass
class MoEModelOutput(ModelOutput):
last_hidden_state: torch.FloatTensor = None
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
cross_attentions: Optional[Tuple[torch.FloatTensor]] = None
gate_loss: torch.FloatTensor = None
gate_loads: Optional[Tuple[torch.FloatTensor]] = None
beam_scores: Optional[Tuple[torch.FloatTensor]] = None
expert_route: Optional[Tuple[torch.FloatTensor]] = None
@dataclass
class MoEModelOutputWithPooling(ModelOutput):
last_hidden_state: torch.FloatTensor = None
pooler_output: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
cross_attentions: Optional[Tuple[torch.FloatTensor]] = None
gate_loss: torch.FloatTensor = None
gate_loads: Optional[Tuple[torch.FloatTensor]] = None
beam_scores: Optional[Tuple[torch.FloatTensor]] = None
expert_route: Optional[Tuple[torch.FloatTensor]] = None