python类make()的实例源码

train_pong.py 文件源码 项目:rl-attack-detection 作者: yenchenlin 项目源码 文件源码 阅读 24 收藏 0 点赞 0 评论 0
def main():
    env = gym.make("PongNoFrameskip-v4")
    env = ScaledFloatFrame(wrap_dqn(env))
    model = deepq.models.cnn_to_mlp(
        convs=[(32, 8, 4), (64, 4, 2), (64, 3, 1)],
        hiddens=[256],
        dueling=True
    )
    act = deepq.learn(
        env,
        q_func=model,
        lr=1e-4,
        max_timesteps=2000000,
        buffer_size=10000,
        exploration_fraction=0.1,
        exploration_final_eps=0.01,
        train_freq=4,
        learning_starts=10000,
        target_network_update_freq=1000,
        gamma=0.99,
        prioritized_replay=True
    )
    act.save("pong_model.pkl")
    env.close()
funfun.py 文件源码 项目:chainer_pong 作者: icoxfog417 项目源码 文件源码 阅读 25 收藏 0 点赞 0 评论 0
def main(game_count=1):
    record = os.path.join(os.path.dirname(__file__), "funfun")
    env = gym.make("Pong-v0")
    hanamichi = Hanamichi()

    env.monitor.start(record)
    for i in range(game_count):
        playing = True
        observation = env.reset()
        reward = -1
        action = -1

        while playing:
            env.render()
            if action < 0:
                action = hanamichi.start(observation)
            else:
                action = hanamichi.act(observation, reward)
            observation, reward, done, info = env.step(action)
            playing = not done
            if done:
                hanamichi.end(reward)

    env.monitor.close()
base.py 文件源码 项目:pytorch.rl.learning 作者: moskomule 项目源码 文件源码 阅读 21 收藏 0 点赞 0 评论 0
def __init__(self, env_name, num_episodes, alpha, gamma, policy, report_freq=100, **kwargs):
        """
        base class for RL using lookup table
        :param env_name: see https://github.com/openai/gym/wiki/Table-of-environments
        :param num_episodes: int, number of episode for training
        :param alpha: float, learning rate
        :param gamma: float, discount rate
        :param policy: str
        :param report_freq: int, by default 100
        :param kwargs: other arguments
        """
        self.env = gym.make(env_name)
        self.num_episodes = num_episodes
        self.alpha = alpha
        self.gamma = gamma
        self.state = None
        self._rewards = None
        self._policy = policy
        self.report_freq = report_freq
        for k, v in kwargs.items():
            setattr(self, str(k), v)
wrapper.py 文件源码 项目:pytorch.rl.learning 作者: moskomule 项目源码 文件源码 阅读 29 收藏 0 点赞 0 评论 0
def make_atari(env_id, noop=True, max_and_skip=True, episode_life=True, clip_rewards=True, frame_stack=True,
               scale=True):
    """Configure environment for DeepMind-style Atari.
    """
    env = gym.make(env_id)
    assert 'NoFrameskip' in env.spec.id
    if noop:
        env = NoopResetEnv(env, noop_max=30)
    if max_and_skip:
        env = MaxAndSkipEnv(env, skip=4)
    if episode_life:
        env = EpisodicLifeEnv(env)
    if 'FIRE' in env.unwrapped.get_action_meanings():
        env = FireResetEnv(env)
    env = WarpFrame(env)
    if scale:
        env = ScaledFloatFrame(env)
    if clip_rewards:
        env = ClipRewardEnv(env)
    if frame_stack:
        env = FrameStack(env, 4)
    return env
train_pong.py 文件源码 项目:combine-DT-with-NN-in-RL 作者: Burning-Bear 项目源码 文件源码 阅读 24 收藏 0 点赞 0 评论 0
def main():
    env = gym.make("PongNoFrameskip-v4")
    env = ScaledFloatFrame(wrap_dqn(env))
    model = deepq.models.cnn_to_mlp(
        convs=[(32, 8, 4), (64, 4, 2), (64, 3, 1)],
        hiddens=[256],
        dueling=True
    )
    act = deepq.learn(
        env,
        q_func=model,
        lr=1e-4,
        max_timesteps=2000000,
        buffer_size=10000,
        exploration_fraction=0.1,
        exploration_final_eps=0.01,
        train_freq=4,
        learning_starts=10000,
        target_network_update_freq=1000,
        gamma=0.99,
        prioritized_replay=True
    )
    act.save("pong_model.pkl")
    env.close()
