replay_memory.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:RFR-solution 作者: baoblackcoal 项目源码 文件源码
def __init__(self, config, model_dir, ob_shape_list):
    self.model_dir = model_dir

    self.cnn_format = config.cnn_format
    self.memory_size = config.memory_size
    self.actions = np.empty(self.memory_size, dtype = np.uint8)
    self.rewards = np.empty(self.memory_size, dtype = np.integer)
    # print(self.memory_size, config.screen_height, config.screen_width)
    # self.screens = np.empty((self.memory_size, config.screen_height, config.screen_width), dtype = np.float16)
    self.screens = np.empty([self.memory_size] + ob_shape_list, dtype = np.float16)
    self.terminals = np.empty(self.memory_size, dtype = np.bool)
    self.history_length = config.history_length
    # self.dims = (config.screen_height, config.screen_width)
    self.dims = tuple(ob_shape_list)
    self.batch_size = config.batch_size
    self.count = 0
    self.current = 0

    # pre-allocate prestates and poststates for minibatch
    self.prestates = np.empty((self.batch_size, self.history_length) + self.dims, dtype = np.float16)
    self.poststates = np.empty((self.batch_size, self.history_length) + self.dims, dtype = np.float16)
    # self.prestates = np.empty((self.batch_size, self.history_length, self.dims), dtype = np.float16)
    # self.poststates = np.empty((self.batch_size, self.history_length, self.dims), dtype = np.float16)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号