attention_encoder.py 文件源码

python
阅读 19 收藏 0 点赞 0 评论 0

项目:master-thesis 作者: AndreasMadsen 项目源码 文件源码
def attention_encoder(x, length,
                      num_blocks=3,
                      name=None, reuse=None):
    with tf.variable_scope(name, "attention-encoder", values=[x, length],
                           reuse=reuse) as scope:
        # get shapes
        batch_size = x.get_shape().as_list()[0]
        if batch_size is None:
            batch_size = tf.shape(x)[0]

        dims = int(x.get_shape()[-1])

        # encode data
        fw_cell = rnn.MultiRNNCell([
            rnn.BasicRNNCell(dims, reuse=scope.reuse) for i in range(num_blocks)
        ], state_is_tuple=True)
        bw_cell = rnn.MultiRNNCell([
            rnn.BasicRNNCell(dims, reuse=scope.reuse) for i in range(num_blocks)
        ], state_is_tuple=True)

        enc_out, _ = tf.nn.bidirectional_dynamic_rnn(
            fw_cell, bw_cell,
            x,
            sequence_length=length,
            initial_state_fw=fw_cell.zero_state(batch_size, tf.float32),
            initial_state_bw=bw_cell.zero_state(batch_size, tf.float32)
        )
        enc_out = tf.concat(enc_out, 2)

        return enc_out
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号