def main():
""" Run training and export summaries to data_dir/logs for a single test
setup and a single set of parameters. Summaries include a) TensorBoard
summaries, b) the latest train/test accuracies and raw edit distances
(status.txt), c) the latest test predictions along with test ground-truth
labels (test_label_seqs.pkl, test_prediction_seqs.pkl), d) visualizations
as training progresses (test_visualizations_######.png)."""
args = define_and_process_args()
print('\n', 'ARGUMENTS', '\n\n', args, '\n')
log_dir = get_log_dir(args)
print('\n', 'LOG DIRECTORY', '\n\n', log_dir, '\n')
standardized_data_path = os.path.join(args.data_dir, args.data_filename)
if not os.path.exists(standardized_data_path):
message = '%s does not exist.' % standardized_data_path
raise ValueError(message)
dataset = data.Dataset(standardized_data_path)
train_raw_seqs, test_raw_seqs = dataset.get_splits(args.test_users)
train_triplets = [data.prepare_raw_seq(seq) for seq in train_raw_seqs]
test_triplets = [data.prepare_raw_seq(seq) for seq in test_raw_seqs]
train_input_seqs, train_reset_seqs, train_label_seqs = zip(*train_triplets)
test_input_seqs, test_reset_seqs, test_label_seqs = zip(*test_triplets)
Model = eval('models.' + args.model_type + 'Model')
input_size = dataset.input_size
target_size = dataset.num_classes
# This is just to satisfy a low-CPU requirement on our cluster
# when using GPUs.
if 'CUDA_VISIBLE_DEVICES' in os.environ:
config = tf.ConfigProto(intra_op_parallelism_threads=2,
inter_op_parallelism_threads=2)
else:
config = None
with tf.Session(config=config) as sess:
model = Model(input_size, target_size, args.num_layers,
args.hidden_layer_size, args.init_scale,
args.dropout_keep_prob)
optimizer = optimizers.Optimizer(
model.loss, args.num_train_sweeps, args.initial_learning_rate,
args.num_initial_sweeps, args.num_sweeps_per_decay,
args.decay_factor, args.max_global_grad_norm)
train(sess, model, optimizer, log_dir, args.batch_size,
args.num_sweeps_per_summary, args.num_sweeps_per_save,
train_input_seqs, train_reset_seqs, train_label_seqs,
test_input_seqs, test_reset_seqs, test_label_seqs)
train_and_summarize.py 文件源码
python
阅读 18
收藏 0
点赞 0
评论 0
评论列表
文章目录