attention.py 文件源码

python
阅读 31 收藏 0 点赞 0 评论 0

项目:sequencing 作者: SwordYork 项目源码 文件源码
def __init__(self, query_size, keys, values, values_length,
                 name='attention'):
        self.attention_size = keys.get_shape().as_list()[-1]
        self.keys = keys
        self.values = values
        self.values_length = values_length
        self.query_trans = LinearOp(query_size, self.attention_size, name=name)

        with tf.variable_scope(name):
            self.v_att = tf.get_variable('v_att', shape=[self.attention_size],
                                         dtype=DTYPE)

        self.time_axis = 0 if TIME_MAJOR else 1

        # Replace all scores for padded inputs with tf.float32.min
        num_scores = tf.shape(self.keys)[self.time_axis]
        scores_mask = tf.sequence_mask(
            lengths=tf.to_int32(self.values_length),
            maxlen=tf.to_int32(num_scores),
            dtype=DTYPE)

        if TIME_MAJOR:
            scores_mask = tf.transpose(scores_mask)

        self.scores_mask = scores_mask
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号