def build_training_graph(self, dataset):
"""Builds the graph to use for training a model.
This operates on the current default graph.
Args:
dataset: The dataset to use during training.
Returns:
The set of tensors and ops references required for training.
"""
with tf.name_scope('input'):
# For training, ensure the data is shuffled, and don't limit to any fixed number of epochs.
# The datasource to use is the one named as 'train' within the dataset.
inputs = self.build_input(dataset, 'train',
batch=self.args.batch_size,
epochs=self.args.epochs,
shuffle=True)
with tf.name_scope('inference'):
inferences = self.build_inference(inputs, training=True)
with tf.name_scope('train'):
# Global steps is marked as trainable (explicitly), so as to have it be saved into checkpoints
# for the purposes of resumed training.
global_steps = tf.Variable(0, name='global_steps', dtype=tf.int64, trainable=True,
collections=[tf.GraphKeys.GLOBAL_VARIABLES,
tf.GraphKeys.GLOBAL_STEP,
tf.GraphKeys.TRAINABLE_VARIABLES])
loss, train_op = self.build_training(global_steps, inputs, inferences)
with tf.name_scope('initialization'):
# Create the saver that will be used to save and restore (in cases of resumed training)
# trained variables.
saver = tf.train.Saver(tf.trainable_variables(), sharded=True)
init_op, local_init_op = self.build_init()
ready_op = tf.report_uninitialized_variables(tf.trainable_variables())
# Create the summary op that will merge all summaries across all sub-graphs
summary_op = tf.summary.merge_all()
scaffold = tf.train.Scaffold(init_op=init_op,
local_init_op=local_init_op,
ready_op=ready_op,
ready_for_local_init_op=ready_op,
summary_op=summary_op,
saver=saver)
scaffold.finalize()
return {
'global_steps': global_steps,
'loss': loss,
'init_op': init_op,
'local_init_op': local_init_op,
'ready_op': ready_op,
'train_op': train_op,
'summary_op': summary_op,
'saver': saver,
'scaffold': scaffold
}
评论列表
文章目录