def __call__(self, input):
mean = input.mean(self.axes, keepdims=True)
std = input.std(self.axes, keepdims=True) + self.epsilon
# Don't batchnoramlise a single data point
mean = ifelse(T.gt(input.shape[0], 1), mean, T.zeros(mean.shape, dtype=mean.dtype))
std = ifelse(T.gt(input.shape[0], 1), std, T.ones(std.shape, dtype=std.dtype))
return (input - mean) * T.addbroadcast((self.gamma / std) + self.beta, *self.axes)
BatchNormLayer.py 文件源码
python
阅读 17
收藏 0
点赞 0
评论 0
评论列表
文章目录