nn.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:chinese_image_captioning 作者: yuanx520 项目源码 文件源码
def _batch_norm(x, name, is_train):
    """ Apply a batch normalization layer. """
    with tf.variable_scope(name):
        inputs_shape = x.get_shape()
        axis = list(range(len(inputs_shape) - 1))
        param_shape = int(inputs_shape[-1])

        moving_mean = tf.get_variable('mean', [param_shape], initializer=tf.constant_initializer(0.0), trainable=False)
        moving_var = tf.get_variable('variance', [param_shape], initializer=tf.constant_initializer(1.0), trainable=False)

        beta = tf.get_variable('offset', [param_shape], initializer=tf.constant_initializer(0.0))
        gamma = tf.get_variable('scale', [param_shape], initializer=tf.constant_initializer(1.0))

        control_inputs = []

        def mean_var_with_update():
            mean, var = tf.nn.moments(x, axis)
            update_moving_mean = moving_averages.assign_moving_average(moving_mean, mean, 0.99)
            update_moving_var = moving_averages.assign_moving_average(moving_var, var, 0.99)
            control_inputs = [update_moving_mean, update_moving_var]
            return tf.identity(mean), tf.identity(var)

        def mean_var():
            mean = moving_mean
            var = moving_var            
            return tf.identity(mean), tf.identity(var)

        mean, var = tf.cond(is_train, mean_var_with_update, mean_var)

        with tf.control_dependencies(control_inputs):
            normed = tf.nn.batch_normalization(x, mean, var, beta, gamma, 1e-3)

    return normed
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号