def normalize_batch_in_training(x, gamma, beta,
reduction_axes, epsilon=1e-3):
"""Computes mean and std for batch then apply batch_normalization on batch.
"""
# TODO remove this if statement when Theano without
# T.nnet.bn.batch_normalization_train is deprecated
if not hasattr(T.nnet.bn, 'batch_normalization_train'):
return _old_normalize_batch_in_training(x, gamma, beta, reduction_axes, epsilon)
if gamma is None:
if beta is None:
gamma = ones_like(x)
else:
gamma = ones_like(beta)
if beta is None:
if gamma is None:
beta = zeros_like(x)
beta = zeros_like(gamma)
normed, mean, stdinv = T.nnet.bn.batch_normalization_train(
x, gamma, beta, reduction_axes, epsilon)
return normed, mean, T.inv(stdinv ** 2)
theano_backend.py 文件源码
python
阅读 28
收藏 0
点赞 0
评论 0
评论列表
文章目录