def training_step(loss, optimizer_handle, learning_rate, **kwargs):
'''
Creates the optimisation operation which is executed in each training iteration of the network
:param loss: The loss to be minimised
:param optimizer_handle: A handle to one of the tf optimisers
:param learning_rate: Learning rate
:param momentum: Optionally, you can also pass a momentum term to the optimiser.
:return: The training operation
'''
if 'momentum' in kwargs:
momentum = kwargs.get('momentum')
optimizer = optimizer_handle(learning_rate=learning_rate, momentum=momentum)
else:
optimizer = optimizer_handle(learning_rate=learning_rate)
# The with statement is needed to make sure the tf contrib version of batch norm properly performs its updates
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(update_ops):
train_op = optimizer.minimize(loss)
return train_op
评论列表
文章目录