nn.py 文件源码

python
阅读 20 收藏 0 点赞 0 评论 0

项目:deepsleepnet 作者: akaraspt 项目源码 文件源码
def batch_norm_new(name, input_var, is_train, decay=0.999, epsilon=1e-5):
    """Batch normalization modified from BatchNormLayer in Tensorlayer.
    Source: <https://github.com/zsdonghao/tensorlayer/blob/master/tensorlayer/layers.py#L2190>
    """

    inputs_shape = input_var.get_shape()
    axis = list(range(len(inputs_shape) - 1))
    params_shape = inputs_shape[-1:]

    with tf.variable_scope(name) as scope:
        # Trainable beta and gamma variables
        beta = tf.get_variable('beta',
                                shape=params_shape,
                                initializer=tf.zeros_initializer)
        gamma = tf.get_variable('gamma',
                                shape=params_shape,
                                initializer=tf.random_normal_initializer(mean=1.0, stddev=0.002))

        # Moving mean and variance updated during training
        moving_mean = tf.get_variable('moving_mean',
                                      params_shape,
                                      initializer=tf.zeros_initializer,
                                      trainable=False)
        moving_variance = tf.get_variable('moving_variance',
                                          params_shape,
                                          initializer=tf.constant_initializer(1.),
                                          trainable=False)

        # Compute mean and variance along axis
        batch_mean, batch_variance = tf.nn.moments(input_var, axis, name='moments')

        # Define ops to update moving_mean and moving_variance
        update_moving_mean = moving_averages.assign_moving_average(moving_mean, batch_mean, decay, zero_debias=False)
        update_moving_variance = moving_averages.assign_moving_average(moving_variance, batch_variance, decay, zero_debias=False)

        # Define a function that :
        # 1. Update moving_mean & moving_variance with batch_mean & batch_variance
        # 2. Then return the batch_mean & batch_variance
        def mean_var_with_update():
            with tf.control_dependencies([update_moving_mean, update_moving_variance]):
                return tf.identity(batch_mean), tf.identity(batch_variance)

        # Perform different ops for training and testing
        if is_train:
            mean, variance = mean_var_with_update()
            normed = tf.nn.batch_normalization(input_var, mean, variance, beta, gamma, epsilon)

        else:
            normed = tf.nn.batch_normalization(input_var, moving_mean, moving_variance, beta, gamma, epsilon)
        # mean, variance = tf.cond(
        #     is_train,
        #     mean_var_with_update, # Training
        #     lambda: (moving_mean, moving_variance) # Testing - it will use the moving_mean and moving_variance (fixed during test) that are computed during training
        # )
        # normed = tf.nn.batch_normalization(input_var, mean, variance, beta, gamma, epsilon)

        return normed
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号