frame_level_models.py 文件源码

python
阅读 34 收藏 0 点赞 0 评论 0

项目:youtube-8m 作者: wangheda 项目源码 文件源码
def __call__(self, x, train=True):
        shape = x.get_shape().as_list()

        with tf.variable_scope(self.name) as scope:
            self.beta = tf.get_variable("beta", shape[1:],
                                        initializer=tf.constant_initializer(0.))
            self.gamma = tf.get_variable("gamma", shape[1:],
                                         initializer=tf.random_normal_initializer(1.,0.02))
            self.mean = tf.get_variable("mean", shape[1:],
                                        initializer=tf.constant_initializer(0.),trainable=False)
            self.variance = tf.get_variable("variance",shape[1:],
                                            initializer=tf.constant_initializer(1.),trainable=False)
            if train:
                batch_mean, batch_var = tf.nn.moments(x, [0], name='moments')

                self.mean.assign(batch_mean)
                self.variance.assign(batch_var)
                ema_apply_op = self.ema.apply([self.mean, self.variance])
                with tf.control_dependencies([ema_apply_op]):
                    mean, var = tf.identity(batch_mean), tf.identity(batch_var)
            else:
                mean, var = self.ema.average(self.mean), self.ema.average(self.variance)

            normed = tf.nn.batch_normalization(x, mean, var, self.beta, self.gamma, self.epsilon)

        return normed
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号