train_cartpole.py 文件源码 项目:combine-DT-with-NN-in-RL 作者: Burning-Bear 项目源码 文件源码 阅读 24 收藏 0 点赞 0 评论 0
def main():
    env = gym.make("CartPole-v0")
    model = deepq.models.mlp([64])
    act = deepq.learn(
        env,
        q_func=model,
        lr=1e-3,
        max_timesteps=100000,
        buffer_size=50000,
        exploration_fraction=0.1,
        exploration_final_eps=0.02,
        print_freq=10,
        callback=callback
    )
    print("Saving model to cartpole_model.pkl")
    act.save("cartpole_model.pkl")
train_cartpole.py 文件源码 项目:combine-DT-with-NN-in-RL 作者: Burning-Bear 项目源码 文件源码 阅读 24 收藏 0 点赞 0 评论 0
def main():
    env = gym.make("CartPole-v0")
    model = deepq.models.mlp([64])
    act = deepq.learn(
        env,
        q_func=model,
        lr=1e-3,
        max_timesteps=100000,
        buffer_size=50000,
        exploration_fraction=0.1,
        exploration_final_eps=0.02,
        print_freq=10,
        callback=callback
    )
    print("Saving model to cartpole_model.pkl")
    act.save("cartpole_model.pkl")
test_contextual_envs.py 文件源码 项目:gym-extensions 作者: Breakend 项目源码 文件源码 阅读 24 收藏 0 点赞 0 评论 0
def test_cartpole_contextual():
    env_id = 'CartPoleContextual-v0'
    env = gym.make(env_id)
    if isinstance(env.unwrapped, CartPoleEnv):
        env.reset()
    else:
        raise NotImplementedError

    nr_of_items_context_space_info = 10
    nr_unwrapped = len(list(env.unwrapped.context_space_info().keys()))
    if nr_of_items_context_space_info != nr_unwrapped:
        print('context_space_info() function needs to be implemented!')
        raise NotImplementedError

    context_vect = [0.01, 0.01, 0.01, 0.01]
    # these should change because change_context_function
    if context_vect == env.unwrapped.context:
        raise AttributeError

    env.unwrapped.change_context(context_vect)
    if context_vect != env.unwrapped.context:
        raise AttributeError
test_contextual_envs.py 文件源码 项目:gym-extensions 作者: Breakend 项目源码 文件源码 阅读 24 收藏 0 点赞 0 评论 0
def test_pendulum_contextual():
    env_id = 'PendulumContextual-v0'
    env = gym.make(env_id)
    if isinstance(env.unwrapped, PendulumEnv):
        env.reset()
    else:
        raise NotImplementedError

    nr_of_items_context_space_info = 10
    nr_unwrapped = len(list(env.unwrapped.context_space_info().keys()))
    if nr_of_items_context_space_info != nr_unwrapped:
        print('context_space_info() function needs to be implemented!')
        raise NotImplementedError

    context_vect = [0.01, 0.01]
    if context_vect == env.unwrapped.context:
        raise AttributeError

    env.unwrapped.change_context(context_vect)
    if context_vect != env.unwrapped.context:
        raise AttributeError
train_cartpole.py 文件源码 项目:rl-attack-detection 作者: yenchenlin 项目源码 文件源码 阅读 29 收藏 0 点赞 0 评论 0
def main():
    env = gym.make("CartPole-v0")
    model = deepq.models.mlp([64])
    act = deepq.learn(
        env,
        q_func=model,
        lr=1e-3,
        max_timesteps=100000,
        buffer_size=50000,
        exploration_fraction=0.1,
        exploration_final_eps=0.02,
        print_freq=10,
        callback=callback
    )
    print("Saving model to cartpole_model.pkl")
    act.save("cartpole_model.pkl")
train_mountaincar.py 文件源码 项目:baselines 作者: openai 项目源码 文件源码 阅读 23 收藏 0 点赞 0 评论 0
def main():
    env = gym.make("MountainCar-v0")
    # Enabling layer_norm here is import for parameter space noise!
    model = deepq.models.mlp([64], layer_norm=True)
    act = deepq.learn(
        env,
        q_func=model,
        lr=1e-3,
        max_timesteps=100000,
        buffer_size=50000,
        exploration_fraction=0.1,
        exploration_final_eps=0.1,
        print_freq=10,
        param_noise=True
    )
    print("Saving model to mountaincar_model.pkl")
    act.save("mountaincar_model.pkl")
