diff --git a/minigpt4/models/QformerRouteMoECLS.py b/minigpt4/models/QformerRouteMoECLS.py new file mode 100644 index 0000000..1bfff87 --- /dev/null +++ b/minigpt4/models/QformerRouteMoECLS.py @@ -0,0 +1,1374 @@ +""" + * Copyright (c) 2023, salesforce.com, inc. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause + * By Junnan Li + * Based on huggingface code base + * https://github.com/huggingface/transformers/blob/v4.15.0/src/transformers/models/bert +""" + +import math +import os +import warnings +import copy +from dataclasses import dataclass +from typing import Optional, Tuple, Dict, Any + +import torch +from torch import Tensor, device, dtype, nn +import torch.utils.checkpoint +from torch.nn import CrossEntropyLoss +import torch.nn.functional as F + +from transformers.activations import ACT2FN +from transformers.file_utils import ( + ModelOutput, +) +from transformers.modeling_outputs import ( + BaseModelOutputWithPastAndCrossAttentions, + BaseModelOutputWithPoolingAndCrossAttentions, + CausalLMOutputWithCrossAttentions, + MaskedLMOutput, + MultipleChoiceModelOutput, + NextSentencePredictorOutput, + QuestionAnsweringModelOutput, + SequenceClassifierOutput, + TokenClassifierOutput, +) +from transformers.modeling_utils import ( + PreTrainedModel, + apply_chunking_to_forward, + find_pruneable_heads_and_indices, + prune_linear_layer, +) +from transformers.utils import logging +from transformers.models.bert.configuration_bert import BertConfig + +from minigpt4.models.moe.utils import ( + MoEModelOutput, + MoEModelOutputWithPooling, + use_experts_route, + moe_layer_judge, +) +from minigpt4.models.moe.route_moe_layer import RouteMoELayer + +logging.set_verbosity_error() # ignore warning : Some weights of BertLMHeadModel were not initialized from the model checkpoint... +logger = logging.get_logger(__name__) + +# from visualizer import get_local + +class BertEmbeddings(nn.Module): + """Construct the embeddings from word and position embeddings.""" + + def __init__(self, config): + super().__init__() + self.word_embeddings = nn.Embedding( + config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id + ) + self.position_embeddings = nn.Embedding( + config.max_position_embeddings, config.hidden_size + ) + + # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load + # any TensorFlow checkpoint file + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + # position_ids (1, len position emb) is contiguous in memory and exported when serialized + self.register_buffer( + "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)) + ) + self.position_embedding_type = getattr( + config, "position_embedding_type", "absolute" + ) + + self.config = config + + def forward( + self, + input_ids=None, + position_ids=None, + query_embeds=None, + past_key_values_length=0, + ): + if input_ids is not None: + seq_length = input_ids.size()[1] + else: + seq_length = 0 + + if position_ids is None: + position_ids = self.position_ids[ + :, past_key_values_length : seq_length + past_key_values_length + ].clone() + + if input_ids is not None: + embeddings = self.word_embeddings(input_ids) + if self.position_embedding_type == "absolute": + position_embeddings = self.position_embeddings(position_ids) + embeddings = embeddings + position_embeddings + + if query_embeds is not None: + embeddings = torch.cat((query_embeds, embeddings), dim=1) + else: + embeddings = query_embeds + + embeddings = self.LayerNorm(embeddings) + embeddings = self.dropout(embeddings) + return embeddings + + +class BertSelfAttention(nn.Module): + def __init__(self, config, is_cross_attention): + super().__init__() + self.config = config + if config.hidden_size % config.num_attention_heads != 0 and not hasattr( + config, "embedding_size" + ): + raise ValueError( + "The hidden size (%d) is not a multiple of the number of attention " + "heads (%d)" % (config.hidden_size, config.num_attention_heads) + ) + + self.num_attention_heads = config.num_attention_heads # 12 + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) # 64 + self.all_head_size = self.num_attention_heads * self.attention_head_size # 768 + + self.query = nn.Linear(config.hidden_size, self.all_head_size) # nn.Linear(768, 768) + if is_cross_attention: + self.key = nn.Linear(config.encoder_width, self.all_head_size) # nn.Linear(1408, 768) + self.value = nn.Linear(config.encoder_width, self.all_head_size) # nn.Linear(1408, 768) + else: + self.key = nn.Linear(config.hidden_size, self.all_head_size) # nn.Linear(768, 768) + self.value = nn.Linear(config.hidden_size, self.all_head_size) # nn.Linear(768, 768) + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + self.position_embedding_type = getattr( + config, "position_embedding_type", "absolute" + ) + if ( + self.position_embedding_type == "relative_key" + or self.position_embedding_type == "relative_key_query" + ): + self.max_position_embeddings = config.max_position_embeddings + self.distance_embedding = nn.Embedding( + 2 * config.max_position_embeddings - 1, self.attention_head_size + ) + self.save_attention = False + + def save_attn_gradients(self, attn_gradients): + self.attn_gradients = attn_gradients + + def get_attn_gradients(self): + return self.attn_gradients + + def save_attention_map(self, attention_map): + self.attention_map = attention_map + + def get_attention_map(self): + return self.attention_map + + def transpose_for_scores(self, x): + new_x_shape = x.size()[:-1] + ( + self.num_attention_heads, + self.attention_head_size, + ) # torch.Size([1, 257, 12, 64]) + x = x.view(*new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=None, + output_attentions=False, + ): + + # If this is instantiated as a cross-attention module, the keys + # and values come from an encoder; the attention mask needs to be + # such that the encoder's padding tokens are not attended to. + is_cross_attention = encoder_hidden_states is not None + + if is_cross_attention: + key_layer = self.transpose_for_scores(self.key(encoder_hidden_states)) # encoder_hidden_states:[bz,257,1408], torch.Size([1, 12, 257, 64]) + value_layer = self.transpose_for_scores(self.value(encoder_hidden_states)) # torch.Size([1, 12, 257, 64]) + attention_mask = encoder_attention_mask + elif past_key_value is not None: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + key_layer = torch.cat([past_key_value[0], key_layer], dim=2) + value_layer = torch.cat([past_key_value[1], value_layer], dim=2) + else: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + + mixed_query_layer = self.query(hidden_states) + + query_layer = self.transpose_for_scores(mixed_query_layer) # torch.Size([1, 12, 41, 64]) + + past_key_value = (key_layer, value_layer) # torch.Size([1, 12, 41, 257]) + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + + if ( + self.position_embedding_type == "relative_key" + or self.position_embedding_type == "relative_key_query" + ): + seq_length = hidden_states.size()[1] + position_ids_l = torch.arange( + seq_length, dtype=torch.long, device=hidden_states.device + ).view(-1, 1) + position_ids_r = torch.arange( + seq_length, dtype=torch.long, device=hidden_states.device + ).view(1, -1) + distance = position_ids_l - position_ids_r + positional_embedding = self.distance_embedding( + distance + self.max_position_embeddings - 1 + ) + positional_embedding = positional_embedding.to( + dtype=query_layer.dtype + ) # fp16 compatibility + + if self.position_embedding_type == "relative_key": + relative_position_scores = torch.einsum( + "bhld,lrd->bhlr", query_layer, positional_embedding + ) + attention_scores = attention_scores + relative_position_scores + elif self.position_embedding_type == "relative_key_query": + relative_position_scores_query = torch.einsum( + "bhld,lrd->bhlr", query_layer, positional_embedding + ) + relative_position_scores_key = torch.einsum( + "bhrd,lrd->bhlr", key_layer, positional_embedding + ) + attention_scores = ( + attention_scores + + relative_position_scores_query + + relative_position_scores_key + ) + + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + if attention_mask is not None: + # Apply the attention mask is (precomputed for all layers in BertModel forward() function) + # extended_attention_mask + + attention_scores = attention_scores + attention_mask + + # Normalize the attention scores to probabilities. + attention_probs = nn.Softmax(dim=-1)(attention_scores) + + if is_cross_attention and self.save_attention: + self.save_attention_map(attention_probs) + attention_probs.register_hook(self.save_attn_gradients) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs_dropped = self.dropout(attention_probs) # torch.Size([1, 12, 41, 257]) + + # Mask heads if we want to + if head_mask is not None: + attention_probs_dropped = attention_probs_dropped * head_mask + + context_layer = torch.matmul(attention_probs_dropped, value_layer) # torch.Size([1, 12, 41, 64]) + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(*new_context_layer_shape) # torch.Size([1, 41, 768]) + + outputs = ( + (context_layer, attention_probs) if output_attentions else (context_layer,) + ) + + outputs = outputs + (past_key_value,) + return outputs + + +class BertSelfOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class BertAttention(nn.Module): + def __init__(self, config, is_cross_attention=False): + super().__init__() + self.self = BertSelfAttention(config, is_cross_attention) + self.output = BertSelfOutput(config) + self.pruned_heads = set() + + def prune_heads(self, heads): + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, + self.self.num_attention_heads, + self.self.attention_head_size, + self.pruned_heads, + ) + + # Prune linear layers + self.self.query = prune_linear_layer(self.self.query, index) + self.self.key = prune_linear_layer(self.self.key, index) + self.self.value = prune_linear_layer(self.self.value, index) + self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) + + # Update hyper params and store pruned heads + self.self.num_attention_heads = self.self.num_attention_heads - len(heads) + self.self.all_head_size = ( + self.self.attention_head_size * self.self.num_attention_heads + ) + self.pruned_heads = self.pruned_heads.union(heads) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=None, + output_attentions=False, + ): + self_outputs = self.self( + hidden_states, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + ) + attention_output = self.output(self_outputs[0], hidden_states) + + outputs = (attention_output,) + self_outputs[ + 1: + ] # add attentions if we output them + return outputs + + +class BertIntermediate(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.intermediate_size) + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + +class BertOutput(nn.Module): # Add & Norm + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.intermediate_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) # 1 + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + # Move LayerNorm & ResNet out of FFN After MoEFFN + hidden_states = self.LayerNorm(hidden_states + input_tensor) # 1 + return hidden_states + + +class FeedForward(nn.Module): + # remove LayerNorm + def __init__(self, config): + super().__init__() + self.dense1 = nn.Linear(config.hidden_size, config.intermediate_size) + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + self.dense2 = nn.Linear(config.intermediate_size, config.hidden_size) + self.dropout = nn.Dropout(config.hidden_dropout_prob) # adjust dropout ratio 0.1->0.2 + # self.dropout = nn.Dropout(0.2) # adjust dropout ratio 0.1->0.2 + + def forward(self, hidden_states: Tensor): + hidden_states = self.dense1(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + hidden_states = self.dense2(hidden_states) + hidden_states = self.dropout(hidden_states) + return hidden_states + + +class BertLayer(nn.Module): + def __init__(self, config, layer_num): + super().__init__() + self.config = config + self.chunk_size_feed_forward = config.chunk_size_feed_forward + self.seq_len_dim = 1 + self.attention = BertAttention(config) + self.layer_num = layer_num + if ( + self.config.add_cross_attention + and layer_num % self.config.cross_attention_freq == 0 + ): + self.crossattention = BertAttention( + config, is_cross_attention=self.config.add_cross_attention + ) + self.has_cross_attention = True + else: + self.has_cross_attention = False + + self.intermediate = BertIntermediate(config) + self.output = BertOutput(config) + + self.intermediate_query = BertIntermediate(config) + self.output_query = BertOutput(config) + + # Add MoE FFN + self.use_experts = use_experts_route(layer_num) + self.layer_judge = moe_layer_judge(layer_num) + self.num_beams = config.moebert_num_beams + ffn = FeedForward(config) + if self.use_experts: + self.experts = RouteMoELayer( + hidden_size=config.hidden_size, + expert=ffn, + num_experts=config.moebert_expert_num, + num_beams=config.moebert_num_beams, + layer_judge = self.layer_judge, + route_method=config.route_method, + weight_type=config.moe_weight_type, + ) + else: + self.experts = ffn + self.expert_ln = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=None, + output_attentions=False, + query_length=0, + beam_scores=None, + expert_route=None, + ): + # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 + self_attn_past_key_value = ( + past_key_value[:2] if past_key_value is not None else None + ) + # import pdb; pdb.set_trace() # 0107test + + # adjust the dimension of hidden_states, attention_mask, encoder_attention_mask and encoder_hidden_states to be the same + if self.num_beams > 1: + if hidden_states.shape[0]== attention_mask.shape[0]*self.num_beams: + # attention_mask dimension to be bz*num_beams + attention_mask = self.adjust_attention_mask(attention_mask) + encoder_attention_mask = self.adjust_attention_mask(encoder_attention_mask) + + if hidden_states.shape[0]*self.num_beams == attention_mask.shape[0]: + # attention_mask dimension back to bz + batch_size = attention_mask.shape[0] + attention_mask = attention_mask[[ i for i in range(0, batch_size, self.num_beams)]] + + if hidden_states.shape[0] == encoder_hidden_states.shape[0]*self.num_beams: + batch_size, visual_tokens, vision_dim = encoder_hidden_states.shape + tmp = encoder_hidden_states.unsqueeze(1).expand(batch_size, self.num_beams, visual_tokens, vision_dim ) + encoder_hidden_states = tmp.contiguous().view(batch_size* self.num_beams, visual_tokens, vision_dim) # torch.Size([bz, 257, 1408]) + + self_attention_outputs = self.attention( + hidden_states, + attention_mask, + head_mask, + output_attentions=output_attentions, + past_key_value=self_attn_past_key_value, + ) + attention_output = self_attention_outputs[0] + outputs = self_attention_outputs[1:-1] + + present_key_value = self_attention_outputs[-1] + + if query_length > 0: + query_attention_output = attention_output[:, :query_length, :] + + if self.has_cross_attention: + + assert ( + encoder_hidden_states is not None + ), "encoder_hidden_states must be given for cross-attention layers" + + cross_attention_outputs = self.crossattention( + query_attention_output, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + output_attentions=output_attentions, + ) + query_attention_output = cross_attention_outputs[0] + outputs = ( + outputs + cross_attention_outputs[1:-1] + ) # add cross attentions if we output attention weights + + if attention_output.shape[1] > query_length: # have text input in Qformer + layer_output_text = apply_chunking_to_forward( + self.feed_forward_chunk, + self.chunk_size_feed_forward, + self.seq_len_dim, + attention_output[:, query_length:, :], + ) + cls_hidden = layer_output_text[0][:, 0, :] # [bz, hidden_size] + + # add moe query ffn + moe_ffn_attention_input = query_attention_output[:, :query_length, :] # [bz, query_length+seq_len, 768] + moe_ffn_attention_mask = attention_mask.squeeze(dim=1).squeeze(dim=1)[:, :query_length] # [bz, 1, 1, query_length+seq_len] + layer_output = self.feed_forward_query_moe(moe_ffn_attention_input, moe_ffn_attention_mask, beam_scores, expert_route, cls_hidden) + # layer_output = (layer_output, beam_scores, expert_route, beam_idx, importance_loss) + # import pdb; pdb.set_trace() # 0107test + + if attention_output.shape[1] > query_length: + if self.layer_judge == 'first' and self.num_beams>1: + # if layer_output[0].shape[0] == layer_output_text.shape[0]*self.num_beams and self.num_beams>1: + # adjust the dimension of layer_output_text to bz*num_beams + layer_output_text = self.adjust_hidden_states_by_num_beams(layer_output_text) + + if self.layer_judge == 'mid' and self.num_beams > 1: + # layer_output_text [bz*num_beams, len, hidden_size] + beam_idx = layer_output[3] + layer_output_text = layer_output_text[beam_idx] + + if self.layer_judge == 'last' and self.num_beams>1: + # select top1 for each sample among beams + # layer_output = (hidden_states, beam_scores, expert_route) + # layer_output & layer_output_text dimen_0 from bz*num_beams to bz + layer_output, layer_output_text = self.route_moe_last_layer_top1(layer_output, layer_output_text) + + layer_output = (torch.cat([layer_output[0], layer_output_text], dim=1), layer_output[1], layer_output[2], layer_output[3],layer_output[4]) + + + else: + layer_output = apply_chunking_to_forward( + self.feed_forward_chunk, + self.chunk_size_feed_forward, + self.seq_len_dim, + attention_output, + ) + layer_output = (layer_output, None, None, None, 0.0) + + outputs = (layer_output,) + outputs + + outputs = outputs + (present_key_value,) + + return outputs + + def adjust_attention_mask(self, attention_mask): + batch_size = attention_mask.shape[0] + tmp = attention_mask.unsqueeze(1).expand(batch_size, self.num_beams, 1, 1, attention_mask.shape[3]) + attention_mask = tmp.contiguous().view(batch_size* self.num_beams, 1, 1, attention_mask.shape[3]) # torch.Size([bz*num_beams, 1, 1, 32+input_len]) + return attention_mask + + def adjust_hidden_states_by_num_beams(self, hidden_states): + batch_size, text_length, hidden_size = hidden_states.shape + tmp_text = hidden_states.unsqueeze(1).expand(batch_size, self.num_beams, text_length, hidden_size) + hidden_states = tmp_text.contiguous().view(-1, text_length, hidden_size) # [bz*num_beams, text_length ,768] + return hidden_states + + def route_moe_last_layer_top1(self, layer_output, layer_output_text): + batch_size = layer_output[0].shape[0] + raw_batch_size = int(batch_size / self.num_beams) + hidden_states, beam_scores, expert_route, beam_idx = layer_output[0], layer_output[1], layer_output[2], layer_output[3] + layer_output_text = layer_output_text[beam_idx] + + scores = beam_scores.view(raw_batch_size, self.num_beams) + _, gate = torch.topk(scores, 1, dim=1) + selects = [ (bz_idx * self.num_beams + gate[bz_idx].item()) for bz_idx in range(raw_batch_size)] + + layer_output_text = layer_output_text[selects] + hidden_states_new = hidden_states[selects] + beam_scores_new = beam_scores[selects] + expert_route_new = expert_route[selects] + + return (hidden_states_new, beam_scores_new, expert_route_new, layer_output[3], layer_output[4]), layer_output_text + + def feed_forward_chunk(self, attention_output): + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output, attention_output) + return layer_output + + def feed_forward_query_moe(self, attention_output, expert_attention_mask, beam_scores, expert_route, cls_hidden): + if not self.use_experts: + hidden_states = self.experts(attention_output) + layer_output = self.expert_ln(hidden_states + attention_output) + return layer_output, None, None, None, 0.0 + + hidden_states, beam_scores, expert_route, beam_idx, importance_loss = self.experts( + attention_output, expert_attention_mask, beam_scores, expert_route, cls_hidden + ) + if hidden_states.shape[0]==attention_output.shape[0]*self.num_beams and self.num_beams>1: + attention_output = self.adjust_hidden_states_by_num_beams(attention_output) + layer_output = self.expert_ln(hidden_states + attention_output) + + return layer_output, beam_scores, expert_route, beam_idx, importance_loss + +class BertEncoder(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.layer = nn.ModuleList( + [BertLayer(config, i) for i in range(config.num_hidden_layers)] + ) + + # @get_local('all_cross_attentions') + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_values=None, + use_cache=None, + output_attentions=False, + output_hidden_states=False, + return_dict=True, + query_length=0, + ): + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + all_cross_attentions = ( + () if output_attentions and self.config.add_cross_attention else None + ) + + next_decoder_cache = () if use_cache else None + beam_scores=None + expert_route=None + importance_loss = 0 + for i in range(self.config.num_hidden_layers): + + layer_module = self.layer[i] + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer_head_mask = head_mask[i] if head_mask is not None else None + past_key_value = past_key_values[i] if past_key_values is not None else None + + if getattr(self.config, "gradient_checkpointing", False) and self.training: + if use_cache: + logger.warn( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + def create_custom_forward(module): + def custom_forward(*inputs): + return module( + *inputs, past_key_value, output_attentions, query_length, beam_scores, expert_route + ) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(layer_module), + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + ) + else: + layer_outputs = layer_module( + hidden_states, #torch.Size([bz, 32+input_len, 768]) + attention_mask, # torch.Size([bz, 1, 1, 32+input_len]) + layer_head_mask, # None + encoder_hidden_states, # torch.Size([bz, 257, 1408]) + encoder_attention_mask, + past_key_value, + output_attentions, # False + query_length, # 32 + beam_scores, # None + expert_route, # None + ) + hidden_states = layer_outputs[0][0] + beam_scores = beam_scores if layer_outputs[0][1] == None else layer_outputs[0][1] + expert_route = expert_route if layer_outputs[0][2] == None else layer_outputs[0][2] + importance_loss += layer_outputs[0][4] + + if use_cache: + next_decoder_cache += (layer_outputs[-1],) + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + all_cross_attentions = all_cross_attentions + (layer_outputs[2],) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple( + v + for v in [ + hidden_states, + next_decoder_cache, + all_hidden_states, + all_self_attentions, + all_cross_attentions, + ] + if v is not None + ) + + return MoEModelOutput( + last_hidden_state=hidden_states, + past_key_values=next_decoder_cache, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + cross_attentions=all_cross_attentions, + beam_scores=beam_scores, + expert_route=expert_route, + gate_loss=importance_loss, + ) + + +class BertPooler(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.activation = nn.Tanh() + + def forward(self, hidden_states): + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + first_token_tensor = hidden_states[:, 0] + pooled_output = self.dense(first_token_tensor) + pooled_output = self.activation(pooled_output) + return pooled_output + + +class BertPredictionHeadTransform(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + if isinstance(config.hidden_act, str): + self.transform_act_fn = ACT2FN[config.hidden_act] + else: + self.transform_act_fn = config.hidden_act + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.transform_act_fn(hidden_states) + hidden_states = self.LayerNorm(hidden_states) + return hidden_states + + +class BertLMPredictionHead(nn.Module): + def __init__(self, config): + super().__init__() + self.transform = BertPredictionHeadTransform(config) + + # The output weights are the same as the input embeddings, but there is + # an output-only bias for each token. + self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + self.bias = nn.Parameter(torch.zeros(config.vocab_size)) + + # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` + self.decoder.bias = self.bias + + def forward(self, hidden_states): + hidden_states = self.transform(hidden_states) + hidden_states = self.decoder(hidden_states) + return hidden_states + + +class BertOnlyMLMHead(nn.Module): + def __init__(self, config): + super().__init__() + self.predictions = BertLMPredictionHead(config) + + def forward(self, sequence_output): + prediction_scores = self.predictions(sequence_output) + return prediction_scores + + +class BertPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = BertConfig + base_model_prefix = "bert" + _keys_to_ignore_on_load_missing = [r"position_ids"] + + def _init_weights(self, module): + """Initialize the weights""" + if isinstance(module, (nn.Linear, nn.Embedding)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + if isinstance(module, nn.Linear) and module.bias is not None: + module.bias.data.zero_() + + +class BertModel(BertPreTrainedModel): + """ + The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of + cross-attention is added between the self-attention layers, following the architecture described in `Attention is + all you need `__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, + Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin. + argument and :obj:`add_cross_attention` set to :obj:`True`; an :obj:`encoder_hidden_states` is then expected as an + input to the forward pass. + """ + + def __init__(self, config, add_pooling_layer=False): + super().__init__(config) + self.config = config + + self.embeddings = BertEmbeddings(config) + + self.encoder = BertEncoder(config) + + self.pooler = BertPooler(config) if add_pooling_layer else None + + self.init_weights() + + def get_input_embeddings(self): + return self.embeddings.word_embeddings + + def set_input_embeddings(self, value): + self.embeddings.word_embeddings = value + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + def get_extended_attention_mask( + self, + attention_mask: Tensor, + input_shape: Tuple[int], + device: device, + is_decoder: bool, + has_query: bool = False, + ) -> Tensor: + """ + Makes broadcastable attention and causal masks so that future and masked tokens are ignored. + + Arguments: + attention_mask (:obj:`torch.Tensor`): + Mask with ones indicating tokens to attend to, zeros for tokens to ignore. + input_shape (:obj:`Tuple[int]`): + The shape of the input to the model. + device: (:obj:`torch.device`): + The device of the input to the model. + + Returns: + :obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`. + """ + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + + if attention_mask.dim() == 3: + extended_attention_mask = attention_mask[:, None, :, :] + elif attention_mask.dim() == 2: + # Provided a padding mask of dimensions [batch_size, seq_length] + # - if the model is a decoder, apply a causal mask in addition to the padding mask + # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length] + if is_decoder: + batch_size, seq_length = input_shape + + seq_ids = torch.arange(seq_length, device=device) + causal_mask = ( + seq_ids[None, None, :].repeat(batch_size, seq_length, 1) + <= seq_ids[None, :, None] + ) + + # add a prefix ones mask to the causal mask + # causal and attention masks must have same type with pytorch version < 1.3 + causal_mask = causal_mask.to(attention_mask.dtype) + + if causal_mask.shape[1] < attention_mask.shape[1]: + prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1] + if has_query: # UniLM style attention mask + causal_mask = torch.cat( + [ + torch.zeros( + (batch_size, prefix_seq_len, seq_length), + device=device, + dtype=causal_mask.dtype, + ), + causal_mask, + ], + axis=1, + ) + causal_mask = torch.cat( + [ + torch.ones( + (batch_size, causal_mask.shape[1], prefix_seq_len), + device=device, + dtype=causal_mask.dtype, + ), + causal_mask, + ], + axis=-1, + ) + extended_attention_mask = ( + causal_mask[:, None, :, :] * attention_mask[:, None, None, :] + ) + else: + extended_attention_mask = attention_mask[:, None, None, :] + else: + raise ValueError( + "Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format( + input_shape, attention_mask.shape + ) + ) + + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for + # masked positions, this operation will create a tensor which is 0.0 for + # positions we want to attend and -10000.0 for masked positions. + # Since we are adding it to the raw scores before the softmax, this is + # effectively the same as removing these entirely. + extended_attention_mask = extended_attention_mask.to( + dtype=self.dtype + ) # fp16 compatibility + extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 + return extended_attention_mask + + def forward( + self, + input_ids=None, + attention_mask=None, + position_ids=None, + head_mask=None, + query_embeds=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_values=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + is_decoder=False, + ): + r""" + encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``: + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids` + (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)` + instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`. + use_cache (:obj:`bool`, `optional`): + If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up + decoding (see :obj:`past_key_values`). + """ + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + # use_cache = use_cache if use_cache is not None else self.config.use_cache + + if input_ids is None: + assert ( + query_embeds is not None + ), "You have to specify query_embeds when input_ids is None" + + # past_key_values_length + past_key_values_length = ( + past_key_values[0][0].shape[2] - self.config.query_length + if past_key_values is not None + else 0 + ) + + query_length = query_embeds.shape[1] if query_embeds is not None else 0 + + embedding_output = self.embeddings( + input_ids=input_ids, + position_ids=position_ids, + query_embeds=query_embeds, + past_key_values_length=past_key_values_length, + ) + + input_shape = embedding_output.size()[:-1] + batch_size, seq_length = input_shape + device = embedding_output.device + + if attention_mask is None: + attention_mask = torch.ones( + ((batch_size, seq_length + past_key_values_length)), device=device + ) + + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + if is_decoder: + extended_attention_mask = self.get_extended_attention_mask( + attention_mask, + input_ids.shape, + device, + is_decoder, + has_query=(query_embeds is not None), + ) + else: + extended_attention_mask = self.get_extended_attention_mask( + attention_mask, input_shape, device, is_decoder + ) + + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + if encoder_hidden_states is not None: + if type(encoder_hidden_states) == list: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[ + 0 + ].size() + else: + ( + encoder_batch_size, + encoder_sequence_length, + _, + ) = encoder_hidden_states.size() + encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) + + if type(encoder_attention_mask) == list: + encoder_extended_attention_mask = [ + self.invert_attention_mask(mask) for mask in encoder_attention_mask + ] + elif encoder_attention_mask is None: + encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) + encoder_extended_attention_mask = self.invert_attention_mask( + encoder_attention_mask + ) + else: + encoder_extended_attention_mask = self.invert_attention_mask( + encoder_attention_mask + ) + else: + encoder_extended_attention_mask = None + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) + + encoder_outputs = self.encoder( + embedding_output, + attention_mask=extended_attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_extended_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + query_length=query_length, + ) + sequence_output = encoder_outputs[0] + pooled_output = ( + self.pooler(sequence_output) if self.pooler is not None else None + ) + + if not return_dict: + return (sequence_output, pooled_output) + encoder_outputs[1:] + + return MoEModelOutputWithPooling( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + past_key_values=encoder_outputs.past_key_values, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + cross_attentions=encoder_outputs.cross_attentions, + beam_scores=encoder_outputs.beam_scores, + expert_route=encoder_outputs.expert_route, + gate_loss=encoder_outputs.gate_loss + ) + + +class BertMoECLSRouteLMHeadModel(BertPreTrainedModel): + + _keys_to_ignore_on_load_unexpected = [r"pooler"] + _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"] + + def __init__(self, config): + super().__init__(config) + + self.bert = BertModel(config, add_pooling_layer=False) + self.cls = BertOnlyMLMHead(config) + + self.init_weights() + + def get_output_embeddings(self): + return self.cls.predictions.decoder + + def set_output_embeddings(self, new_embeddings): + self.cls.predictions.decoder = new_embeddings + + def forward( + self, + input_ids=None, + attention_mask=None, + position_ids=None, + head_mask=None, + query_embeds=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + labels=None, + past_key_values=None, + use_cache=True, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + return_logits=False, + is_decoder=True, + reduction="mean", + ): + r""" + encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``: + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in + ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are + ignored (masked), the loss is only computed for the tokens with labels n ``[0, ..., config.vocab_size]`` + past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids` + (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)` + instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`. + use_cache (:obj:`bool`, `optional`): + If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up + decoding (see :obj:`past_key_values`). + Returns: + Example:: + >>> from transformers import BertTokenizer, BertLMHeadModel, BertConfig + >>> import torch + >>> tokenizer = BertTokenizer.from_pretrained('bert-base-cased') + >>> config = BertConfig.from_pretrained("bert-base-cased") + >>> model = BertLMHeadModel.from_pretrained('bert-base-cased', config=config) + >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") + >>> outputs = model(**inputs) + >>> prediction_logits = outputs.logits + """ + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + if labels is not None: + use_cache = False + if past_key_values is not None: + query_embeds = None + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + head_mask=head_mask, + query_embeds=query_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + is_decoder=is_decoder, + ) + # gate_loss = outputs.gate_loss + + sequence_output = outputs[0] + if query_embeds is not None: + sequence_output = outputs[0][:, query_embeds.shape[1] :, :] + + prediction_scores = self.cls(sequence_output) + + if return_logits: + return prediction_scores[:, :-1, :].contiguous() + + lm_loss, total_loss = None, None + if labels is not None: + # we are doing next-token prediction; shift prediction scores and input ids by one + shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous() + labels = labels[:, 1:].contiguous() + loss_fct = CrossEntropyLoss(reduction=reduction, label_smoothing=0.1) + lm_loss = loss_fct( + shifted_prediction_scores.view(-1, self.config.vocab_size), + labels.view(-1), + ) + if reduction == "none": + lm_loss = lm_loss.view(prediction_scores.size(0), -1).sum(1) + + total_loss = lm_loss + + if not return_dict: + output = (prediction_scores,) + outputs[2:] + return ((total_loss,) + output) if total_loss is not None else output + + return CausalLMOutputWithCrossAttentions( + loss=total_loss, + logits=prediction_scores, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, + ) + + def prepare_inputs_for_generation( + self, input_ids, query_embeds, past=None, attention_mask=None, **model_kwargs + ): + # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly + if attention_mask is None: + attention_mask = input_ids.new_ones(input_ids.shape) + query_mask = input_ids.new_ones(query_embeds.shape[:-1]) + attention_mask = torch.cat([query_mask, attention_mask], dim=-1) + + # cut decoder_input_ids if past is used + if past is not None: + input_ids = input_ids[:, -1:] + + return { + "input_ids": input_ids, + "query_embeds": query_embeds, + "attention_mask": attention_mask, + "past_key_values": past, + "encoder_hidden_states": model_kwargs.get("encoder_hidden_states", None), + "encoder_attention_mask": model_kwargs.get("encoder_attention_mask", None), + "is_decoder": True, + } + + def _reorder_cache(self, past, beam_idx): + reordered_past = () + for layer_past in past: + reordered_past += ( + tuple( + past_state.index_select(0, beam_idx) for past_state in layer_past + ), + ) + return reordered_past + + +class BertForMaskedLM(BertPreTrainedModel): + + _keys_to_ignore_on_load_unexpected = [r"pooler"] + _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"] + + def __init__(self, config): + super().__init__(config) + + self.bert = BertModel(config, add_pooling_layer=False) + self.cls = BertOnlyMLMHead(config) + + self.init_weights() + + def get_output_embeddings(self): + return self.cls.predictions.decoder + + def set_output_embeddings(self, new_embeddings): + self.cls.predictions.decoder = new_embeddings + + def forward( + self, + input_ids=None, + attention_mask=None, + position_ids=None, + head_mask=None, + query_embeds=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + labels=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + return_logits=False, + is_decoder=False, + ): + r""" + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Labels for computing the masked language modeling loss. Indices should be in ``[-100, 0, ..., + config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored + (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]`` + """ + + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + head_mask=head_mask, + query_embeds=query_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + is_decoder=is_decoder, + ) + + if query_embeds is not None: + sequence_output = outputs[0][:, query_embeds.shape[1] :, :] + prediction_scores = self.cls(sequence_output) + + if return_logits: + return prediction_scores + + masked_lm_loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() # -100 index = padding token + masked_lm_loss = loss_fct( + prediction_scores.view(-1, self.config.vocab_size), labels.view(-1) + ) + + if not return_dict: + output = (prediction_scores,) + outputs[2:] + return ( + ((masked_lm_loss,) + output) if masked_lm_loss is not None else output + ) + + return MaskedLMOutput( + loss=masked_lm_loss, + logits=prediction_scores, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/minigpt4/models/QformerRouteMoECLSLN.py b/minigpt4/models/QformerRouteMoECLSLN.py new file mode 100644 index 0000000..8292b96 --- /dev/null +++ b/minigpt4/models/QformerRouteMoECLSLN.py @@ -0,0 +1,1365 @@ +""" + * Copyright (c) 2023, salesforce.com, inc. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause + * By Junnan Li + * Based on huggingface code base + * https://github.com/huggingface/transformers/blob/v4.15.0/src/transformers/models/bert +""" + +import math +import os +import warnings +import copy +from dataclasses import dataclass +from typing import Optional, Tuple, Dict, Any + +import torch +from torch import Tensor, device, dtype, nn +import torch.utils.checkpoint +from torch.nn import CrossEntropyLoss +import torch.nn.functional as F + +from transformers.activations import ACT2FN +from transformers.file_utils import ( + ModelOutput, +) +from transformers.modeling_outputs import ( + BaseModelOutputWithPastAndCrossAttentions, + BaseModelOutputWithPoolingAndCrossAttentions, + CausalLMOutputWithCrossAttentions, + MaskedLMOutput, + MultipleChoiceModelOutput, + NextSentencePredictorOutput, + QuestionAnsweringModelOutput, + SequenceClassifierOutput, + TokenClassifierOutput, +) +from transformers.modeling_utils import ( + PreTrainedModel, + apply_chunking_to_forward, + find_pruneable_heads_and_indices, + prune_linear_layer, +) +from transformers.utils import logging +from transformers.models.bert.configuration_bert import BertConfig + +from minigpt4.models.moe.utils import ( + MoEModelOutput, + MoEModelOutputWithPooling, + use_experts_route, + moe_layer_judge, +) +from minigpt4.models.moe.route_moe_layer import RouteMoELayer + +logging.set_verbosity_error() # ignore warning : Some weights of BertLMHeadModel were not initialized from the model checkpoint... +logger = logging.get_logger(__name__) + +# from visualizer import get_local + +class BertEmbeddings(nn.Module): + """Construct the embeddings from word and position embeddings.""" + + def __init__(self, config): + super().__init__() + self.word_embeddings = nn.Embedding( + config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id + ) + self.position_embeddings = nn.Embedding( + config.max_position_embeddings, config.hidden_size + ) + + # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load + # any TensorFlow checkpoint file + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + # position_ids (1, len position emb) is contiguous in memory and exported when serialized + self.register_buffer( + "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)) + ) + self.position_embedding_type = getattr( + config, "position_embedding_type", "absolute" + ) + + self.config = config + + def forward( + self, + input_ids=None, + position_ids=None, + query_embeds=None, + past_key_values_length=0, + ): + if input_ids is not None: + seq_length = input_ids.size()[1] + else: + seq_length = 0 + + if position_ids is None: + position_ids = self.position_ids[ + :, past_key_values_length : seq_length + past_key_values_length + ].clone() + + if input_ids is not None: + embeddings = self.word_embeddings(input_ids) + if self.position_embedding_type == "absolute": + position_embeddings = self.position_embeddings(position_ids) + embeddings = embeddings + position_embeddings + + if query_embeds is not None: + embeddings = torch.cat((query_embeds, embeddings), dim=1) + else: + embeddings = query_embeds + + embeddings = self.LayerNorm(embeddings) + embeddings = self.dropout(embeddings) + return embeddings + + +class BertSelfAttention(nn.Module): + def __init__(self, config, is_cross_attention): + super().__init__() + self.config = config + if config.hidden_size % config.num_attention_heads != 0 and not hasattr( + config, "embedding_size" + ): + raise ValueError( + "The hidden size (%d) is not a multiple of the number of attention " + "heads (%d)" % (config.hidden_size, config.num_attention_heads) + ) + + self.num_attention_heads = config.num_attention_heads # 12 + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) # 64 + self.all_head_size = self.num_attention_heads * self.attention_head_size # 768 + + self.query = nn.Linear(config.hidden_size, self.all_head_size) # nn.Linear(768, 768) + if is_cross_attention: + self.key = nn.Linear(config.encoder_width, self.all_head_size) # nn.Linear(1408, 768) + self.value = nn.Linear(config.encoder_width, self.all_head_size) # nn.Linear(1408, 768) + else: + self.key = nn.Linear(config.hidden_size, self.all_head_size) # nn.Linear(768, 768) + self.value = nn.Linear(config.hidden_size, self.all_head_size) # nn.Linear(768, 768) + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + self.position_embedding_type = getattr( + config, "position_embedding_type", "absolute" + ) + if ( + self.position_embedding_type == "relative_key" + or self.position_embedding_type == "relative_key_query" + ): + self.max_position_embeddings = config.max_position_embeddings + self.distance_embedding = nn.Embedding( + 2 * config.max_position_embeddings - 1, self.attention_head_size + ) + self.save_attention = False + + def save_attn_gradients(self, attn_gradients): + self.attn_gradients = attn_gradients + + def get_attn_gradients(self): + return self.attn_gradients + + def save_attention_map(self, attention_map): + self.attention_map = attention_map + + def get_attention_map(self): + return self.attention_map + + def transpose_for_scores(self, x): + new_x_shape = x.size()[:-1] + ( + self.num_attention_heads, + self.attention_head_size, + ) # torch.Size([1, 257, 12, 64]) + x = x.view(*new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=None, + output_attentions=False, + ): + + # If this is instantiated as a cross-attention module, the keys + # and values come from an encoder; the attention mask needs to be + # such that the encoder's padding tokens are not attended to. + is_cross_attention = encoder_hidden_states is not None + + if is_cross_attention: + key_layer = self.transpose_for_scores(self.key(encoder_hidden_states)) # encoder_hidden_states:[bz,257,1408], torch.Size([1, 12, 257, 64]) + value_layer = self.transpose_for_scores(self.value(encoder_hidden_states)) # torch.Size([1, 12, 257, 64]) + attention_mask = encoder_attention_mask + elif past_key_value is not None: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + key_layer = torch.cat([past_key_value[0], key_layer], dim=2) + value_layer = torch.cat([past_key_value[1], value_layer], dim=2) + else: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + + mixed_query_layer = self.query(hidden_states) + + query_layer = self.transpose_for_scores(mixed_query_layer) # torch.Size([1, 12, 41, 64]) + + past_key_value = (key_layer, value_layer) # torch.Size([1, 12, 41, 257]) + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + + if ( + self.position_embedding_type == "relative_key" + or self.position_embedding_type == "relative_key_query" + ): + seq_length = hidden_states.size()[1] + position_ids_l = torch.arange( + seq_length, dtype=torch.long, device=hidden_states.device + ).view(-1, 1) + position_ids_r = torch.arange( + seq_length, dtype=torch.long, device=hidden_states.device + ).view(1, -1) + distance = position_ids_l - position_ids_r + positional_embedding = self.distance_embedding( + distance + self.max_position_embeddings - 1 + ) + positional_embedding = positional_embedding.to( + dtype=query_layer.dtype + ) # fp16 compatibility + + if self.position_embedding_type == "relative_key": + relative_position_scores = torch.einsum( + "bhld,lrd->bhlr", query_layer, positional_embedding + ) + attention_scores = attention_scores + relative_position_scores + elif self.position_embedding_type == "relative_key_query": + relative_position_scores_query = torch.einsum( + "bhld,lrd->bhlr", query_layer, positional_embedding + ) + relative_position_scores_key = torch.einsum( + "bhrd,lrd->bhlr", key_layer, positional_embedding + ) + attention_scores = ( + attention_scores + + relative_position_scores_query + + relative_position_scores_key + ) + + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + if attention_mask is not None: + # Apply the attention mask is (precomputed for all layers in BertModel forward() function) + # extended_attention_mask + + attention_scores = attention_scores + attention_mask + + # Normalize the attention scores to probabilities. + attention_probs = nn.Softmax(dim=-1)(attention_scores) + + if is_cross_attention and self.save_attention: + self.save_attention_map(attention_probs) + attention_probs.register_hook(self.save_attn_gradients) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs_dropped = self.dropout(attention_probs) # torch.Size([1, 12, 41, 257]) + + # Mask heads if we want to + if head_mask is not None: + attention_probs_dropped = attention_probs_dropped * head_mask + + context_layer = torch.matmul(attention_probs_dropped, value_layer) # torch.Size([1, 12, 41, 64]) + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(*new_context_layer_shape) # torch.Size([1, 41, 768]) + + outputs = ( + (context_layer, attention_probs) if output_attentions else (context_layer,) + ) + + outputs = outputs + (past_key_value,) + return outputs + + +class BertSelfOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class BertAttention(nn.Module): + def __init__(self, config, is_cross_attention=False): + super().__init__() + self.self = BertSelfAttention(config, is_cross_attention) + self.output = BertSelfOutput(config) + self.pruned_heads = set() + + def prune_heads(self, heads): + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, + self.self.num_attention_heads, + self.self.attention_head_size, + self.pruned_heads, + ) + + # Prune linear layers + self.self.query = prune_linear_layer(self.self.query, index) + self.self.key = prune_linear_layer(self.self.key, index) + self.self.value = prune_linear_layer(self.self.value, index) + self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) + + # Update hyper params and store pruned heads + self.self.num_attention_heads = self.self.num_attention_heads - len(heads) + self.self.all_head_size = ( + self.self.attention_head_size * self.self.num_attention_heads + ) + self.pruned_heads = self.pruned_heads.union(heads) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=None, + output_attentions=False, + ): + self_outputs = self.self( + hidden_states, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + ) + attention_output = self.output(self_outputs[0], hidden_states) + + outputs = (attention_output,) + self_outputs[ + 1: + ] # add attentions if we output them + return outputs + + +class BertIntermediate(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.intermediate_size) + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + +class BertOutput(nn.Module): # Add & Norm + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.intermediate_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) # 1 + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + # Move LayerNorm & ResNet out of FFN After MoEFFN + hidden_states = self.LayerNorm(hidden_states + input_tensor) # 1 + return hidden_states + + +class FeedForward(nn.Module): + def __init__(self, config): + nn.Module.__init__(self) + # first layer + self.intermediate_query = BertIntermediate(config) + # second layer + self.output_query = BertOutput(config) + + def forward(self, hidden_states: Tensor): + input_tensor = hidden_states + intermediate_output = self.intermediate_query(hidden_states) + hidden_states = self.output_query(intermediate_output, input_tensor) + return hidden_states + + +class BertLayer(nn.Module): + def __init__(self, config, layer_num): + super().__init__() + self.config = config + self.chunk_size_feed_forward = config.chunk_size_feed_forward + self.seq_len_dim = 1 + self.attention = BertAttention(config) + self.layer_num = layer_num + if ( + self.config.add_cross_attention + and layer_num % self.config.cross_attention_freq == 0 + ): + self.crossattention = BertAttention( + config, is_cross_attention=self.config.add_cross_attention + ) + self.has_cross_attention = True + else: + self.has_cross_attention = False + + self.intermediate = BertIntermediate(config) + self.output = BertOutput(config) + + self.intermediate_query = BertIntermediate(config) + self.output_query = BertOutput(config) + + # Add MoE FFN + self.use_experts = use_experts_route(layer_num) + self.layer_judge = moe_layer_judge(layer_num) + self.num_beams = config.moebert_num_beams + ffn = FeedForward(config) + + if self.use_experts: + self.experts = RouteMoELayer( + hidden_size=config.hidden_size, + expert=ffn, + num_experts=config.moebert_expert_num, + num_beams=config.moebert_num_beams, + layer_judge = self.layer_judge, + route_method=config.route_method, + weight_type=config.moe_weight_type, + ) + else: + self.experts = ffn + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=None, + output_attentions=False, + query_length=0, + beam_scores=None, + expert_route=None, + ): + # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 + self_attn_past_key_value = ( + past_key_value[:2] if past_key_value is not None else None + ) + # import pdb; pdb.set_trace() # 0107test + + # adjust the dimension of hidden_states, attention_mask, encoder_attention_mask and encoder_hidden_states to be the same + if self.num_beams > 1: + if hidden_states.shape[0]== attention_mask.shape[0]*self.num_beams: + # attention_mask dimension to be bz*num_beams + attention_mask = self.adjust_attention_mask(attention_mask) + encoder_attention_mask = self.adjust_attention_mask(encoder_attention_mask) + + if hidden_states.shape[0]*self.num_beams == attention_mask.shape[0]: + # attention_mask dimension back to bz + batch_size = attention_mask.shape[0] + attention_mask = attention_mask[[ i for i in range(0, batch_size, self.num_beams)]] + + if hidden_states.shape[0] == encoder_hidden_states.shape[0]*self.num_beams: + batch_size, visual_tokens, vision_dim = encoder_hidden_states.shape + tmp = encoder_hidden_states.unsqueeze(1).expand(batch_size, self.num_beams, visual_tokens, vision_dim ) + encoder_hidden_states = tmp.contiguous().view(batch_size* self.num_beams, visual_tokens, vision_dim) # torch.Size([bz, 257, 1408]) + + self_attention_outputs = self.attention( + hidden_states, + attention_mask, + head_mask, + output_attentions=output_attentions, + past_key_value=self_attn_past_key_value, + ) + attention_output = self_attention_outputs[0] + outputs = self_attention_outputs[1:-1] + + present_key_value = self_attention_outputs[-1] + + if query_length > 0: + query_attention_output = attention_output[:, :query_length, :] + + if self.has_cross_attention: + + assert ( + encoder_hidden_states is not None + ), "encoder_hidden_states must be given for cross-attention layers" + + cross_attention_outputs = self.crossattention( + query_attention_output, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + output_attentions=output_attentions, + ) + query_attention_output = cross_attention_outputs[0] + outputs = ( + outputs + cross_attention_outputs[1:-1] + ) # add cross attentions if we output attention weights + + if attention_output.shape[1] > query_length: # have text input in Qformer + layer_output_text = apply_chunking_to_forward( + self.feed_forward_chunk, + self.chunk_size_feed_forward, + self.seq_len_dim, + attention_output[:, query_length:, :], + ) + cls_hidden = layer_output_text[0][:, 0, :] # [bz, hidden_size] + + # add moe query ffn + moe_ffn_attention_input = query_attention_output[:, :query_length, :] # [bz, query_length+seq_len, 768] + moe_ffn_attention_mask = attention_mask.squeeze(dim=1).squeeze(dim=1)[:, :query_length] # [bz, 1, 1, query_length+seq_len] + layer_output = self.feed_forward_query_moe(moe_ffn_attention_input, moe_ffn_attention_mask, beam_scores, expert_route, cls_hidden) + # layer_output = (layer_output, beam_scores, expert_route, beam_idx, importance_loss) + # import pdb; pdb.set_trace() # 0107test + + if attention_output.shape[1] > query_length: + if self.layer_judge == 'first' and self.num_beams>1: + # if layer_output[0].shape[0] == layer_output_text.shape[0]*self.num_beams and self.num_beams>1: + # adjust the dimension of layer_output_text to bz*num_beams + layer_output_text = self.adjust_hidden_states_by_num_beams(layer_output_text) + + if self.layer_judge == 'mid' and self.num_beams > 1: + # layer_output_text [bz*num_beams, len, hidden_size] + beam_idx = layer_output[3] + layer_output_text = layer_output_text[beam_idx] + + if self.layer_judge == 'last' and self.num_beams>1: + # select top1 for each sample among beams + # layer_output = (hidden_states, beam_scores, expert_route) + # layer_output & layer_output_text dimen_0 from bz*num_beams to bz + layer_output, layer_output_text = self.route_moe_last_layer_top1(layer_output, layer_output_text) + + layer_output = (torch.cat([layer_output[0], layer_output_text], dim=1), layer_output[1], layer_output[2], layer_output[3],layer_output[4]) + + else: + layer_output = apply_chunking_to_forward( + self.feed_forward_chunk, + self.chunk_size_feed_forward, + self.seq_len_dim, + attention_output, + ) + layer_output = (layer_output, None, None, None, 0.0) + + outputs = (layer_output,) + outputs + + outputs = outputs + (present_key_value,) + + return outputs + + def adjust_attention_mask(self, attention_mask): + batch_size = attention_mask.shape[0] + tmp = attention_mask.unsqueeze(1).expand(batch_size, self.num_beams, 1, 1, attention_mask.shape[3]) + attention_mask = tmp.contiguous().view(batch_size* self.num_beams, 1, 1, attention_mask.shape[3]) # torch.Size([bz*num_beams, 1, 1, 32+input_len]) + return attention_mask + + def adjust_layer_output_text(self, layer_output_text): + batch_size, text_length, hidden_size = layer_output_text.shape + tmp_text = layer_output_text.unsqueeze(1).expand(batch_size, self.num_beams, text_length, hidden_size) + layer_output_text = tmp_text.contiguous().view(-1, text_length, hidden_size) # [bz*num_beams, text_length ,768] + return layer_output_text + + def route_moe_last_layer_top1(self, layer_output, layer_output_text): + batch_size = layer_output[0].shape[0] + raw_batch_size = int(batch_size / self.num_beams) + hidden_states, beam_scores, expert_route, beam_idx = layer_output[0], layer_output[1], layer_output[2], layer_output[3] + layer_output_text = layer_output_text[beam_idx] + + scores = beam_scores.view(raw_batch_size, self.num_beams) + _, gate = torch.topk(scores, 1, dim=1) + selects = [ (bz_idx * self.num_beams + gate[bz_idx].item()) for bz_idx in range(raw_batch_size)] + + layer_output_text = layer_output_text[selects] + hidden_states_new = hidden_states[selects] + beam_scores_new = beam_scores[selects] + expert_route_new = expert_route[selects] + + return (hidden_states_new, beam_scores_new, expert_route_new, layer_output[3], layer_output[4]), layer_output_text + + + def feed_forward_chunk(self, attention_output): + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output, attention_output) + # layer_output = self.LayerNorm(layer_output + attention_output) + return layer_output + + def feed_forward_query_moe(self, attention_output, expert_attention_mask, beam_scores, expert_route, cls_hidden): + if not self.use_experts: + layer_output = self.experts(attention_output) + return layer_output, None, None, None, 0.0 + + layer_output, beam_scores, expert_route, beam_idx, importance_loss = self.experts( + attention_output, expert_attention_mask, beam_scores, expert_route, cls_hidden + ) + + return layer_output, beam_scores, expert_route, beam_idx, importance_loss + +class BertEncoder(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.layer = nn.ModuleList( + [BertLayer(config, i) for i in range(config.num_hidden_layers)] + ) + + # @get_local('all_cross_attentions') + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_values=None, + use_cache=None, + output_attentions=False, + output_hidden_states=False, + return_dict=True, + query_length=0, + ): + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + all_cross_attentions = ( + () if output_attentions and self.config.add_cross_attention else None + ) + + next_decoder_cache = () if use_cache else None + beam_scores=None + expert_route=None + importance_loss = 0 + for i in range(self.config.num_hidden_layers): + + layer_module = self.layer[i] + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer_head_mask = head_mask[i] if head_mask is not None else None + past_key_value = past_key_values[i] if past_key_values is not None else None + + if getattr(self.config, "gradient_checkpointing", False) and self.training: + if use_cache: + logger.warn( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + def create_custom_forward(module): + def custom_forward(*inputs): + return module( + *inputs, past_key_value, output_attentions, query_length, beam_scores, expert_route + ) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(layer_module), + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + ) + else: + layer_outputs = layer_module( + hidden_states, #torch.Size([bz, 32+input_len, 768]) + attention_mask, # torch.Size([bz, 1, 1, 32+input_len]) + layer_head_mask, # None + encoder_hidden_states, # torch.Size([bz, 257, 1408]) + encoder_attention_mask, + past_key_value, + output_attentions, # False + query_length, # 32 + beam_scores, # None + expert_route, # None + ) + hidden_states = layer_outputs[0][0] + beam_scores = beam_scores if layer_outputs[0][1] == None else layer_outputs[0][1] + expert_route = expert_route if layer_outputs[0][2] == None else layer_outputs[0][2] + importance_loss += layer_outputs[0][4] + + if use_cache: + next_decoder_cache += (layer_outputs[-1],) + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + all_cross_attentions = all_cross_attentions + (layer_outputs[2],) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple( + v + for v in [ + hidden_states, + next_decoder_cache, + all_hidden_states, + all_self_attentions, + all_cross_attentions, + ] + if v is not None + ) + + return MoEModelOutput( + last_hidden_state=hidden_states, + past_key_values=next_decoder_cache, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + cross_attentions=all_cross_attentions, + beam_scores=beam_scores, + expert_route=expert_route, + gate_loss=importance_loss, + ) + + +class BertPooler(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.activation = nn.Tanh() + + def forward(self, hidden_states): + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + first_token_tensor = hidden_states[:, 0] + pooled_output = self.dense(first_token_tensor) + pooled_output = self.activation(pooled_output) + return pooled_output + + +class BertPredictionHeadTransform(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + if isinstance(config.hidden_act, str): + self.transform_act_fn = ACT2FN[config.hidden_act] + else: + self.transform_act_fn = config.hidden_act + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.transform_act_fn(hidden_states) + hidden_states = self.LayerNorm(hidden_states) + return hidden_states + + +class BertLMPredictionHead(nn.Module): + def __init__(self, config): + super().__init__() + self.transform = BertPredictionHeadTransform(config) + + # The output weights are the same as the input embeddings, but there is + # an output-only bias for each token. + self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + self.bias = nn.Parameter(torch.zeros(config.vocab_size)) + + # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` + self.decoder.bias = self.bias + + def forward(self, hidden_states): + hidden_states = self.transform(hidden_states) + hidden_states = self.decoder(hidden_states) + return hidden_states + + +class BertOnlyMLMHead(nn.Module): + def __init__(self, config): + super().__init__() + self.predictions = BertLMPredictionHead(config) + + def forward(self, sequence_output): + prediction_scores = self.predictions(sequence_output) + return prediction_scores + + +class BertPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = BertConfig + base_model_prefix = "bert" + _keys_to_ignore_on_load_missing = [r"position_ids"] + + def _init_weights(self, module): + """Initialize the weights""" + if isinstance(module, (nn.Linear, nn.Embedding)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + if isinstance(module, nn.Linear) and module.bias is not None: + module.bias.data.zero_() + + +class BertModel(BertPreTrainedModel): + """ + The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of + cross-attention is added between the self-attention layers, following the architecture described in `Attention is + all you need `__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, + Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin. + argument and :obj:`add_cross_attention` set to :obj:`True`; an :obj:`encoder_hidden_states` is then expected as an + input to the forward pass. + """ + + def __init__(self, config, add_pooling_layer=False): + super().__init__(config) + self.config = config + + self.embeddings = BertEmbeddings(config) + + self.encoder = BertEncoder(config) + + self.pooler = BertPooler(config) if add_pooling_layer else None + + self.init_weights() + + def get_input_embeddings(self): + return self.embeddings.word_embeddings + + def set_input_embeddings(self, value): + self.embeddings.word_embeddings = value + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + def get_extended_attention_mask( + self, + attention_mask: Tensor, + input_shape: Tuple[int], + device: device, + is_decoder: bool, + has_query: bool = False, + ) -> Tensor: + """ + Makes broadcastable attention and causal masks so that future and masked tokens are ignored. + + Arguments: + attention_mask (:obj:`torch.Tensor`): + Mask with ones indicating tokens to attend to, zeros for tokens to ignore. + input_shape (:obj:`Tuple[int]`): + The shape of the input to the model. + device: (:obj:`torch.device`): + The device of the input to the model. + + Returns: + :obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`. + """ + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + + if attention_mask.dim() == 3: + extended_attention_mask = attention_mask[:, None, :, :] + elif attention_mask.dim() == 2: + # Provided a padding mask of dimensions [batch_size, seq_length] + # - if the model is a decoder, apply a causal mask in addition to the padding mask + # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length] + if is_decoder: + batch_size, seq_length = input_shape + + seq_ids = torch.arange(seq_length, device=device) + causal_mask = ( + seq_ids[None, None, :].repeat(batch_size, seq_length, 1) + <= seq_ids[None, :, None] + ) + + # add a prefix ones mask to the causal mask + # causal and attention masks must have same type with pytorch version < 1.3 + causal_mask = causal_mask.to(attention_mask.dtype) + + if causal_mask.shape[1] < attention_mask.shape[1]: + prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1] + if has_query: # UniLM style attention mask + causal_mask = torch.cat( + [ + torch.zeros( + (batch_size, prefix_seq_len, seq_length), + device=device, + dtype=causal_mask.dtype, + ), + causal_mask, + ], + axis=1, + ) + causal_mask = torch.cat( + [ + torch.ones( + (batch_size, causal_mask.shape[1], prefix_seq_len), + device=device, + dtype=causal_mask.dtype, + ), + causal_mask, + ], + axis=-1, + ) + extended_attention_mask = ( + causal_mask[:, None, :, :] * attention_mask[:, None, None, :] + ) + else: + extended_attention_mask = attention_mask[:, None, None, :] + else: + raise ValueError( + "Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format( + input_shape, attention_mask.shape + ) + ) + + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for + # masked positions, this operation will create a tensor which is 0.0 for + # positions we want to attend and -10000.0 for masked positions. + # Since we are adding it to the raw scores before the softmax, this is + # effectively the same as removing these entirely. + extended_attention_mask = extended_attention_mask.to( + dtype=self.dtype + ) # fp16 compatibility + extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 + return extended_attention_mask + + def forward( + self, + input_ids=None, + attention_mask=None, + position_ids=None, + head_mask=None, + query_embeds=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_values=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + is_decoder=False, + ): + r""" + encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``: + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids` + (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)` + instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`. + use_cache (:obj:`bool`, `optional`): + If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up + decoding (see :obj:`past_key_values`). + """ + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + # use_cache = use_cache if use_cache is not None else self.config.use_cache + + if input_ids is None: + assert ( + query_embeds is not None + ), "You have to specify query_embeds when input_ids is None" + + # past_key_values_length + past_key_values_length = ( + past_key_values[0][0].shape[2] - self.config.query_length + if past_key_values is not None + else 0 + ) + + query_length = query_embeds.shape[1] if query_embeds is not None else 0 + + embedding_output = self.embeddings( + input_ids=input_ids, + position_ids=position_ids, + query_embeds=query_embeds, + past_key_values_length=past_key_values_length, + ) + + input_shape = embedding_output.size()[:-1] + batch_size, seq_length = input_shape + device = embedding_output.device + + if attention_mask is None: + attention_mask = torch.ones( + ((batch_size, seq_length + past_key_values_length)), device=device + ) + + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + if is_decoder: + extended_attention_mask = self.get_extended_attention_mask( + attention_mask, + input_ids.shape, + device, + is_decoder, + has_query=(query_embeds is not None), + ) + else: + extended_attention_mask = self.get_extended_attention_mask( + attention_mask, input_shape, device, is_decoder + ) + + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + if encoder_hidden_states is not None: + if type(encoder_hidden_states) == list: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[ + 0 + ].size() + else: + ( + encoder_batch_size, + encoder_sequence_length, + _, + ) = encoder_hidden_states.size() + encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) + + if type(encoder_attention_mask) == list: + encoder_extended_attention_mask = [ + self.invert_attention_mask(mask) for mask in encoder_attention_mask + ] + elif encoder_attention_mask is None: + encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) + encoder_extended_attention_mask = self.invert_attention_mask( + encoder_attention_mask + ) + else: + encoder_extended_attention_mask = self.invert_attention_mask( + encoder_attention_mask + ) + else: + encoder_extended_attention_mask = None + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) + + encoder_outputs = self.encoder( + embedding_output, + attention_mask=extended_attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_extended_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + query_length=query_length, + ) + sequence_output = encoder_outputs[0] + pooled_output = ( + self.pooler(sequence_output) if self.pooler is not None else None + ) + + if not return_dict: + return (sequence_output, pooled_output) + encoder_outputs[1:] + + return MoEModelOutputWithPooling( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + past_key_values=encoder_outputs.past_key_values, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + cross_attentions=encoder_outputs.cross_attentions, + beam_scores=encoder_outputs.beam_scores, + expert_route=encoder_outputs.expert_route, + gate_loss=encoder_outputs.gate_loss + ) + + +class BertMoECLSRouteLMHeadModelLNIn(BertPreTrainedModel): + + _keys_to_ignore_on_load_unexpected = [r"pooler"] + _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"] + + def __init__(self, config): + super().__init__(config) + + self.bert = BertModel(config, add_pooling_layer=False) + self.cls = BertOnlyMLMHead(config) + + self.init_weights() + + def get_output_embeddings(self): + return self.cls.predictions.decoder + + def set_output_embeddings(self, new_embeddings): + self.cls.predictions.decoder = new_embeddings + + def forward( + self, + input_ids=None, + attention_mask=None, + position_ids=None, + head_mask=None, + query_embeds=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + labels=None, + past_key_values=None, + use_cache=True, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + return_logits=False, + is_decoder=True, + reduction="mean", + ): + r""" + encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``: + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in + ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are + ignored (masked), the loss is only computed for the tokens with labels n ``[0, ..., config.vocab_size]`` + past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids` + (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)` + instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`. + use_cache (:obj:`bool`, `optional`): + If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up + decoding (see :obj:`past_key_values`). + Returns: + Example:: + >>> from transformers import BertTokenizer, BertLMHeadModel, BertConfig + >>> import torch + >>> tokenizer = BertTokenizer.from_pretrained('bert-base-cased') + >>> config = BertConfig.from_pretrained("bert-base-cased") + >>> model = BertLMHeadModel.from_pretrained('bert-base-cased', config=config) + >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") + >>> outputs = model(**inputs) + >>> prediction_logits = outputs.logits + """ + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + if labels is not None: + use_cache = False + if past_key_values is not None: + query_embeds = None + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + head_mask=head_mask, + query_embeds=query_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + is_decoder=is_decoder, + ) + # gate_loss = outputs.gate_loss + + sequence_output = outputs[0] + if query_embeds is not None: + sequence_output = outputs[0][:, query_embeds.shape[1] :, :] + + prediction_scores = self.cls(sequence_output) + + if return_logits: + return prediction_scores[:, :-1, :].contiguous() + + lm_loss, total_loss = None, None + if labels is not None: + # we are doing next-token prediction; shift prediction scores and input ids by one + shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous() + labels = labels[:, 1:].contiguous() + loss_fct = CrossEntropyLoss(reduction=reduction, label_smoothing=0.1) + lm_loss = loss_fct( + shifted_prediction_scores.view(-1, self.config.vocab_size), + labels.view(-1), + ) + if reduction == "none": + lm_loss = lm_loss.view(prediction_scores.size(0), -1).sum(1) + + total_loss = lm_loss + + if not return_dict: + output = (prediction_scores,) + outputs[2:] + return ((total_loss,) + output) if total_loss is not None else output + + return CausalLMOutputWithCrossAttentions( + loss=total_loss, + logits=prediction_scores, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, + ) + + def prepare_inputs_for_generation( + self, input_ids, query_embeds, past=None, attention_mask=None, **model_kwargs + ): + # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly + if attention_mask is None: + attention_mask = input_ids.new_ones(input_ids.shape) + query_mask = input_ids.new_ones(query_embeds.shape[:-1]) + attention_mask = torch.cat([query_mask, attention_mask], dim=-1) + + # cut decoder_input_ids if past is used + if past is not None: + input_ids = input_ids[:, -1:] + + return { + "input_ids": input_ids, + "query_embeds": query_embeds, + "attention_mask": attention_mask, + "past_key_values": past, + "encoder_hidden_states": model_kwargs.get("encoder_hidden_states", None), + "encoder_attention_mask": model_kwargs.get("encoder_attention_mask", None), + "is_decoder": True, + } + + def _reorder_cache(self, past, beam_idx): + reordered_past = () + for layer_past in past: + reordered_past += ( + tuple( + past_state.index_select(0, beam_idx) for past_state in layer_past + ), + ) + return reordered_past + + +class BertForMaskedLM(BertPreTrainedModel): + + _keys_to_ignore_on_load_unexpected = [r"pooler"] + _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"] + + def __init__(self, config): + super().__init__(config) + + self.bert = BertModel(config, add_pooling_layer=False) + self.cls = BertOnlyMLMHead(config) + + self.init_weights() + + def get_output_embeddings(self): + return self.cls.predictions.decoder + + def set_output_embeddings(self, new_embeddings): + self.cls.predictions.decoder = new_embeddings + + def forward( + self, + input_ids=None, + attention_mask=None, + position_ids=None, + head_mask=None, + query_embeds=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + labels=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + return_logits=False, + is_decoder=False, + ): + r""" + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Labels for computing the masked language modeling loss. Indices should be in ``[-100, 0, ..., + config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored + (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]`` + """ + + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + head_mask=head_mask, + query_embeds=query_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + is_decoder=is_decoder, + ) + + if query_embeds is not None: + sequence_output = outputs[0][:, query_embeds.shape[1] :, :] + prediction_scores = self.cls(sequence_output) + + if return_logits: + return prediction_scores + + masked_lm_loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() # -100 index = padding token + masked_lm_loss = loss_fct( + prediction_scores.view(-1, self.config.vocab_size), labels.view(-1) + ) + + if not return_dict: + output = (prediction_scores,) + outputs[2:] + return ( + ((masked_lm_loss,) + output) if masked_lm_loss is not None else output + ) + + return MaskedLMOutput( + loss=masked_lm_loss, + logits=prediction_scores, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/minigpt4/models/QformerRouteMoELNUni.py b/minigpt4/models/QformerRouteMoELNUni.py index 6fbbd06..2c1be9a 100644 --- a/minigpt4/models/QformerRouteMoELNUni.py +++ b/minigpt4/models/QformerRouteMoELNUni.py @@ -517,15 +517,6 @@ class BertLayer(nn.Module): outputs + cross_attention_outputs[1:-1] ) # add cross attentions if we output attention weights - # add moe query ffn - # query_attention_output size: [bz, query_length+seq_len, 768] - # attention_mask size: [bz, 1, 1, query_length+seq_len] - moe_ffn_attention_input = query_attention_output[:, :query_length, :] - moe_ffn_attention_mask = attention_mask.squeeze(dim=1).squeeze(dim=1)[:, :query_length] - layer_output = self.feed_forward_query_moe(moe_ffn_attention_input, moe_ffn_attention_mask, beam_scores, expert_route) - # layer_output = (layer_output, beam_scores, expert_route, beam_idx, importance_loss) - import pdb; pdb.set_trace() # 0107test - if attention_output.shape[1] > query_length: # have text input in Qformer layer_output_text = apply_chunking_to_forward( self.feed_forward_chunk, @@ -533,6 +524,19 @@ class BertLayer(nn.Module): self.seq_len_dim, attention_output[:, query_length:, :], ) + cls_hidden = layer_output_text[0][:, 0, :] # [bz, hidden_size] + + # add moe query ffn + # query_attention_output size: [bz, query_length+seq_len, 768] + # attention_mask size: [bz, 1, 1, query_length+seq_len] + moe_ffn_attention_input = query_attention_output[:, :query_length, :] + moe_ffn_attention_mask = attention_mask.squeeze(dim=1).squeeze(dim=1)[:, :query_length] + layer_output = self.feed_forward_query_moe(moe_ffn_attention_input, moe_ffn_attention_mask, beam_scores, expert_route, cls_hidden) + # layer_output = (layer_output, beam_scores, expert_route, beam_idx, importance_loss) + # import pdb; pdb.set_trace() # 0107test + + if attention_output.shape[1] > query_length: # have text input in Qformer + if self.layer_judge == 'first' and self.num_beams>1: # if layer_output[0].shape[0] == layer_output_text.shape[0]*self.num_beams and self.num_beams>1: # adjust the dimension of layer_output_text to bz*num_beams @@ -622,13 +626,13 @@ class BertLayer(nn.Module): layer_output = self.output(intermediate_output, attention_output) return layer_output - def feed_forward_query_moe(self, attention_output, expert_attention_mask, beam_scores, expert_route): + def feed_forward_query_moe(self, attention_output, expert_attention_mask, beam_scores, expert_route, cls_hidden): if not self.use_experts: layer_output = self.experts(attention_output) return layer_output, None, None, None, 0.0 layer_output, beam_scores, expert_route, beam_idx, importance_loss = self.experts( - attention_output, expert_attention_mask, beam_scores, expert_route + attention_output, expert_attention_mask, beam_scores, expert_route, cls_hidden ) return layer_output, beam_scores, expert_route, beam_idx, importance_loss diff --git a/minigpt4/models/QformerRouteMoEUni.py b/minigpt4/models/QformerRouteMoEUni.py index cb29d41..a2ceae2 100644 --- a/minigpt4/models/QformerRouteMoEUni.py +++ b/minigpt4/models/QformerRouteMoEUni.py @@ -523,12 +523,22 @@ class BertLayer(nn.Module): outputs + cross_attention_outputs[1:-1] ) # add cross attentions if we output attention weights + + if attention_output.shape[1] > query_length: # have text input in Qformer + layer_output_text = apply_chunking_to_forward( + self.feed_forward_chunk, + self.chunk_size_feed_forward, + self.seq_len_dim, + attention_output[:, query_length:, :], + ) + cls_hidden = layer_output_text[0][:, 0, :] # [bz, hidden_size] + # add moe query ffn # query_attention_output size: [bz, query_length+seq_len, 768] # attention_mask size: [bz, 1, 1, query_length+seq_len] moe_ffn_attention_input = query_attention_output[:, :query_length, :] moe_ffn_attention_mask = attention_mask.squeeze(dim=1).squeeze(dim=1)[:, :query_length] - layer_output = self.feed_forward_query_moe(moe_ffn_attention_input, moe_ffn_attention_mask, beam_scores, expert_route) + layer_output = self.feed_forward_query_moe(moe_ffn_attention_input, moe_ffn_attention_mask, beam_scores, expert_route, cls_hidden) # layer_output = (layer_output, beam_scores, expert_route, beam_idx, importance_loss) # import pdb; pdb.set_trace() # 0107test @@ -628,14 +638,14 @@ class BertLayer(nn.Module): layer_output = self.output(intermediate_output, attention_output) return layer_output - def feed_forward_query_moe(self, attention_output, expert_attention_mask, beam_scores, expert_route): + def feed_forward_query_moe(self, attention_output, expert_attention_mask, beam_scores, expert_route, cls_hidden): if not self.use_experts: hidden_states = self.experts(attention_output) layer_output = self.expert_ln(hidden_states + attention_output) return layer_output, None, None, None, 0.0 hidden_states, beam_scores, expert_route, beam_idx, importance_loss = self.experts( - attention_output, expert_attention_mask, beam_scores, expert_route + attention_output, expert_attention_mask, beam_scores, expert_route, cls_hidden ) if hidden_states.shape[0]==attention_output.shape[0]*self.num_beams and self.num_beams>1: attention_output = self.adjust_hidden_states_by_num_beams(attention_output) diff --git a/minigpt4/models/blip2.py b/minigpt4/models/blip2.py index 51aa1ee..b8c51b0 100644 --- a/minigpt4/models/blip2.py +++ b/minigpt4/models/blip2.py @@ -27,6 +27,8 @@ from minigpt4.models.QformerRouteMoE import BertMoERouteLMHeadModel from minigpt4.models.QformerRouteMoELN import BertMoERouteLMHeadModelLNIn from minigpt4.models.QformerRouteMoELNUni import BertMoERouteLMHeadModelLNInUniversal from minigpt4.models.QformerRouteMoEUni import BertMoERouteLMHeadModelUniversal +from minigpt4.models.QformerRouteMoECLS import BertMoECLSRouteLMHeadModel +from minigpt4.models.QformerRouteMoECLSLN import BertMoECLSRouteLMHeadModelLNIn from minigpt4.models.eva_vit import create_eva_vit_g from transformers import BertTokenizer from peft import ( @@ -99,6 +101,38 @@ class Blip2Base(BaseModel): return RouteMoEQformer, query_tokens + @classmethod + def init_RouteCLSMoEQformer(cls, num_query_token, vision_width, moebert_expert_num, moebert_num_beams, route_method, moe_weight_type, cross_attention_freq=2, ln_position="out"): + moe_encoder_config = BertConfig.from_pretrained("/mnt/pfs-guan-ssai/nlu/wanghanzi/models/bert-base-uncased") + + moe_encoder_config.encoder_width = vision_width + # insert cross-attention layer every other block + moe_encoder_config.add_cross_attention = True + moe_encoder_config.cross_attention_freq = cross_attention_freq + moe_encoder_config.query_length = num_query_token + + moe_encoder_config.moebert_expert_num = moebert_expert_num + moe_encoder_config.moebert_num_beams = moebert_num_beams + moe_encoder_config.route_method = route_method + moe_encoder_config.moe_weight_type = moe_weight_type + + if ln_position == "out": + RouteMoEQformer = BertMoECLSRouteLMHeadModel.from_pretrained( + "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/bert-base-uncased", config=moe_encoder_config + ) + elif ln_position == "in": + # need to adjust + RouteMoEQformer = BertMoECLSRouteLMHeadModelLNIn.from_pretrained( + "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/bert-base-uncased", config=moe_encoder_config + ) + query_tokens = nn.Parameter( + torch.zeros(1, num_query_token, moe_encoder_config.hidden_size) + ) + query_tokens.data.normal_(mean=0.0, std=moe_encoder_config.initializer_range) + + return RouteMoEQformer, query_tokens + + @classmethod def init_RouteMoEQformer(cls, num_query_token, vision_width, moebert_expert_num, moebert_num_beams, route_method, moe_weight_type, cross_attention_freq=2, ln_position="out"): moe_encoder_config = BertConfig.from_pretrained("/mnt/pfs-guan-ssai/nlu/wanghanzi/models/bert-base-uncased") diff --git a/minigpt4/models/blip2_vicuna_instruct.py b/minigpt4/models/blip2_vicuna_instruct.py index a2befec..318751a 100644 --- a/minigpt4/models/blip2_vicuna_instruct.py +++ b/minigpt4/models/blip2_vicuna_instruct.py @@ -85,7 +85,7 @@ class Blip2VicunaInstruct(Blip2Base): ) print('Initing & Loading Qformer') - if general_version in ['naive_moe', 'route_moe', 'uni_route_moe']: + if general_version in ['naive_moe', 'route_moe', 'uni_route_moe', 'cls_route_moe']: if general_version == 'naive_moe': self.Qformer, self.query_tokens = self.init_QformerMoE( num_query_token=num_query_token, @@ -121,6 +121,17 @@ class Blip2VicunaInstruct(Blip2Base): cross_attention_freq=2, ln_position=ln_position, ) + elif general_version == 'cls_route_moe': + self.Qformer, self.query_tokens = self.init_RouteCLSMoEQformer( + num_query_token=num_query_token, + vision_width=self.visual_encoder.num_features, + moebert_expert_num=moebert_expert_num, + moebert_num_beams=moebert_num_beams, + route_method=moebert_route_method, + moe_weight_type=moe_weight_type, + cross_attention_freq=2, + ln_position=ln_position, + ) elif general_version == 'base': self.Qformer, self.query_tokens = self.init_Qformer( diff --git a/minigpt4/models/moe/moe_layer_backup.py b/minigpt4/models/moe/backup/moe_layer_backup.py similarity index 100% rename from minigpt4/models/moe/moe_layer_backup.py rename to minigpt4/models/moe/backup/moe_layer_backup.py diff --git a/minigpt4/models/moe/backup/route_moe_layer_backup.py b/minigpt4/models/moe/backup/route_moe_layer_backup.py new file mode 100644 index 0000000..39ecb18 --- /dev/null +++ b/minigpt4/models/moe/backup/route_moe_layer_backup.py @@ -0,0 +1,265 @@ +import copy +import pickle +import torch +import torch.nn as nn +import torch.nn.functional as F + +class RouteMoELayer(nn.Module): + def __init__(self, hidden_size, expert, num_experts, num_beams=2, layer_judge=None, route_method="pre-route", weight_type="ffn_prob"): + # remove hash list + nn.Module.__init__(self) + self.num_experts = num_experts + self.experts = nn.ModuleList([copy.deepcopy(expert) for i in range(num_experts)]) + self.num_beams = num_beams + self.hidden_size = hidden_size + self.layer_judge = layer_judge + self.weight_type = weight_type + + self.route_method = route_method + if self.route_method == "pre-route": + self.gate = nn.Linear(hidden_size, num_experts, bias=False).float() + elif self.route_method in ["post-route", "post-route-dp"]: + gate = nn.Linear(hidden_size, 1, bias=False).float() + self.gate = gate + # self.gates = nn.ModuleList([copy.deepcopy(gate) for i in range(num_experts)]) + + def _importance_auxiliary_loss(self, prob_gate): + # From VMOE + # _importance_auxiliary_loss + axis = tuple(range(prob_gate.ndim - 1)) # All except last. + importance_per_expert = torch.sum(prob_gate, dim=axis) + std_importance_per_expert = torch.std(importance_per_expert) + mean_importance_per_expert = torch.mean(importance_per_expert) + # Compute coefficient of variation (i.e. std/mean) squared. + return (std_importance_per_expert / mean_importance_per_expert)**2 + + + def forward_gate(self, x): + """ + x : torch.Size([bz*num_beams, 32, 768]) or torch.Size([bz, 32, 768]) + prob_gate : torch.Size([bz*num_beams, num_experts]) or torch.Size([bz, num_experts]) + """ + attention_mask = torch.ones(x.shape[0], x.shape[1]).to(x.device) + x_masked = x * attention_mask.unsqueeze(-1) # torch.Size([bz*num_beams, 32, 768]) + # x_average = x_masked.sum(1) / attention_mask.unsqueeze(-1).sum(1) # torch.Size([bz*num_beams, 768]) + x_average = torch.mean(x_masked, dim=1) # torch.Size([bz*num_beams, 768]) + logits_gate = self.gate(x_average) # torch.Size([bz*num_beams, num_experts]) + prob_gate = F.softmax(logits_gate, dim=-1) # torch.Size([bz*num_beams, num_experts]) + return prob_gate + + def dp_search(self, current_scores_log, beam_scores, expert_route, batch_size): + if self.layer_judge=='first' and self.route_method in ['post-route-dp']: + # current_scores_log torch.Size([bz, num_experts]) + assert beam_scores==None and expert_route==None + current_scores = torch.exp(current_scores_log) + topk_values, gate = torch.topk(current_scores, self.num_beams, dim=1) # gate, 每个样本被分配的expert: torch.Size([bz, topk]) + beam_scores = topk_values.view(self.num_beams * batch_size) # torch.Size([bz * num_beams]) + expert_route = gate.view(self.num_beams * batch_size).unsqueeze(1) # torch.Size([bz * num_beams,1]) + beam_idx = torch.tensor(range(self.num_beams * batch_size)) + + else: + batch_size = int(batch_size // self.num_beams) + next_scores_raw = current_scores_log + torch.log(beam_scores).unsqueeze(1) # torch.Size([4*3, 5]) # 取log 之后,可以直接相加概率 + next_scores_exp = torch.exp(next_scores_raw) + + next_scores_raw, next_experts_raw = torch.topk(next_scores_exp, 1, dim=1, largest=True, sorted=True) + next_scores = next_scores_raw.view(batch_size, self.num_beams) + next_experts = next_experts_raw.view(batch_size, self.num_beams) + # next_scores, next_experts = torch.topk(current_scores_log, 1, dim=1, largest=True, sorted=True) # equal 等价 + # next_scores torch.Size([bz * num_beams, 1]) + # next_tokens torch.Size([bz * num_beams, 1]) + + next_batch_beam = list() + for batch_idx in range(batch_size): + next_sent_beam = list() + expert_id = next_experts[batch_idx] + expert_score = next_scores[batch_idx] + values, index = torch.topk(expert_score, self.num_beams, dim=0, largest=True, sorted=True) + for i in range(self.num_beams): + beam_id = index[i].item() + ex_id = expert_id[beam_id].item() + effective_beam_id = batch_idx*self.num_beams + beam_id + next_sent_beam.append((values[i], ex_id, effective_beam_id)) + next_batch_beam.extend(next_sent_beam) + + beam_scores = beam_scores.new([x[0] for x in next_batch_beam]) + beam_experts = expert_route[:,-1].new([x[1] for x in next_batch_beam]) + beam_idx = expert_route[:,-1].new([x[2] for x in next_batch_beam]) + pre_route = expert_route[beam_idx,:] + expert_route = torch.cat([pre_route, beam_experts.unsqueeze(1)], dim=-1) + + return beam_scores, expert_route, beam_idx + + def beam_search(self, current_scores_log, beam_scores, expert_route, batch_size): + if self.layer_judge=='first' and self.route_method in ['pre-route', 'post-route']: + # current_scores_log torch.Size([bz, num_experts]) + assert beam_scores==None and expert_route==None + current_scores = torch.exp(current_scores_log) + topk_values, gate = torch.topk(current_scores, self.num_beams, dim=1) # gate, 每个样本被分配的expert: torch.Size([bz, topk]) + beam_scores = topk_values.view(self.num_beams * batch_size) # torch.Size([bz * num_beams]) + expert_route = gate.view(self.num_beams * batch_size).unsqueeze(1) # torch.Size([bz * num_beams,1]) + beam_idx = torch.tensor(range(self.num_beams * batch_size)) + + else: + batch_size = int(batch_size // self.num_beams) + next_scores_raw = current_scores_log + torch.log(beam_scores).unsqueeze(1) # torch.Size([4*3, 5]) # 取log 之后,可以直接相加概率 + next_scores_exp = torch.exp(next_scores_raw) + + next_scores_raw1 = next_scores_exp.view( + batch_size, self.num_beams * self.num_experts + ) # torch.Size([bz, num_beams*num_experts]) + + next_scores, next_experts = torch.topk(next_scores_raw1, self.num_beams, dim=1, largest=True, sorted=True) + # next_scores torch.Size([bz, num_beams]) + # next_tokens torch.Size([bz, num_beams]) + + next_batch_beam = list() + for batch_idx in range(batch_size): + next_sent_beam = list() + for rank, (expert_id, expert_score) in enumerate( + zip(next_experts[batch_idx], next_scores[batch_idx]) + ): + expert_id = expert_id.item() + beam_id = expert_id // self.num_experts + ex_id = expert_id % self.num_experts + effective_beam_id = batch_idx*self.num_beams + beam_id + + next_sent_beam.append((expert_score, ex_id, effective_beam_id)) + next_batch_beam.extend(next_sent_beam) + + beam_scores = beam_scores.new([x[0] for x in next_batch_beam]) + beam_experts = expert_route[:,-1].new([x[1] for x in next_batch_beam]) + beam_idx = expert_route[:,-1].new([x[2] for x in next_batch_beam]) + pre_route = expert_route[beam_idx,:] + expert_route = torch.cat([pre_route, beam_experts.unsqueeze(1)], dim=-1) + + return beam_scores, expert_route, beam_idx + + def forward_expert_ffn(self, x, expert_select, current_scores): + """ + x_repeat : [bz*num_beams, 32,768] + expert_select : [bz*num_beams] + current_scores : [bz*num_beams, num_experts] / [bz, num_experts] + """ + # add_1228 l2_normalization + # normalized_tensor = torch.nn.functional.normalize(current_scores, p=2, dim=0) # L2 Normalization torch.Size([bz, topk]) + # tmp_prob = normalized_tensor.unsqueeze(-1).unsqueeze(-1) + # import pdb;pdb.set_trace() + outputs = list() + for i in range(self.num_experts): + output_x = self.experts[i].forward(x) + outputs.append(output_x.unsqueeze(1)) + candidate_output = torch.cat(outputs, dim=1) + expert_select_matrix = F.one_hot(expert_select, self.num_experts) + if self.weight_type == 'ffn_prob': + tmp_prob = current_scores * expert_select_matrix + candidate_output = candidate_output * tmp_prob.unsqueeze(-1).unsqueeze(-1) + else: + candidate_output = candidate_output * expert_select_matrix.unsqueeze(-1).unsqueeze(-1) + output = torch.sum(candidate_output, dim=1) + # import pdb;pdb.set_trace() + return output # torch.Size([bz*num_beams, 32, 768]) + + def forward_pre_route(self, x, beam_scores, expert_route, use_log=True): + + current_scores = self.forward_gate(x) # [bz, num_beams] / [bz*num_beams, num_beams] + + importance_loss = self._importance_auxiliary_loss(current_scores) + + if use_log: + current_scores_log = torch.log(current_scores) # 取log之后可以直接相加 + else: + current_scores_log = current_scores + # import pdb;pdb.set_trace() + batch_size, num_tokens = x.shape[0], x.shape[1] + beam_scores, expert_route, beam_idx = self.beam_search(current_scores_log, beam_scores, expert_route, batch_size) + current_expert_select = expert_route[:,-1] + + if self.layer_judge=='first': # expand first dim to batch_size * num_beams + replicated_tensor = x.unsqueeze(1).expand(batch_size, self.num_beams, num_tokens, self.hidden_size) + x = replicated_tensor.contiguous().view(-1, num_tokens, self.hidden_size) # [bz*num_beams, 32,768] + current_scores_t = current_scores.unsqueeze(1).expand(batch_size, self.num_beams, self.num_experts) + current_scores = current_scores_t.contiguous().view(-1, self.num_experts) # [bz*num_beams, num_experts] + + input_x = x[beam_idx] + candidate_output = self.forward_expert_ffn(input_x, current_expert_select, current_scores) # [bz*num_beams, 32,768] + # import pdb;pdb.set_trace() + return candidate_output, beam_scores, expert_route, beam_idx, importance_loss + + def forward_post_route(self, x, beam_scores, expert_route, use_log=True): + + attention_mask = torch.ones(x.shape[0], x.shape[1]).to(x.device) + x_masked = x * attention_mask.unsqueeze(-1) # torch.Size([bz, 32, 768]) + + def forward_expert(input_x, expert_idx): + output_x = self.experts[expert_idx].forward(input_x) + return output_x + + outputs = list() + logits_gate_lst = list() + for expert_idx in range(self.num_experts): + output_x = forward_expert(x_masked, expert_idx) + # output_x_aver = output_x.sum(1) / attention_mask.unsqueeze(-1).sum(1) # torch.Size([bz*num_beam, 768]) + output_x_aver = torch.mean(output_x, dim=1) + # gate_score = self.gates[expert_idx](output_x_aver) + gate_score = self.gate(output_x_aver) + logits_gate_lst.append(gate_score) + outputs.append(output_x.unsqueeze(0)) + + candidate_output_raw = torch.cat(outputs) # torch.Size([num_expert, bz*num_beam, 32, 768]) + logits_gate = torch.cat(logits_gate_lst,dim=1)# torch.Size([bz*num_beam, num_expert]) + current_scores = F.softmax(logits_gate, dim=-1) # torch.Size([bz*num_beam, num_experts]) + + if use_log: + current_scores_log = torch.log(current_scores) # 取log之后可以直接相加 + else: + current_scores_log = current_scores + + # importance loss + importance_loss = self._importance_auxiliary_loss(current_scores) + + batch_size, num_tokens = x.shape[0], x.shape[1] # bz*num_beam + + if self.route_method == 'post-route': + beam_scores, expert_route, beam_idx = self.beam_search(current_scores_log, beam_scores, expert_route, batch_size) + elif self.route_method == 'post-route-dp': + beam_scores, expert_route, beam_idx = self.dp_search(current_scores_log, beam_scores, expert_route, batch_size) + + # beam_scores torch.Size([bz*num_beam]) + # expert_route torch.Size([bz*num_beam, layer_n]) + current_select_expert = expert_route[:,-1] + # current_select_expert torch.Size([bz*num_beam, 1]) + + if self.layer_judge == 'first': + replicated_tensor = candidate_output_raw.unsqueeze(2).expand(self.num_experts, batch_size, self.num_beams, num_tokens, self.hidden_size) + candidate_output_raw = replicated_tensor.contiguous().view(self.num_experts, -1, num_tokens, self.hidden_size) # [bz*num_beams, 32,768] + current_scores_t = current_scores.unsqueeze(1).expand(batch_size, self.num_beams, self.num_experts) + current_scores = current_scores_t.contiguous().view(-1, self.num_experts) # [bz*num_beams, num_experts] + + candidate_output = candidate_output_raw.permute(1, 0, 2, 3)[beam_idx] # torch.Size([8, 2, 32, 768]) + expert_select_matrix = F.one_hot(current_select_expert, self.num_experts) + if self.weight_type == 'ffn_prob': + tmp_prob = current_scores[beam_idx] * expert_select_matrix + output = candidate_output * tmp_prob.unsqueeze(-1).unsqueeze(-1) + else: + output = candidate_output * expert_select_matrix.unsqueeze(-1).unsqueeze(-1) + final_output = torch.sum(output, dim=1) + + return final_output, beam_scores, expert_route, beam_idx, importance_loss + + def forward(self, x, attention_mask, beam_scores, expert_route, use_log=True): + """ + if first_layer: x [bz, 32, 768] + else: x [bz*num_beams, 32, 768] + + """ + if self.route_method == 'pre-route': + candidate_output, beam_scores, expert_route, beam_idx, importance_loss = self.forward_pre_route(x, beam_scores, expert_route, use_log=True) + elif self.route_method in ['post-route', 'post-route-dp']: + candidate_output, beam_scores, expert_route, beam_idx, importance_loss = self.forward_post_route(x, beam_scores, expert_route, use_log=True) + else: + assert("route method should in pre-route, post-route, post-route-dp") + return candidate_output, beam_scores, expert_route, beam_idx, importance_loss + + + diff --git a/minigpt4/models/moe/test_moe_layer.py b/minigpt4/models/moe/backup/test_moe_layer.py similarity index 100% rename from minigpt4/models/moe/test_moe_layer.py rename to minigpt4/models/moe/backup/test_moe_layer.py diff --git a/minigpt4/models/moe/route_moe_layer.py b/minigpt4/models/moe/route_moe_layer.py index 39ecb18..51475a8 100644 --- a/minigpt4/models/moe/route_moe_layer.py +++ b/minigpt4/models/moe/route_moe_layer.py @@ -22,6 +22,9 @@ class RouteMoELayer(nn.Module): gate = nn.Linear(hidden_size, 1, bias=False).float() self.gate = gate # self.gates = nn.ModuleList([copy.deepcopy(gate) for i in range(num_experts)]) + elif self.route_method in ['cls-route', 'cls-query-route', 'cls-cross-route']: + self.gate = nn.Linear(hidden_size, 1, bias=False).float() + def _importance_auxiliary_loss(self, prob_gate): # From VMOE @@ -33,7 +36,6 @@ class RouteMoELayer(nn.Module): # Compute coefficient of variation (i.e. std/mean) squared. return (std_importance_per_expert / mean_importance_per_expert)**2 - def forward_gate(self, x): """ x : torch.Size([bz*num_beams, 32, 768]) or torch.Size([bz, 32, 768]) @@ -91,7 +93,7 @@ class RouteMoELayer(nn.Module): return beam_scores, expert_route, beam_idx def beam_search(self, current_scores_log, beam_scores, expert_route, batch_size): - if self.layer_judge=='first' and self.route_method in ['pre-route', 'post-route']: + if self.layer_judge=='first' and self.route_method in ['pre-route', 'post-route', 'cls-route', 'cls-query-route', 'cls-cross-route']: # current_scores_log torch.Size([bz, num_experts]) assert beam_scores==None and expert_route==None current_scores = torch.exp(current_scores_log) @@ -247,7 +249,85 @@ class RouteMoELayer(nn.Module): return final_output, beam_scores, expert_route, beam_idx, importance_loss - def forward(self, x, attention_mask, beam_scores, expert_route, use_log=True): + def calculate_cls_gate_score(self, cls_hidden, output_x): + + if self.route_method == 'cls-route': + # cls_hidden = [bz, 768] + gate_score = self.gate(cls_hidden) # bz, 1 + elif self.route_method == 'cls-query-route': # add cls_hiddin on query_token mean pool hidden + mean_output = torch.mean(output_x, dim=1) # bz, 768 + gate_score = self.gate(mean_output+cls_hidden) # bz, 1 + elif self.route_method == 'cls-cross-route': + # cls_hidden as Q, output_x as K, V calculate scaled dot-product attention between Q and K and V + # cls_hidden: bz, 768 + # output_x: bz, 32, 768 + Q = cls_hidden.unsqueeze(1) # bz, 1, 768 + K = output_x # bz, 32, 768 + V = output_x # bz, 32, 768 + # scaled dot-product attention + QK = torch.matmul(Q, K.transpose(-1, -2)) / (K.size(-1) ** 0.5) # bz, 1, 32 + QK = F.softmax(QK, dim=-1) # bz, 1, 32 + gate_score = torch.matmul(QK, V) # bz, 1, 768 + gate_score = gate_score.squeeze(1) # bz, 768 + gate_score = self.gate(gate_score) # bz, 1 + return gate_score + + + def forward_cls_route(self, x, beam_scores, expert_route, cls_hidden): + attention_mask = torch.ones(x.shape[0], x.shape[1]).to(x.device) + x_masked = x * attention_mask.unsqueeze(-1) # torch.Size([bz, 32, 768]) + + def forward_expert(input_x, expert_idx): + output_x = self.experts[expert_idx].forward(input_x) + return output_x + + outputs = list() + logits_gate_lst = list() + for expert_idx in range(self.num_experts): + output_x = forward_expert(x_masked, expert_idx) # bz, 32, 768 + + gate_score = self.calculate_cls_gate_score(cls_hidden, output_x) # bz, 1 + + logits_gate_lst.append(gate_score) + outputs.append(output_x.unsqueeze(0)) + + candidate_output_raw = torch.cat(outputs) # torch.Size([num_expert, bz*num_beam, 32, 768]) + logits_gate = torch.cat(logits_gate_lst,dim=1)# torch.Size([bz*num_beam, num_expert]) + current_scores = F.softmax(logits_gate, dim=-1) # torch.Size([bz*num_beam, num_experts]) + + current_scores_log = torch.log(current_scores) # 取log之后可以直接相加 + + # importance loss + importance_loss = self._importance_auxiliary_loss(current_scores) + + batch_size, num_tokens = x.shape[0], x.shape[1] # bz*num_beam + + beam_scores, expert_route, beam_idx = self.beam_search(current_scores_log, beam_scores, expert_route, batch_size) + + # beam_scores torch.Size([bz*num_beam]) + # expert_route torch.Size([bz*num_beam, layer_n]) + current_select_expert = expert_route[:,-1] + # current_select_expert torch.Size([bz*num_beam, 1]) + + if self.layer_judge == 'first': + replicated_tensor = candidate_output_raw.unsqueeze(2).expand(self.num_experts, batch_size, self.num_beams, num_tokens, self.hidden_size) + candidate_output_raw = replicated_tensor.contiguous().view(self.num_experts, -1, num_tokens, self.hidden_size) # [bz*num_beams, 32,768] + current_scores_t = current_scores.unsqueeze(1).expand(batch_size, self.num_beams, self.num_experts) + current_scores = current_scores_t.contiguous().view(-1, self.num_experts) # [bz*num_beams, num_experts] + + candidate_output = candidate_output_raw.permute(1, 0, 2, 3)[beam_idx] # torch.Size([8, 2, 32, 768]) + expert_select_matrix = F.one_hot(current_select_expert, self.num_experts) + if self.weight_type == 'ffn_prob': + tmp_prob = current_scores[beam_idx] * expert_select_matrix + output = candidate_output * tmp_prob.unsqueeze(-1).unsqueeze(-1) + else: + output = candidate_output * expert_select_matrix.unsqueeze(-1).unsqueeze(-1) + final_output = torch.sum(output, dim=1) + + return final_output, beam_scores, expert_route, beam_idx, importance_loss + + + def forward(self, x, attention_mask, beam_scores, expert_route, cls_hidden=None): """ if first_layer: x [bz, 32, 768] else: x [bz*num_beams, 32, 768] @@ -257,9 +337,8 @@ class RouteMoELayer(nn.Module): candidate_output, beam_scores, expert_route, beam_idx, importance_loss = self.forward_pre_route(x, beam_scores, expert_route, use_log=True) elif self.route_method in ['post-route', 'post-route-dp']: candidate_output, beam_scores, expert_route, beam_idx, importance_loss = self.forward_post_route(x, beam_scores, expert_route, use_log=True) + elif self.route_method in ['cls-route', 'cls-query-route', 'cls-cross-route']: + candidate_output, beam_scores, expert_route, beam_idx, importance_loss = self.forward_cls_route(x, beam_scores, expert_route, cls_hidden) else: assert("route method should in pre-route, post-route, post-route-dp") return candidate_output, beam_scores, expert_route, beam_idx, importance_loss - - - diff --git a/minigpt4/models/moe/uniroute_moe_layer.py b/minigpt4/models/moe/uniroute_moe_layer.py index 5cd2069..4a2ff02 100644 --- a/minigpt4/models/moe/uniroute_moe_layer.py +++ b/minigpt4/models/moe/uniroute_moe_layer.py @@ -142,7 +142,32 @@ class UniRouteMoELayer(nn.Module): # import pdb;pdb.set_trace() return candidate_output, beam_scores, expert_route, beam_idx, importance_loss - def forward_post_route_uni(self, x, beam_scores, expert_route, use_log=True): + + def calculate_cls_gate_score(self, cls_hidden, output_x): + + if self.route_method == 'uni-cls-route': + # cls_hidden = [bz, 768] + gate_score = self.gate(cls_hidden) # bz, 1 + elif self.route_method == 'uni-cls-query-route': # add cls_hiddin on query_token mean pool hidden + mean_output = torch.mean(output_x, dim=1) # bz, 768 + gate_score = self.gate(mean_output+cls_hidden) # bz, 1 + elif self.route_method == 'uni-cls-cross-route': + # cls_hidden as Q, output_x as K, V calculate scaled dot-product attention between Q and K and V + # cls_hidden: bz, 768 + # output_x: bz, 32, 768 + Q = cls_hidden.unsqueeze(1) # bz, 1, 768 + K = output_x # bz, 32, 768 + V = output_x # bz, 32, 768 + # scaled dot-product attention + QK = torch.matmul(Q, K.transpose(-1, -2)) / (K.size(-1) ** 0.5) # bz, 1, 32 + QK = F.softmax(QK, dim=-1) # bz, 1, 32 + gate_score = torch.matmul(QK, V) # bz, 1, 768 + gate_score = gate_score.squeeze(1) # bz, 768 + gate_score = self.gate(gate_score) # bz, 1 + return gate_score + + + def forward_route_uni(self, x, beam_scores, expert_route, use_log=True, cls_hidden=None): if beam_scores == None: batch_size = x.shape[0] @@ -165,8 +190,14 @@ class UniRouteMoELayer(nn.Module): logits_gate_lst = list() for expert_idx in range(self.num_route_experts): # num_expert-1 output_x = forward_expert(x_masked, expert_idx) - output_x_aver = torch.mean(output_x, dim=1) - gate_score = self.gate(output_x_aver) + + if self.route_method == 'post-route-uni': + output_x_aver = torch.mean(output_x, dim=1) + gate_score = self.gate(output_x_aver) + + elif self.route_method in ['uni-cls-route', 'uni-cls-query-route', 'uni-cls-cross-route'] and cls_hidden is not None: + gate_score = self.calculate_cls_gate_score(cls_hidden, output_x) + logits_gate_lst.append(gate_score) outputs.append(output_x.unsqueeze(0)) @@ -198,8 +229,6 @@ class UniRouteMoELayer(nn.Module): output = candidate_output * expert_select_matrix.unsqueeze(-1).unsqueeze(-1) experts_output = torch.sum(output, dim=1) # [bz*num_beams-1, 32, 768] - # import pdb; pdb.set_trace() - #################### ### universal expert #################### @@ -219,15 +248,15 @@ class UniRouteMoELayer(nn.Module): return final_output, beam_scores, expert_route, beam_idx, importance_loss - def forward(self, x, attention_mask, beam_scores, expert_route, use_log=True): + def forward(self, x, attention_mask, beam_scores, expert_route, cls_hidden): """ if first_layer: x [bz, 32, 768] else: x [bz*num_beams, 32, 768] """ if self.route_method == 'pre-route-uni': candidate_output, beam_scores, expert_route, beam_idx, importance_loss = self.forward_pre_route(x, beam_scores, expert_route, use_log=True) - elif self.route_method in ['post-route-uni']: - candidate_output, beam_scores, expert_route, beam_idx, importance_loss = self.forward_post_route_uni(x, beam_scores, expert_route, use_log=True) + elif self.route_method in ['post-route-uni', 'uni-cls-route', 'uni-cls-query-route', 'uni-cls-cross-route']: + candidate_output, beam_scores, expert_route, beam_idx, importance_loss = self.forward_route_uni(x, beam_scores, expert_route, use_log=True, cls_hidden=cls_hidden) return candidate_output, beam_scores, expert_route, beam_idx, importance_loss