def normalize_batch_in_training(x, gamma, beta,
reduction_axes, epsilon=0.0001):
'''Computes mean and std for batch then apply batch_normalization on batch.
'''
dev = theano.config.device
use_cudnn = ndim(x) < 5 and reduction_axes == [0, 2, 3] and (dev.startswith('cuda') or dev.startswith('gpu'))
if use_cudnn:
broadcast_beta = beta.dimshuffle('x', 0, 'x', 'x')
broadcast_gamma = gamma.dimshuffle('x', 0, 'x', 'x')
try:
normed, mean, stdinv = theano.sandbox.cuda.dnn.dnn_batch_normalization_train(
x, broadcast_gamma, broadcast_beta, 'spatial', epsilon)
var = T.inv(stdinv ** 2)
return normed, T.flatten(mean), T.flatten(var)
except AttributeError:
pass
var = x.var(reduction_axes)
mean = x.mean(reduction_axes)
target_shape = []
for axis in range(ndim(x)):
if axis in reduction_axes:
target_shape.append(1)
else:
target_shape.append(x.shape[axis])
target_shape = T.stack(*target_shape)
broadcast_mean = T.reshape(mean, target_shape)
broadcast_var = T.reshape(var, target_shape)
broadcast_beta = T.reshape(beta, target_shape)
broadcast_gamma = T.reshape(gamma, target_shape)
normed = batch_normalization(x, broadcast_mean, broadcast_var,
broadcast_beta, broadcast_gamma,
epsilon)
return normed, mean, var
评论列表
文章目录