sequence.py 文件源码

python
阅读 33 收藏 0 点赞 0 评论 0

项目:tensorforce 作者: reinforceio 项目源码 文件源码
def tf_process(self, tensor):
        # or just always the same?
        tf.assert_equal(x=tf.shape(input=tensor)[0], y=1)

        states_buffer = tf.get_variable(
            name='states-buffer',
            shape=((self.length,) + util.shape(tensor)[1:]),
            dtype=tensor.dtype,
            trainable=False
        )
        index = tf.get_variable(
            name='index',
            dtype=util.tf_dtype('int'),
            initializer=-1,
            trainable=False
        )

        assignment = tf.cond(
            pred=tf.equal(x=index, y=-1),
            true_fn=(lambda: tf.assign(
                ref=states_buffer,
                value=tf.tile(
                    input=tensor,
                    multiples=((self.length,) + tuple(1 for _ in range(util.rank(tensor) - 1)))
                )
            )),
            false_fn=(lambda: tf.assign(ref=states_buffer[index], value=tensor[0]))
        )

        with tf.control_dependencies(control_inputs=(assignment,)):
            previous_states = [states_buffer[(index - n - 1) % self.length] for n in range(self.length)]
            assignment = tf.assign(ref=index, value=((tf.maximum(x=index, y=0) + 1) % self.length))

        with tf.control_dependencies(control_inputs=(assignment,)):
            return tf.expand_dims(input=tf.concat(values=previous_states, axis=-1), axis=0)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号