train.py 文件源码

python
阅读 29 收藏 0 点赞 0 评论 0

项目:async-rl 作者: Grzego 项目源码 文件源码
def __init__(self, action_space, batch_size=32, screen=(84, 84), swap_freq=200):
        from keras.optimizers import RMSprop        
        # -----
        self.screen = screen
        self.input_depth = 1
        self.past_range = 3
        self.observation_shape = (self.input_depth * self.past_range,) + self.screen
        self.batch_size = batch_size

        _, _, self.train_net, adventage = build_network(self.observation_shape, action_space.n)

        self.train_net.compile(optimizer=RMSprop(epsilon=0.1, rho=0.99),
                               loss=[value_loss(), policy_loss(adventage, args.beta)])

        self.pol_loss = deque(maxlen=25)
        self.val_loss = deque(maxlen=25)
        self.values = deque(maxlen=25)
        self.entropy = deque(maxlen=25)
        self.swap_freq = swap_freq
        self.swap_counter = self.swap_freq
        self.unroll = np.arange(self.batch_size)
        self.targets = np.zeros((self.batch_size, action_space.n))
        self.counter = 0
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号