dataset.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:rl-attack-detection 作者: yenchenlin 项目源码 文件源码
def __init__(self, directory, num_act, mean_path, num_threads=1, capacity=1e5, batch_size=32,
                scale=(1.0/255.0), s_t_shape=[84, 84, 4], x_t_1_shape=[84, 84, 1], colorspace='gray'):
        self.scale = scale
        self.s_t_shape = s_t_shape
        self.x_t_1_shape = x_t_1_shape

        # Load image mean
        mean = np.load(os.path.join(mean_path))

        # Prepare data flow
        s_t, a_t, x_t_1 = _read_and_decode(directory,
                                        s_t_shape=s_t_shape,
                                        num_act=num_act,
                                        x_t_1_shape=x_t_1_shape)
        self.mean = mean
        self.s_t_batch, self.a_t_batch, self.x_t_1_batch = tf.train.shuffle_batch([s_t, a_t, x_t_1],
                                                            batch_size=batch_size, capacity=capacity,
                                                            min_after_dequeue=int(capacity*0.25),
                                                            num_threads=num_threads)

        # Subtract image mean (according to J Oh design)
        self.mean_const = tf.constant(mean, dtype=tf.float32)
        print(self.mean_const.get_shape())
        self.s_t_batch = (self.s_t_batch - tf.tile(self.mean_const, [1, 1, 4])) * scale
        self.x_t_1_batch = (self.x_t_1_batch - self.mean_const) * scale
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号