def batch_normalization(x, mean, var, beta, gamma, epsilon=1e-3):
"""Apply batch normalization on x given mean, var, beta and gamma.
"""
# TODO remove this if statement when Theano without
# T.nnet.bn.batch_normalization_test is deprecated
if not hasattr(T.nnet.bn, 'batch_normalization_test'):
return _old_batch_normalization(x, mean, var, beta, gamma, epsilon)
if gamma is None:
gamma = ones_like(var)
if beta is None:
beta = zeros_like(mean)
if mean.ndim == 1:
# based on TensorFlow's default: normalize along rightmost dimension
reduction_axes = list(range(x.ndim - 1))
else:
reduction_axes = [i for i in range(x.ndim) if mean.broadcastable[i]]
return T.nnet.bn.batch_normalization_test(
x, gamma, beta, mean, var, reduction_axes, epsilon)
# TODO remove this function when Theano without
# T.nnet.bn.batch_normalization_train is deprecated
theano_backend.py 文件源码
python
阅读 34
收藏 0
点赞 0
评论 0
评论列表
文章目录