theano_backend.py 文件源码

python
阅读 34 收藏 0 点赞 0 评论 0

项目:keras_superpixel_pooling 作者: parag2489 项目源码 文件源码
def batch_normalization(x, mean, var, beta, gamma, epsilon=1e-3):
    """Apply batch normalization on x given mean, var, beta and gamma.
    """
    # TODO remove this if statement when Theano without
    # T.nnet.bn.batch_normalization_test is deprecated
    if not hasattr(T.nnet.bn, 'batch_normalization_test'):
        return _old_batch_normalization(x, mean, var, beta, gamma, epsilon)

    if gamma is None:
        gamma = ones_like(var)
    if beta is None:
        beta = zeros_like(mean)

    if mean.ndim == 1:
        # based on TensorFlow's default: normalize along rightmost dimension
        reduction_axes = list(range(x.ndim - 1))
    else:
        reduction_axes = [i for i in range(x.ndim) if mean.broadcastable[i]]

    return T.nnet.bn.batch_normalization_test(
        x, gamma, beta, mean, var, reduction_axes, epsilon)


# TODO remove this function when Theano without
# T.nnet.bn.batch_normalization_train is deprecated
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号