rnn_cell.py 文件源码

python
阅读 30 收藏 0 点赞 0 评论 0

项目:bi-att-flow 作者: allenai 项目源码 文件源码
def get_double_linear_controller(size, bias, input_keep_prob=1.0, is_train=None):
        def double_linear_controller(inputs, state, memory):
            """

            :param inputs: [N, i]
            :param state: [N, d]
            :param memory: [N, M, m]
            :return: [N, M]
            """
            rank = len(memory.get_shape())
            _memory_size = tf.shape(memory)[rank-2]
            tiled_inputs = tf.tile(tf.expand_dims(inputs, 1), [1, _memory_size, 1])
            if isinstance(state, tuple):
                tiled_states = [tf.tile(tf.expand_dims(each, 1), [1, _memory_size, 1])
                                for each in state]
            else:
                tiled_states = [tf.tile(tf.expand_dims(state, 1), [1, _memory_size, 1])]

            # [N, M, d]
            in_ = tf.concat(2, [tiled_inputs] + tiled_states + [memory])
            out = double_linear_logits(in_, size, bias, input_keep_prob=input_keep_prob,
                                       is_train=is_train)
            return out
        return double_linear_controller
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号