_model.py 文件源码

python
阅读 33 收藏 0 点赞 0 评论 0

项目:tensorfx 作者: TensorLab 项目源码 文件源码
def build_training_graph(self, dataset):
    """Builds the graph to use for training a model.

    This operates on the current default graph.

    Args:
      dataset: The dataset to use during training.
    Returns:
      The set of tensors and ops references required for training.
    """
    with tf.name_scope('input'):
      # For training, ensure the data is shuffled, and don't limit to any fixed number of epochs.
      # The datasource to use is the one named as 'train' within the dataset.
      inputs = self.build_input(dataset, 'train',
                                batch=self.args.batch_size,
                                epochs=self.args.epochs,
                                shuffle=True)

    with tf.name_scope('inference'):
      inferences = self.build_inference(inputs, training=True)

    with tf.name_scope('train'):
      # Global steps is marked as trainable (explicitly), so as to have it be saved into checkpoints
      # for the purposes of resumed training.
      global_steps = tf.Variable(0, name='global_steps', dtype=tf.int64, trainable=True,
                                 collections=[tf.GraphKeys.GLOBAL_VARIABLES,
                                              tf.GraphKeys.GLOBAL_STEP,
                                              tf.GraphKeys.TRAINABLE_VARIABLES])
      loss, train_op = self.build_training(global_steps, inputs, inferences)

    with tf.name_scope('initialization'):
      # Create the saver that will be used to save and restore (in cases of resumed training)
      # trained variables.
      saver = tf.train.Saver(tf.trainable_variables(), sharded=True)

      init_op, local_init_op = self.build_init()
      ready_op = tf.report_uninitialized_variables(tf.trainable_variables())

    # Create the summary op that will merge all summaries across all sub-graphs
    summary_op = tf.summary.merge_all()

    scaffold = tf.train.Scaffold(init_op=init_op,
                                 local_init_op=local_init_op,
                                 ready_op=ready_op,
                                 ready_for_local_init_op=ready_op,
                                 summary_op=summary_op,
                                 saver=saver)
    scaffold.finalize()

    return {
      'global_steps': global_steps,
      'loss': loss,
      'init_op': init_op,
      'local_init_op': local_init_op,
      'ready_op': ready_op,
      'train_op': train_op,
      'summary_op': summary_op,
      'saver': saver,
      'scaffold': scaffold
    }
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号