Attention.py 文件源码

python
阅读 34 收藏 0 点赞 0 评论 0

项目:TFCommon 作者: MU94W 项目源码 文件源码
def __call__(self, query):

        with tf.variable_scope('attention'):
            # Check if the memory's batch_size is consistent with query's batch_size

            query_units = query.get_shape()[-1].value

            Wa = tf.get_variable(name='Wa', shape=(query_units, self.attention_units))
            Va = tf.get_variable(name='Va', shape=(self.attention_units,),
                                 initializer=tf.constant_initializer(0.0) if self.mode == 0 else tf.constant_initializer(1e-2))
            b  = tf.get_variable(name='b',  shape=(self.attention_units,),
                                 initializer=tf.constant_initializer(0.0) if self.mode == 0 else tf.constant_initializer(0.5))

            # 1st. compute query_feat (query's repsentation in attention module)
            query_feat = tf.reshape(tf.matmul(query, Wa), (-1, 1, 1, self.attention_units))

            # 2nd. compute the energy for all time steps in encoder (element-wise mul then reduce)
            e = tf.reduce_sum(Va * tf.nn.tanh(self.hidden_feats + query_feat + b), axis=(2,3))

            # 3rd. compute the score
            if self.mask is not None:
                exp_e = tf.exp(e)
                exp_e = exp_e * self.mask
                alpha = tf.divide(exp_e, tf.reduce_sum(exp_e, axis=-1, keep_dims=True))
            else:
                alpha = tf.nn.softmax(e)

            # 4th. get the weighted context from memory (element-wise mul then reduce)
            context = tf.reshape(alpha, (tf.shape(query)[0], self.enc_length, 1, 1)) * self.memory
            context = tf.reduce_sum(context, axis=(1, 2))

            return context, alpha
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号