train_cartpole.py 文件源码 项目:baselines 作者: openai 项目源码 文件源码 阅读 23 收藏 0 点赞 0 评论 0
def main():
    env = gym.make("CartPole-v0")
    model = deepq.models.mlp([64])
    act = deepq.learn(
        env,
        q_func=model,
        lr=1e-3,
        max_timesteps=100000,
        buffer_size=50000,
        exploration_fraction=0.1,
        exploration_final_eps=0.02,
        print_freq=10,
        callback=callback
    )
    print("Saving model to cartpole_model.pkl")
    act.save("cartpole_model.pkl")
run_mujoco.py 文件源码 项目:baselines 作者: openai 项目源码 文件源码 阅读 27 收藏 0 点赞 0 评论 0
def train(env_id, num_timesteps, seed):
    import baselines.common.tf_util as U
    sess = U.single_threaded_session()
    sess.__enter__()

    rank = MPI.COMM_WORLD.Get_rank()
    if rank != 0:
        logger.set_level(logger.DISABLED)
    workerseed = seed + 10000 * MPI.COMM_WORLD.Get_rank()
    set_global_seeds(workerseed)
    env = gym.make(env_id)
    def policy_fn(name, ob_space, ac_space):
        return MlpPolicy(name=name, ob_space=env.observation_space, ac_space=env.action_space,
            hid_size=32, num_hid_layers=2)
    env = bench.Monitor(env, logger.get_dir() and
        osp.join(logger.get_dir(), str(rank)))
    env.seed(workerseed)
    gym.logger.setLevel(logging.WARN)

    trpo_mpi.learn(env, policy_fn, timesteps_per_batch=1024, max_kl=0.01, cg_iters=10, cg_damping=0.1,
        max_timesteps=num_timesteps, gamma=0.99, lam=0.98, vf_iters=5, vf_stepsize=1e-3)
    env.close()
run_mujoco.py 文件源码 项目:baselines 作者: openai 项目源码 文件源码 阅读 26 收藏 0 点赞 0 评论 0
def train(env_id, num_timesteps, seed):
    env=gym.make(env_id)
    env = bench.Monitor(env, logger.get_dir() and os.path.join(logger.get_dir(), str(rank)))
    set_global_seeds(seed)
    env.seed(seed)
    gym.logger.setLevel(logging.WARN)

    with tf.Session(config=tf.ConfigProto()):
        ob_dim = env.observation_space.shape[0]
        ac_dim = env.action_space.shape[0]
        with tf.variable_scope("vf"):
            vf = NeuralNetValueFunction(ob_dim, ac_dim)
        with tf.variable_scope("pi"):
            policy = GaussianMlpPolicy(ob_dim, ac_dim)

        learn(env, policy=policy, vf=vf,
            gamma=0.99, lam=0.97, timesteps_per_batch=2500,
            desired_kl=0.002,
            num_timesteps=num_timesteps, animate=False)

        env.close()
train_space_invaders.py 文件源码 项目:ai-bs-summer17 作者: uchibe 项目源码 文件源码 阅读 21 收藏 0 点赞 0 评论 0
def main():
    env = gym.make("SpaceInvadersNoFrameskip-v4")
    env = ScaledFloatFrame(wrap_dqn(env))
    model = deepq.models.cnn_to_mlp(
        convs=[(32, 8, 4), (64, 4, 2), (64, 3, 1)],
        hiddens=[256],
        dueling=True
    )
    act = deepq.learn(
        env,
        q_func=model,
        lr=1e-4,
        max_timesteps=2000000,
        buffer_size=10000,
        exploration_fraction=0.1,
        exploration_final_eps=0.01,
        train_freq=4,
        learning_starts=10000,
        target_network_update_freq=1000,
        gamma=0.99,
        prioritized_replay=True
    )
    act.save("space_invaders_model.pkl")
    env.close()
