common.py 文件源码

python
阅读 16 收藏 0 点赞 0 评论 0

项目:tensorflow_multigpu_imagenet 作者: arashno 项目源码 文件源码
def batchNormalization(x, is_training= True, decay= 0.9, epsilon= 0.001):
    x_shape = x.get_shape()
    params_shape = x_shape[-1:]


    axis = list(range(len(x_shape) - 1))

    beta = _get_variable('beta',
                         params_shape,
                         initializer= tf.zeros_initializer)
    gamma = _get_variable('gamma',
                          params_shape,
                          initializer= tf.ones_initializer)

    moving_mean = _get_variable('moving_mean',
                                params_shape,
                                initializer= tf.zeros_initializer,
                                trainable= False)
    moving_variance = _get_variable('moving_variance',
                                    params_shape,
                                    initializer= tf.ones_initializer,
                                    trainable= False)

    # These ops will only be preformed when training.

    if is_training:
      mean, variance = tf.nn.moments(x, axis)
      update_moving_mean = moving_averages.assign_moving_average(moving_mean,
                                                               mean, decay)
      update_moving_variance = moving_averages.assign_moving_average(
        moving_variance, variance, decay)
      tf.add_to_collection(tf.GraphKeys.UPDATE_OPS , update_moving_mean)
      tf.add_to_collection(tf.GraphKeys.UPDATE_OPS , update_moving_variance)
      return tf.nn.batch_normalization(x, mean, variance, beta, gamma, epsilon)
    else:
      return tf.nn.batch_normalization(x, moving_mean, moving_variance, beta, gamma, epsilon)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号