resnet.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:KittiSeg 作者: MarvinTeichmann 项目源码 文件源码
def _bn(x, is_training, hypes):
    x_shape = x.get_shape()
    params_shape = x_shape[-1:]
    axis = list(range(len(x_shape) - 1))

    beta = _get_variable('beta',
                         params_shape,
                         initializer=tf.zeros_initializer())
    gamma = _get_variable('gamma',
                          params_shape,
                          initializer=tf.ones_initializer())

    moving_mean = _get_variable('moving_mean',
                                params_shape,
                                initializer=tf.zeros_initializer(),
                                trainable=False)
    moving_variance = _get_variable('moving_variance',
                                    params_shape,
                                    initializer=tf.ones_initializer(),
                                    trainable=False)

    # These ops will only be preformed when training.
    mean, variance = tf.nn.moments(x, axis)

    update_moving_mean = moving_averages.assign_moving_average(moving_mean,
                                                               mean,
                                                               BN_DECAY)
    update_moving_variance = moving_averages.assign_moving_average(
        moving_variance, variance, BN_DECAY)
    if hypes['use_moving_average_bn']:
        tf.add_to_collection(UPDATE_OPS_COLLECTION, update_moving_mean)
        tf.add_to_collection(UPDATE_OPS_COLLECTION, update_moving_variance)

        mean, variance = control_flow_ops.cond(
            is_training, lambda: (mean, variance),
            lambda: (moving_mean, moving_variance))
    else:
        mean, variance = mean, variance

    x = tf.nn.batch_normalization(x, mean, variance, beta, gamma, BN_EPSILON)
    # x.set_shape(inputs.get_shape()) ??

    return x
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号