Attention.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:TFCommon 作者: MU94W 项目源码 文件源码
def __init__(self, attention_units, memory, sequence_length=None, time_major=True, mode=0):
        self.attention_units    = attention_units
        self.enc_units          = memory.get_shape()[-1].value

        if time_major:
            memory = tf.transpose(memory, perm=(1,0,2))

        self.enc_length = tf.shape(memory)[1]
        self.batch_size = tf.shape(memory)[0]
        self.mode = mode
        self.mask = array_ops.sequence_mask(sequence_length, self.enc_length) if sequence_length is not None else None
        self.tiny = -math.inf * tf.ones(shape=(self.batch_size, self.enc_length))

        self.memory = tf.reshape(memory, (tf.shape(memory)[0], self.enc_length, 1, self.enc_units))
        ### pre-compute Uahj to minimize the computational cost
        with tf.variable_scope('attention'):
            Ua = tf.get_variable(name='Ua', shape=(1, 1, self.enc_units, self.attention_units))
        self.hidden_feats = tf.nn.conv2d(self.memory, Ua, [1,1,1,1], "SAME")
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号