def __setup_training(self,images, labels):
tf.logging.set_verbosity(tf.logging.INFO)
logits, end_points = self.network_fn(images)
#############################
# Specify the loss function #
#############################
loss_1 = None
if 'AuxLogits' in end_points:
loss_1 = tf.losses.softmax_cross_entropy(
logits=end_points['AuxLogits'], onehot_labels=labels,
label_smoothing=self.label_smoothing, weights=0.4, scope='aux_loss')
total_loss = tf.losses.softmax_cross_entropy(
logits=logits, onehot_labels=labels,
label_smoothing=self.label_smoothing, weights=1.0)
if loss_1 is not None:
total_loss = total_loss + loss_1
global_step = slim.create_global_step()
# Variables to train.
variables_to_train = self.__get_variables_to_train()
learning_rate = self.__configure_learning_rate(self.dataset.num_samples, global_step)
optimizer = self.__configure_optimizer(learning_rate)
train_op = slim.learning.create_train_op(total_loss, optimizer, variables_to_train=variables_to_train)
self.__add_summaries(end_points, learning_rate, total_loss)
###########################
# Kicks off the training. #
###########################
slim.learning.train(
train_op,
logdir=self.train_dir,
init_fn=self.__get_init_fn(),
number_of_steps=self.max_number_of_steps,
log_every_n_steps=self.log_every_n_steps,
save_summaries_secs=self.save_summaries_secs,
save_interval_secs=self.save_interval_secs)
return
评论列表
文章目录