model.py 文件源码

python
阅读 33 收藏 0 点赞 0 评论 0

项目:acdc_segmenter 作者: baumgach 项目源码 文件源码
def training_step(loss, optimizer_handle, learning_rate, **kwargs):
    '''
    Creates the optimisation operation which is executed in each training iteration of the network
    :param loss: The loss to be minimised
    :param optimizer_handle: A handle to one of the tf optimisers 
    :param learning_rate: Learning rate
    :param momentum: Optionally, you can also pass a momentum term to the optimiser. 
    :return: The training operation
    '''


    if 'momentum' in kwargs:
        momentum = kwargs.get('momentum')
        optimizer = optimizer_handle(learning_rate=learning_rate, momentum=momentum)
    else:
        optimizer = optimizer_handle(learning_rate=learning_rate)

    # The with statement is needed to make sure the tf contrib version of batch norm properly performs its updates
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    with tf.control_dependencies(update_ops):
        train_op = optimizer.minimize(loss)

    return train_op
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号