train_and_summarize.py 文件源码

python
阅读 18 收藏 0 点赞 0 评论 0

项目:miccai-2016-surgical-activity-rec 作者: rdipietro 项目源码 文件源码
def main():
    """ Run training and export summaries to data_dir/logs for a single test
    setup and a single set of parameters. Summaries include a) TensorBoard
    summaries, b) the latest train/test accuracies and raw edit distances
    (status.txt), c) the latest test predictions along with test ground-truth
    labels (test_label_seqs.pkl, test_prediction_seqs.pkl), d) visualizations
    as training progresses (test_visualizations_######.png)."""

    args = define_and_process_args()
    print('\n', 'ARGUMENTS', '\n\n', args, '\n')

    log_dir = get_log_dir(args)
    print('\n', 'LOG DIRECTORY', '\n\n', log_dir, '\n')

    standardized_data_path = os.path.join(args.data_dir, args.data_filename)
    if not os.path.exists(standardized_data_path):
        message = '%s does not exist.' % standardized_data_path
        raise ValueError(message)

    dataset = data.Dataset(standardized_data_path)
    train_raw_seqs, test_raw_seqs = dataset.get_splits(args.test_users)
    train_triplets = [data.prepare_raw_seq(seq) for seq in train_raw_seqs]
    test_triplets = [data.prepare_raw_seq(seq) for seq in test_raw_seqs]

    train_input_seqs, train_reset_seqs, train_label_seqs = zip(*train_triplets)
    test_input_seqs, test_reset_seqs, test_label_seqs = zip(*test_triplets)

    Model = eval('models.' + args.model_type + 'Model')
    input_size = dataset.input_size
    target_size = dataset.num_classes

    # This is just to satisfy a low-CPU requirement on our cluster
    # when using GPUs.
    if 'CUDA_VISIBLE_DEVICES' in os.environ:
        config = tf.ConfigProto(intra_op_parallelism_threads=2,
                                inter_op_parallelism_threads=2)
    else:
        config = None

    with tf.Session(config=config) as sess:
        model = Model(input_size, target_size, args.num_layers,
                      args.hidden_layer_size, args.init_scale,
                      args.dropout_keep_prob)
        optimizer = optimizers.Optimizer(
            model.loss, args.num_train_sweeps, args.initial_learning_rate,
            args.num_initial_sweeps, args.num_sweeps_per_decay,
            args.decay_factor, args.max_global_grad_norm)
        train(sess, model, optimizer, log_dir, args.batch_size,
              args.num_sweeps_per_summary, args.num_sweeps_per_save,
              train_input_seqs, train_reset_seqs, train_label_seqs,
              test_input_seqs, test_reset_seqs, test_label_seqs)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号