def parse_args():
"""Parse input arguments."""
parser = argparse.ArgumentParser(description='Deep360Pilot')
parser.add_argument('--opt', dest='opt_method', help='[Adam, Adadelta, RMSProp]', default='Adam')
parser.add_argument('--root', dest='root_path', help='root path of data', default='./')
parser.add_argument('--data', dest='data_path', help='data path of data', default='./data/')
parser.add_argument('--mode', dest='mode', help='[train, test, vid, pred]', required=True)
parser.add_argument('--model', dest='model_path', help='model path to load')
parser.add_argument('--gpu', dest='gpu', help='Choose which gpu to use', default='0')
parser.add_argument('-n', '--name', dest='video_name', help='youtube_id + _ + part')
parser.add_argument('-d', '--domain', dest='domain', help='skate, skiing, ...', required=True)
parser.add_argument('-l', '--lambda', dest='lam', help='movement tradeoff lambda, the higher the smoother.', type=float, required=True)
parser.add_argument('-b', '--boxnum', dest='boxnum', help='boxes number, Use integer, [8, 16, 32]', type=int, required=True)
parser.add_argument('-p', '--phase', dest='phase', help='phase [classify, regress]', required=True)
parser.add_argument('-s', '--save', dest='save', help='save images for debug', default=False)
group = parser.add_mutually_exclusive_group()
group.add_argument('--debug', dest='debug', help='Start debug mode or not', action='store_true')
args = parser.parse_args()
return args, parser
python类train()的实例源码
def main():
parser = argparse.ArgumentParser(description='PyTorch YOLO')
parser.add_argument('--use_cuda', type=bool, default=False,
help='use cuda or not')
parser.add_argument('--epochs', type=int, default=10,
help='Epochs')
parser.add_argument('--batch_size', type=int, default=1,
help='Batch size')
parser.add_argument('--lr', type=float, default=1e-3,
help='Learning rate')
parser.add_argument('--seed', type=int, default=1234,
help='Random seed')
args = parser.parse_args()
torch.manual_seed(args.seed)
torch.backends.cudnn.benchmark = args.use_cuda
train.train(args)
def main(_):
pp.pprint(FLAGS.__flags)
# training/inference
with tf.Session() as sess:
dcgan = DCGAN(sess, FLAGS)
# path checks
if not os.path.exists(FLAGS.checkpoint_dir):
os.makedirs(FLAGS.checkpoint_dir)
if not os.path.exists(os.path.join(FLAGS.log_dir, dcgan.get_model_dir())):
os.makedirs(os.path.join(FLAGS.log_dir, dcgan.get_model_dir()))
if not os.path.exists(os.path.join(FLAGS.sample_dir, dcgan.get_model_dir())):
os.makedirs(os.path.join(FLAGS.sample_dir, dcgan.get_model_dir()))
# load checkpoint if found
if dcgan.checkpoint_exists():
print("Loading checkpoints...")
if dcgan.load():
print "success!"
else:
raise IOError("Could not read checkpoints from {0}!".format(
FLAGS.checkpoint_dir))
else:
if not FLAGS.train:
raise IOError("No checkpoints found but need for sampling!")
print "No checkpoints found. Training from scratch."
dcgan.load()
# train DCGAN
if FLAGS.train:
train(dcgan)
# inference/visualization code goes here
print "Generating samples..."
inference.sample_images(dcgan)
print "Generating visualizations of z..."
inference.visualize_z(dcgan)
def main(_):
pp.pprint(flags.FLAGS.__flags)
if not os.path.exists(FLAGS.checkpoint_dir):
os.makedirs(FLAGS.checkpoint_dir)
if not os.path.exists(FLAGS.sample_dir):
os.makedirs(FLAGS.sample_dir)
with tf.Session(config=tf.ConfigProto(
allow_soft_placement=True, log_device_placement=False)) as sess:
if FLAGS.dataset == 'mnist':
assert False
dcgan = DCGAN(sess, image_size=FLAGS.image_size, batch_size=FLAGS.batch_size,
sample_size = 64,
z_dim = 8192,
d_label_smooth = .25,
generator_target_prob = .75 / 2.,
out_stddev = .075,
out_init_b = - .45,
image_shape=[FLAGS.image_width, FLAGS.image_width, 3],
dataset_name=FLAGS.dataset, is_crop=FLAGS.is_crop, checkpoint_dir=FLAGS.checkpoint_dir,
sample_dir=FLAGS.sample_dir,
generator=Generator(),
train_func=train, discriminator_func=discriminator,
build_model_func=build_model, config=FLAGS,
devices=["gpu:0", "gpu:1", "gpu:2", "gpu:3"] #, "gpu:4"]
)
if FLAGS.is_train:
print "TRAINING"
dcgan.train(FLAGS)
print "DONE TRAINING"
else:
dcgan.load(FLAGS.checkpoint_dir)
OPTION = 2
visualize(sess, dcgan, FLAGS, OPTION)
def main(job_id, params):
logger.info("Model options:\n{}".format(pprint.pformat(params)))
validerr = train(**params)
return validerr
def train_word2vec(corpus_file, out_file, **kwargs):
word2vec.train(corpus_file, out_file, **kwargs)
def train_model(db_file, entity_db_file, vocab_file, word2vec, **kwargs):
db = AbstractDB(db_file, 'r')
entity_db = EntityDB.load(entity_db_file)
vocab = Vocab.load(vocab_file)
if word2vec:
w2vec = ModelReader(word2vec)
else:
w2vec = None
train.train(db, entity_db, vocab, w2vec, **kwargs)
def main(outputName):
print("Welcome into RNTN implementation 0.6 (recording will be on ", outputName, ")")
random.seed("MetaMind") # Lucky seed ? Fixed seed for replication
np.random.seed(7)
print("Parsing dataset, creating dictionary...")
# Dictionary initialisation
vocabulary.initVocab()
# Loading dataset
datasets = {}
datasets['training'] = utils.loadDataset("trees/train.txt");
print("Training loaded !")
datasets['testing'] = utils.loadDataset("trees/test.txt");
print("Testing loaded !")
datasets['validating'] = utils.loadDataset("trees/dev.txt");
print("Validation loaded !")
print("Datasets loaded !")
print("Nb of words", vocabulary.vocab.length());
# Datatransform (normalisation, remove outliers,...) ?? > Not here
# Grid search on our hyperparameters (too long for complete k-fold cross validation so just train/test)
for mBS in miniBatchSize:
for aRNI in adagradResetNbIter:
for lR in learningRate:
for rT in regularisationTerm:
params = {}
params["nbEpoch"] = nbEpoch
params["learningRate"] = lR
params["regularisationTerm"] = rT
params["adagradResetNbIter"] = aRNI
params["miniBatchSize"] = mBS
# No need to reset the vocabulary values (contained in model.L so automatically reset)
# Same for the training and testing set (output values recomputed at each iterations)
model, errors = train.train(outputName, datasets, params)
# TODO: Plot the cross-validation curve
# TODO: Plot a heat map of the hyperparameters cost to help tunning them ?
## Validate on the last computed model (Only used for final training)
#print("Training complete, validating...")
#vaError = model.computeError(datasets['validating'], True)
#print("Validation error: ", vaError)
print("The End. Thank you for using this program!")