karpathy_cnn.py 文件源码

python
阅读 17 收藏 0 点赞 0 评论 0

项目:DeepRL 作者: arnomoonens 项目源码 文件源码
def __init__(self, env, monitor_path, video=True, **usercfg):
        super(KarpathyCNN, self).__init__(**usercfg)
        self.env = wrappers.Monitor(env, monitor_path, force=True, video_callable=(None if video else False))
        self.nA = env.action_space.n
        self.monitor_path = monitor_path
        # Default configuration. Can be overwritten using keyword arguments.
        self.config.update(
            dict(
                # timesteps_per_batch=10000,
                # n_iter=100,
                n_hidden_units=200,
                learning_rate=1e-3,
                batch_size=10,  # Amount of episodes after which to adapt gradients
                gamma=0.99,  # Discount past rewards by a percentage
                decay=0.99,  # Decay of RMSProp optimizer
                epsilon=1e-9,  # Epsilon of RMSProp optimizer
                draw_frequency=50  # Draw a plot every 50 episodes
            )
        )
        self.config.update(usercfg)
        self.build_network()
        if self.config["save_model"]:
            tf.add_to_collection("action", self.action)
            tf.add_to_collection("states", self.states)
            self.saver = tf.train.Saver()
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号