layers.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:third_person_im 作者: bstadie 项目源码 文件源码
def __init__(self, incoming, center=True, scale=False, epsilon=0.001, decay=0.9,
                 beta=tf.zeros_initializer, gamma=tf.ones_initializer, moving_mean=tf.zeros_initializer,
                 moving_variance=tf.ones_initializer, **kwargs):
        super(BatchNormLayer, self).__init__(incoming, **kwargs)

        self.center = center
        self.scale = scale
        self.epsilon = epsilon
        self.decay = decay

        input_shape = incoming.output_shape
        axis = list(range(len(input_shape) - 1))
        params_shape = input_shape[-1:]

        if center:
            self.beta = self.add_param(beta, shape=params_shape, name='beta', trainable=True, regularizable=False)
        else:
            self.beta = None
        if scale:
            self.gamma = self.add_param(gamma, shape=params_shape, name='gamma', trainable=True, regularizable=True)
        else:
            self.gamma = None

        self.moving_mean = self.add_param(moving_mean, shape=params_shape, name='moving_mean', trainable=False,
                                          regularizable=False)
        self.moving_variance = self.add_param(moving_variance, shape=params_shape, name='moving_variance',
                                              trainable=False, regularizable=False)
        self.axis = axis
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号