a2_base_model.py 文件源码

python
阅读 29 收藏 0 点赞 0 评论 0

项目:text_classification 作者: brightmart 项目源码 文件源码
def sub_layer_multi_head_attention(self ,layer_index ,Q ,K_s,type,mask=None,is_training=None,dropout_keep_prob=None)  :# COMMON FUNCTION
        """
        multi head attention as sub layer
        :param layer_index: index of layer number
        :param Q: shape should be: [batch_size,sequence_length,embed_size]
        :param k_s: shape should be: [batch_size,sequence_length,embed_size]
        :param type: encoder,decoder or encoder_decoder_attention
        :param mask: when use mask,illegal connection will be mask as huge big negative value.so it's possiblitity will become zero.
        :return: output of multi head attention.shape:[batch_size,sequence_length,d_model]
        """
        with tf.variable_scope("base_mode_sub_layer_multi_head_attention_" + type+str(layer_index)):
            # below is to handle attention for encoder and decoder with difference length:
            #length=self.decoder_sent_length if (type!='encoder' and self.sequence_length!=self.decoder_sent_length) else self.sequence_length #TODO this may be useful
            length=self.sequence_length
            #1. get V as learned parameters
            V_s = tf.get_variable("V_s", shape=(self.batch_size,length,self.d_model),initializer=self.initializer)
            #2. call function of multi head attention to get result
            multi_head_attention_class = MultiHeadAttention(Q, K_s, V_s, self.d_model, self.d_k, self.d_v, self.sequence_length,
                                                            self.h,type=type,is_training=is_training,mask=mask,dropout_rate=(1.0-dropout_keep_prob))
            sub_layer_multi_head_attention_output = multi_head_attention_class.multi_head_attention_fn()  # [batch_size*sequence_length,d_model]
        return sub_layer_multi_head_attention_output  # [batch_size,sequence_length,d_model]
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号