thin_stack.py 文件源码

python
阅读 29 收藏 0 点赞 0 评论 0

项目:thinstack-rl 作者: hans 项目源码 文件源码
def _create_state(self):
        """Prepare stateful variables modified during the recurrence."""

        # Both the queue and the stack are flattened stack_size * batch_size
        # tensors. `stack_size` many blocks of `batch_size` values
        stack_shape = (self.stack_size * self.batch_size, self.model_dim)
        self.stack = tf.Variable(tf.zeros(stack_shape, dtype=tf.float32),
                                 trainable=False, name="stack")
        self.queue = tf.Variable(tf.zeros((self.stack_size * self.batch_size,), dtype=tf.float32),
                                 trainable=False, name="queue")

        self.buff_cursors = tf.Variable(tf.zeros((self.batch_size,), dtype=tf.float32),
                                          trainable=False, name="buff_cursors")
        self.cursors = tf.Variable(tf.ones((self.batch_size,), dtype=tf.float32) * - 1,
                                   trainable=False, name="cursors")

        # TODO make parameterizable
        self.tracking_value = tf.Variable(tf.zeros((self.batch_size, self.tracking_dim), dtype=tf.float32),
                                          trainable=False, name="tracking_value")

        # Create an Op which will (re-)initialize the auxiliary variables
        # declared above.
        self._aux_vars = [self.stack, self.queue, self.buff_cursors, self.cursors,
                          self.tracking_value]
        self.variable_initializer = tf.initialize_variables(self._aux_vars)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号