pposgd_atlas.py 文件源码 项目:ai-bs-summer17 作者: uchibe 项目源码 文件源码 阅读 27 收藏 0 点赞 0 评论 0
def train(env_id, num_timesteps, seed):
    from baselines.pposgd import mlp_policy, pposgd_simple
    U.make_session(num_cpu=1).__enter__()
    logger.session().__enter__()
    set_global_seeds(seed)
    env = gym.make(env_id)
    def policy_fn(name, ob_space, ac_space):
        return mlp_policy.MlpPolicy(name=name, ob_space=ob_space, ac_space=ac_space,
            hid_size=64, num_hid_layers=2)
    env = bench.Monitor(env, osp.join(logger.get_dir(), "monitor.json"))
    env.seed(seed)
    gym.logger.setLevel(logging.WARN)
    pposgd_simple.learn(env, policy_fn, 
            max_timesteps=num_timesteps,
            timesteps_per_batch=2048,
            clip_param=0.2, entcoeff=0.0,
            optim_epochs=10, optim_stepsize=3e-4, optim_batchsize=64,
            gamma=0.99, lam=0.95,
        )
    env.close()
envs.py 文件源码 项目:pytorch-a2c-ppo-acktr 作者: ikostrikov 项目源码 文件源码 阅读 35 收藏 0 点赞 0 评论 0
def make_env(env_id, seed, rank, log_dir):
    def _thunk():
        env = gym.make(env_id)
        is_atari = hasattr(gym.envs, 'atari') and isinstance(env.unwrapped, gym.envs.atari.atari_env.AtariEnv)
        if is_atari:
            env = make_atari(env_id)
        env.seed(seed + rank)
        if log_dir is not None:
            env = bench.Monitor(env, os.path.join(log_dir, str(rank)))
        if is_atari:
            env = wrap_deepmind(env)
        # If the input has shape (W,H,3), wrap for PyTorch convolutions
        obs_shape = env.observation_space.shape
        if len(obs_shape) == 3 and obs_shape[2] in [1, 3]:
            env = WrapPyTorch(env)
        return env

    return _thunk
dqn_cartpole.py 文件源码 项目:ngraph 作者: NervanaSystems 项目源码 文件源码 阅读 25 收藏 0 点赞 0 评论 0
def main():
    # initialize gym environment
    environment = gym.make('CartPole-v0')

    state_axes = ng.make_axes([
        ng.make_axis(environment.observation_space.shape[0], name='width')
    ])

    agent = dqn.Agent(
        state_axes,
        environment.action_space,
        model=baselines_model,
        epsilon=dqn.linear_generator(start=1.0, end=0.02, steps=10000),
        learning_rate=1e-3,
        gamma=1.0,
        memory=dqn.Memory(maxlen=50000),
        learning_starts=1000,
    )

    rl_loop.rl_loop_train(environment, agent, episodes=1000)

    total_reward = rl_loop.evaluate_single_episode(environment, agent)
    print(total_reward)
test_integration_dqn.py 文件源码 项目:ngraph 作者: NervanaSystems 项目源码 文件源码 阅读 24 收藏 0 点赞 0 评论 0
def test_dependent_environment():
    environment = gym.make('DependentEnv-v0')

    total_rewards = []
    for i in range(10):
        agent = dqn.Agent(
            dqn.space_shape(environment.observation_space),
            environment.action_space,
            model=model,
            epsilon=dqn.decay_generator(start=1.0, decay=0.995, minimum=0.1),
            gamma=0.99,
            learning_rate=0.1,
        )

        rl_loop.rl_loop_train(environment, agent, episodes=10)

        total_rewards.append(
            rl_loop.evaluate_single_episode(environment, agent)
        )

    # most of these 10 agents will be able to converge to the perfect policy
    assert np.mean(np.array(total_rewards) == 100) >= 0.5
envs.py 文件源码 项目:DHP 作者: YuhangSong 项目源码 文件源码 阅读 28 收藏 0 点赞 0 评论 0
def create_flash_env(env_id, client_id, remotes, **_):
    env = gym.make(env_id)
    env = Vision(env)
    env = Logger(env)
    env = BlockingReset(env)

    reg = universe.runtime_spec('flashgames').server_registry
    height = reg[env_id]["height"]
    width = reg[env_id]["width"]
    env = CropScreen(env, height, width, 84, 18)
    env = FlashRescale(env)

    keys = ['left', 'right', 'up', 'down', 'x']
    env = DiscreteToFixedKeysVNCActions(env, keys)
    env = EpisodeID(env)
    env = DiagnosticsInfo(env)
    env = Unvectorize(env)
    env.configure(fps=5.0, remotes=remotes, start_timeout=15 * 60, client_id=client_id,
                  vnc_driver='go', vnc_kwargs={
                    'encoding': 'tight', 'compress_level': 0,
                    'fine_quality_level': 50, 'subsample_level': 3})
    return env
