def train(config={'activation': 'relu'}, reporter=None):
global FLAGS, status_reporter, activation_fn
status_reporter = reporter
activation_fn = getattr(tf.nn, config['activation'])
parser = argparse.ArgumentParser()
parser.add_argument(
'--data_dir', type=str, default='/tmp/tensorflow/mnist/input_data',
help='Directory for storing input data')
FLAGS, unparsed = parser.parse_known_args()
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
# !!! Example of using the ray.tune Python API !!!
评论列表
文章目录