def main(_):
# set up TF environment
os.environ['CUDA_VISIBLE_DEVICES'] = FLAGS.gpus
gpus_list = FLAGS.gpus.split(',')
# save prefix
prefix = '%s/%s/%s/%d' % (FLAGS.working_root, FLAGS.dataset_name,
FLAGS.model_name, FLAGS.try_num)
if not os.path.exists(prefix):
os.makedirs(prefix)
# start
model_params = {"num_classes": 10, "gpus_list": gpus_list}
run_config = tf.estimator.RunConfig()
run_config = run_config.replace(
model_dir=prefix,
log_step_count_steps=100,
save_checkpoints_secs=600,
session_config=tf.ConfigProto(allow_soft_placement=True,
gpu_options=tf.GPUOptions(allow_growth=True)))
nn = tf.estimator.Estimator(
model_fn=model_fn, params=model_params, config=run_config)
nn.train(input_fn=lambda: input_fn(
len(gpus_list)), steps=None, max_steps=None)
评论列表
文章目录