asr_model.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:TF-Speech-Recognition 作者: ZhishengWang 项目源码 文件源码
def _batch_norm_vec(self, name, x):
        """Batch normalization."""
        with tf.variable_scope(name):
            params_shape = [x.get_shape()[-1]]

            beta = tf.get_variable(
              'beta', params_shape, tf.float32,
              initializer=tf.constant_initializer(0.0, tf.float32))
            gamma = tf.get_variable(
              'gamma', params_shape, tf.float32,
              initializer=tf.constant_initializer(1.0, tf.float32))

            #beta = Qmf_quan(beta, 4, 7)
            #gamma = Qmf_quan(gamma, 4, 7)

        if self.mode == 'train':
            mean, variance = tf.nn.moments(x, [0], name='moments')

            moving_mean = tf.get_variable(
                'moving_mean', params_shape, tf.float32,
                initializer=tf.constant_initializer(0.0, tf.float32),
                trainable=False)
            moving_variance = tf.get_variable(
                'moving_variance', params_shape, tf.float32,
                initializer=tf.constant_initializer(1.0, tf.float32),
                trainable=False)
            #moving_mean = Qmf_quan(moving_mean, 4, 7)
            #moving_variance = Qmf_quan(moving_variance, 4, 7)
            self._extra_train_ops.append(moving_averages.assign_moving_average(
                moving_mean, mean, 0.9))
            self._extra_train_ops.append(moving_averages.assign_moving_average(
                moving_variance, variance, 0.9))
        else:
            mean = tf.get_variable(
                'moving_mean', params_shape, tf.float32,
                initializer=tf.constant_initializer(0.0, tf.float32),
                trainable=False)
            variance = tf.get_variable(
                'moving_variance', params_shape, tf.float32,
                initializer=tf.constant_initializer(1.0, tf.float32),
                trainable=False)
            #mean = Qmf_quan(mean, 4, 7)
            #variance = Qmf_quan(variance, 4, 7)
            tf.summary.histogram(mean.op.name, mean)
            tf.summary.histogram(variance.op.name, variance)
        # elipson used to be 1e-5. Maybe 0.001 solves NaN problem in deeper net.
        y = tf.nn.batch_normalization(
              x, mean, variance, beta, gamma, 0.001)
        y.set_shape(x.get_shape())
        return y
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号