train.py 文件源码

python
阅读 17 收藏 0 点赞 0 评论 0

项目:agent 作者: sintefneodroid 项目源码 文件源码
def main():
  """

  :return:
  """
  _visualiser = None
  if C.USE_VISDOM:
    _visualiser = Visdom(C.VISDOM_SERVER)

  _environment = neo.make(C.ENVIRONMENT,
                          connect_to_running=C.CONNECT_TO_RUNNING,
                          logging_directory=C.LOGGING_DIRECTORY,
                          debug_logging=C.USE_LOGGING)
  _environment.seed(C.RANDOM_SEED)

  if type(C.ARCH_PARAMS['input_size']) == str:
    C.ARCH_PARAMS['input_size'] = _environment.observation_space.shape
  print('observation dimensions: ', C.ARCH_PARAMS['input_size'])

  if type(C.ARCH_PARAMS['output_size']) == str:
    C.ARCH_PARAMS['output_size'] = _environment.action_space.n
  print('action dimensions: ', C.ARCH_PARAMS['output_size'])

  _model = C.ARCH(C.ARCH_PARAMS)
  if C.LOAD_PREVIOUS_MODEL_IF_AVAILABLE:
    _model.load_state_dict(load_model(C))
  _target_model = C.ARCH(C.ARCH_PARAMS)
  _target_model.load_state_dict(_model.state_dict())

  if C.USE_CUDA_IF_AVAILABLE:
    _model = _model.cuda()
    _target_model.cuda()

  _trained_model = training_loop(_model,
                                 _target_model,
                                 _environment,
                                 _visualiser)

  # _environment.render(close=True)
  _environment.close()

  save_model(_trained_model, C)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号