train_estimator.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:Classification_Nets 作者: BobLiu20 项目源码 文件源码
def main(_):
    # set up TF environment
    os.environ['CUDA_VISIBLE_DEVICES'] = FLAGS.gpus
    gpus_list = FLAGS.gpus.split(',')
    # save prefix
    prefix = '%s/%s/%s/%d' % (FLAGS.working_root, FLAGS.dataset_name,
                              FLAGS.model_name, FLAGS.try_num)
    if not os.path.exists(prefix):
        os.makedirs(prefix)
    # start
    model_params = {"num_classes": 10, "gpus_list": gpus_list}
    run_config = tf.estimator.RunConfig()
    run_config = run_config.replace(
        model_dir=prefix,
        log_step_count_steps=100,
        save_checkpoints_secs=600,
        session_config=tf.ConfigProto(allow_soft_placement=True,
                                      gpu_options=tf.GPUOptions(allow_growth=True)))
    nn = tf.estimator.Estimator(
        model_fn=model_fn, params=model_params, config=run_config)
    nn.train(input_fn=lambda: input_fn(
        len(gpus_list)), steps=None, max_steps=None)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号