def LayerNormalization(x, gamma, mask, estimated_mean=0.0, estimated_var=1.0):
assert x.ndim == 3 or x.ndim == 2
if x.ndim == 3:
x_mean = T.mean(x, axis=2).dimshuffle(0, 1, 'x')
x_var = T.var(x, axis=2).dimshuffle(0, 1, 'x')
return gamma*((x - x_mean) / T.sqrt(x_var+1e-7)), x_mean[0, 0], x_var[0, 0]
elif x.ndim == 2:
x_mean = T.mean(x, axis=1).dimshuffle(0, 'x')
x_var = T.var(x, axis=1).dimshuffle(0, 'x')
return gamma*((x - x_mean) / T.sqrt(x_var+1e-7)), x_mean[0], x_var[0]
# Does theano.batched_dot. If last_axis is on it will loop over the last axis, otherwise it will loop over the first axis.
评论列表
文章目录