task.py 文件源码

python
阅读 36 收藏 0 点赞 0 评论 0

项目:ISLES2017 作者: MiguelMonteiro 项目源码 文件源码
def run(target, is_chief, train_steps, job_dir, train_files, eval_files, num_epochs, learning_rate):
    num_channels = 6
    hooks = list()
    # does not work well in distributed mode cause it only counts local steps (I think...)
    hooks.append(tf.train.StopAtStepHook(train_steps))

    if is_chief:
        evaluation_graph = tf.Graph()
        with evaluation_graph.as_default():
            # Features and label tensors
            image, ground_truth, name = model.input_fn(eval_files, 1, shuffle=False, shared_name=None)
            # Returns dictionary of tensors to be evaluated
            metric_dict = model.model_fn(model.EVAL, name, image, ground_truth, num_channels, learning_rate)
            # hook that performs evaluation separate from training
            hooks.append(EvaluationRunHook(job_dir, metric_dict, evaluation_graph))
        hooks.append(CheckpointExporterHook(job_dir))

    # Create a new graph and specify that as default
    with tf.Graph().as_default():
        with tf.device(tf.train.replica_device_setter()):

            # Features and label tensors as read using filename queue
            image, ground_truth, name = model.input_fn(train_files, num_epochs, shuffle=True, shared_name='train_queue')

            # Returns the training graph and global step tensor
            train_op, log_hook, train_summaries = model.model_fn(model.TRAIN, name, image, ground_truth,
                                                                 num_channels, learning_rate)
            # Hook that logs training to the console
            hooks.append(log_hook)

            train_summary_hook = tf.train.SummarySaverHook(save_steps=1, output_dir=get_summary_dir(job_dir),
                                                           summary_op=train_summaries)
            hooks.append(train_summary_hook)

        # Creates a MonitoredSession for training
        # MonitoredSession is a Session-like object that handles
        # initialization, recovery and hooks
        # https://www.tensorflow.org/api_docs/python/tf/train/MonitoredTrainingSession
        with tf.train.MonitoredTrainingSession(master=target,
                                               is_chief=is_chief,
                                               checkpoint_dir=job_dir,
                                               hooks=hooks,
                                               save_checkpoint_secs=60*3,
                                               save_summaries_steps=1,
                                               log_step_count_steps=5) as session:
            # Run the training graph which returns the step number as tracked by
            # the global step tensor.
            # When train epochs is reached, session.should_stop() will be true.
            while not session.should_stop():
                session.run(train_op)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号