gym_wrapper.py 文件源码 项目:PAAC.pytorch 作者: qbx2 项目源码 文件源码 阅读 27 收藏 0 点赞 0 评论 0
def make(env_id, hack=None):
    if 'Deterministic-v4' not in env_id:
        print('[Warning] Use Deterministic-v4 version '
              'to reproduce the results of paper.')

    _env = env = gym.make(env_id)

    if hack:
        # Hack gym env to output grayscale image
        if env.spec.timestep_limit is not None:
            from gym.wrappers.time_limit import TimeLimit

            if isinstance(env, TimeLimit):
                _env = env.env

        if hack == 'train':
            _env._get_image = _env.ale.getScreenGrayscale
            _env._get_obs = _env.ale.getScreenGrayscale
        elif hack == 'eval':
            _env._get_obs = _env.ale.getScreenGrayscale

    return env
dqn.py 文件源码 项目:chi 作者: rmst 项目源码 文件源码 阅读 34 收藏 0 点赞 0 评论 0
def dqn_test(env='OneRoundDeterministicReward-v0'):
    env = gym.make(env)
    env = ObservationShapeWrapper(env)

    @tt.model(tracker=tf.train.ExponentialMovingAverage(1-.01),
                         optimizer=tf.train.AdamOptimizer(.01))
    def q_network(x):
        x = layers.fully_connected(x, 32)
        x = layers.fully_connected(x, env.action_space.n, activation_fn=None,
                                                             weights_initializer=tf.random_normal_initializer(0, 1e-4))
        return x

    agent = DqnAgent(env, q_network, double_dqn=False, replay_start=100, annealing_time=100)

    rs = []
    for ep in range(10000):
        r, _ = agent.play_episode()

        rs.append(r)

        if ep % 100 == 0:
            print(f'Return after episode {ep} is {sum(rs)/len(rs)}')
            rs = []
test_time_limit.py 文件源码 项目:universe 作者: openai 项目源码 文件源码 阅读 25 收藏 0 点赞 0 评论 0
def test_steps_limit_restart():
    env = gym.make('test.StepsLimitDummyVNCEnv-v0')
    env.configure(_n=1)
    env = wrappers.TimeLimit(env)
    env.reset()

    assert env._max_episode_seconds == None
    assert env._max_episode_steps == 2

    # Episode has started
    _, _, done, info = env.step([[]])
    assert done == [False]

    # Limit reached, now we get a done signal and the env resets itself
    _, _, done, info = env.step([[]])
    assert done == [True]
    assert env._elapsed_steps == 0
test_time_limit.py 文件源码 项目:universe 作者: openai 项目源码 文件源码 阅读 28 收藏 0 点赞 0 评论 0
def test_seconds_limit_restart():
    env = gym.make('test.SecondsLimitDummyVNCEnv-v0')
    env.configure(_n=1)
    env = wrappers.TimeLimit(env)
    env.reset()

    assert env._max_episode_seconds == 0.1
    assert env._max_episode_steps == None

    # Episode has started
    _, _, done, info = env.step([[]])
    assert done == [False]

    # Not enough time has passed
    _, _, done, info = env.step([[]])
    assert done == [False]

    time.sleep(0.2)

    # Limit reached, now we get a done signal and the env resets itself
    _, _, done, info = env.step([[]])
    assert done == [True]
test_time_limit.py 文件源码 项目:universe 作者: openai 项目源码 文件源码 阅读 25 收藏 0 点赞 0 评论 0
def test_default_time_limit():
    # We need an env without a default limit
    register(
        id='test.NoLimitDummyVNCEnv-v0',
        entry_point='universe.envs:DummyVNCEnv',
        tags={
            'vnc': True,
            },
    )

    env = gym.make('test.NoLimitDummyVNCEnv-v0')
    env.configure(_n=1)
    env = wrappers.TimeLimit(env)
    env.reset()

    assert env._max_episode_seconds == wrappers.time_limit.DEFAULT_MAX_EPISODE_SECONDS
    assert env._max_episode_steps == None
test_joint.py 文件源码 项目:universe 作者: openai 项目源码 文件源码 阅读 24 收藏 0 点赞 0 评论 0
def test_joint():
    env1 = gym.make('test.DummyVNCEnv-v0')
    env2 = gym.make('test.DummyVNCEnv-v0')
    env1.configure(_n=3)
    env2.configure(_n=3)
    for reward_buffer in [env1._reward_buffers[0], env2._reward_buffers[0]]:
        reward_buffer.set_env_info('running', 'test.DummyVNCEnv-v0', '1', 60)
        reward_buffer.reset('1')
        reward_buffer.push('1', 10, False, {})

    env = wrappers.Joint([env1, env2])
    assert env.n == 6
    observation_n = env.reset()
    assert observation_n == [None] * 6

    observation_n, reward_n, done_n, info = env.step([[] for _ in range(env.n)])
    assert reward_n == [10.0, 0.0, 0.0, 10.0, 0.0, 0.0]
    assert done_n == [False] * 6
