def CommandLine(args=None):
'''
Main command line. Accepts args, to allow for simple unit testing.
'''
flags = tf.app.flags
FLAGS = flags.FLAGS
if args:
FLAGS.__init__()
FLAGS.__dict__.update(args)
try:
flags.DEFINE_string("model_type", "wide+deep","Valid model types: {'wide', 'deep', 'wide+deep'}.")
flags.DEFINE_string("run_name", None, "name for this run (defaults to model type)")
flags.DEFINE_string("load_weights", None, "filename with initial weights to load")
flags.DEFINE_string("checkpoints_dir", None, "name of directory where checkpoints should be saved")
flags.DEFINE_integer("n_epoch", 200, "Number of training epoch steps")
flags.DEFINE_integer("snapshot_step", 100, "Step number when snapshot (and validation testing) is done")
flags.DEFINE_float("wide_learning_rate", 0.001, "learning rate for the wide part of the model")
flags.DEFINE_float("deep_learning_rate", 0.001, "learning rate for the deep part of the model")
flags.DEFINE_boolean("verbose", False, "Verbose output")
except argparse.ArgumentError:
pass # so that CommandLine can be run more than once, for testing
twad = TFLearnWideAndDeep(model_type=FLAGS.model_type, verbose=FLAGS.verbose,
name=FLAGS.run_name, wide_learning_rate=FLAGS.wide_learning_rate,
deep_learning_rate=FLAGS.deep_learning_rate,
checkpoints_dir=FLAGS.checkpoints_dir)
twad.load_data()
if FLAGS.load_weights:
print ("Loading initial weights from %s" % FLAGS.load_weights)
twad.model.load(FLAGS.load_weights)
twad.train(n_epoch=FLAGS.n_epoch, snapshot_step=FLAGS.snapshot_step)
twad.evaluate()
return twad
#-----------------------------------------------------------------------------
# unit tests
评论列表
文章目录