KerasBatchNormalization.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:audit-log-detection 作者: twosixlabs 项目源码 文件源码
def call(self, x, mask=None):
        if self.mode == 0 or self.mode == 2:
            assert self.built, 'Layer must be built before being called'
            input_shape = self.input_spec[0].shape

            reduction_axes = list(range(len(input_shape)))
            del reduction_axes[self.axis]
            broadcast_shape = [1] * len(input_shape)
            broadcast_shape[self.axis] = input_shape[self.axis]

            # case: train mode (uses stats of the current batch)
            mean = K.mean(x, axis=reduction_axes)
            brodcast_mean = K.reshape(mean, broadcast_shape)
            std = K.mean(K.square(x - brodcast_mean) + self.epsilon, axis=reduction_axes)
            std = K.sqrt(std)
            brodcast_std = K.reshape(std, broadcast_shape)
            mean_update = self.momentum * self.running_mean + (1 - self.momentum) * mean
            std_update = self.momentum * self.running_std + (1 - self.momentum) * std

            if self.mode == 2:
                x_normed = (x - brodcast_mean) / (brodcast_std + self.epsilon)
                out = K.reshape(self.gamma, broadcast_shape) * x_normed + K.reshape(self.beta, broadcast_shape)
            else:
                # mode 0
                self.called_with = x
                self.updates = [(self.running_mean, mean_update),
                                (self.running_std, std_update)]
                x_normed = (x - brodcast_mean) / (brodcast_std + self.epsilon)

                # case: test mode (uses running averages)
                brodcast_running_mean = K.reshape(self.running_mean, broadcast_shape)
                brodcast_running_std = K.reshape(self.running_std, broadcast_shape)
                x_normed_running = ((x - brodcast_running_mean) / (brodcast_running_std + self.epsilon))

                # pick the normalized form of x corresponding to the training phase
                x_normed = K.in_train_phase(x_normed, x_normed_running)
                out = K.reshape(self.gamma, broadcast_shape) * x_normed + K.reshape(self.beta, broadcast_shape)

        elif self.mode == 1:
            # sample-wise normalization
            m = K.mean(x, axis=-1, keepdims=True)
            std = K.sqrt(K.var(x, axis=-1, keepdims=True) + self.epsilon)
            x_normed = (x - m) / (std + self.epsilon)
            out = self.gamma * x_normed + self.beta
        return out
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号