gym_core.py 文件源码 项目:universe 作者: openai 项目源码 文件源码 阅读 25 收藏 0 点赞 0 评论 0
def __init__(self, env, gym_core_id=None):
        super(GymCoreAction, self).__init__(env)

        if gym_core_id is None:
            # self.spec is None while inside of the make, so we need
            # to pass gym_core_id in explicitly there. This case will
            # be hit when instantiating by hand.
            gym_core_id = self.spec._kwargs['gym_core_id']

        spec = gym.spec(gym_core_id)
        raw_action_space = gym_core_action_space(gym_core_id)

        self._actions = raw_action_space.actions
        self.action_space = gym_spaces.Discrete(len(self._actions))

        if spec._entry_point.startswith('gym.envs.atari:'):
            self.key_state = translator.AtariKeyState(gym.make(gym_core_id))
        else:
            self.key_state = None
test_semantics.py 文件源码 项目:universe 作者: openai 项目源码 文件源码 阅读 21 收藏 0 点赞 0 评论 0
def test_describe_handling():
    env = gym.make('flashgames.DuskDrive-v0')
    env.configure(vnc_driver=FakeVNCSession, rewarder_driver=FakeRewarder, remotes='vnc://example.com:5900+15900')
    env.reset()

    reward_buffer = get_reward_buffer(env)
    rewarder_client = get_rewarder_client(env)

    rewarder_client._manual_recv('v0.env.describe', {'env_id': 'flashgames.DuskDrive-v0', 'env_state': 'resetting', 'fps': 60}, {'episode_id': '1'})

    assert reward_buffer._remote_episode_id == '1'
    assert reward_buffer._remote_env_state == 'resetting'
    assert reward_buffer._current_episode_id == None
    assert reward_buffer.reward_state(reward_buffer._current_episode_id)._env_state == None

    rewarder_client._manual_recv('v0.reply.env.reset', {}, {'episode_id': '1'})

    assert reward_buffer._remote_episode_id == '1'
    assert reward_buffer._remote_env_state == 'resetting'
    assert reward_buffer._current_episode_id == '1'
    assert reward_buffer.reward_state(reward_buffer._current_episode_id)._env_state == 'resetting'
test_envs.py 文件源码 项目:universe 作者: openai 项目源码 文件源码 阅读 25 收藏 0 点赞 0 评论 0
def test_smoke(env_id):
    """Check that environments start up without errors and that we can extract rewards and observations"""
    gym.undo_logger_setup()
    logging.getLogger().setLevel(logging.INFO)

    env = gym.make(env_id)
    if env.metadata.get('configure.required', False):
        if os.environ.get('FORCE_LATEST_UNIVERSE_DOCKER_RUNTIMES'):  # Used to test universe-envs in CI
            configure_with_latest_docker_runtime_tag(env)
        else:
            env.configure(remotes=1)

    env = wrappers.Unvectorize(env)

    env.reset()
    _rollout(env, timestep_limit=60*30) # Check a rollout
train_agent_chainer.py 文件源码 项目:gym-malware 作者: endgameinc 项目源码 文件源码 阅读 28 收藏 0 点赞 0 评论 0
def train_agent(rounds=10000, use_score=False, name='result_dir', create_agent=create_ddqn_agent):
    ENV_NAME = 'malware-score-v0' if use_score else 'malware-v0'
    env = gym.make( ENV_NAME ) 
    np.random.seed(123)
    env.seed(123)

    agent = create_agent(env)

    chainerrl.experiments.train_agent_with_evaluation(
        agent, env,
        steps=rounds,                   # Train the agent for this many rounds steps
        max_episode_len=env.maxturns,   # Maximum length of each episodes        
        eval_interval=1000,             # Evaluate the agent after every 1000 steps
        eval_n_runs=100,                # 100 episodes are sampled for each evaluation        
        outdir=name)                    # Save everything to 'result' directory

    return agent


问题


面经


文章

微信
公众号

扫码关注公众号