restnet_tensorflow.py 文件源码

python
阅读 16 收藏 0 点赞 0 评论 0

项目:mlAlgorithms 作者: gu-yan 项目源码 文件源码
def batch_norm(data, name):
    shape_param = data.get_shape()[-1]
    beta = tf.get_variable(name=name+'_beta', shape=shape_param, dtype=tf.float32,
                           initializer=tf.constant_initializer(0.0, tf.float32))
    gamma = tf.get_variable(name=name+'_gamma', shape=shape_param, dtype=tf.float32,
                            initializer=tf.constant_initializer(1.0, tf.float32))
    if FLAGS.train_mode:
        mean_param, variance_param = tf.nn.moments(x=data, axes=[0, 1, 2], name=name+'_moments')
        moving_mean = tf.get_variable(name=name+'_moving_mean', shape=shape_param, dtype=tf.float32,
                                      initializer=tf.constant_initializer(0.0, tf.float32), trainable=False)
        moving_variance = tf.get_variable(name=name+'_moving_variance', shape=shape_param, dtype=tf.float32,
                                          initializer=tf.constant_initializer(1.0, tf.float32), trainable=False)
        mean = moving_averages.assign_moving_average(variable=moving_mean, value=mean_param, decay=0.9)
        variance = moving_averages.assign_moving_average(variable=moving_variance, value=variance_param, decay=0.9)
    else:
        mean = tf.get_variable(name=name+'_moving_mean', shape=shape_param, dtype=tf.float32,
                               initializer=tf.constant_initializer(0.0, tf.float32), trainable=False)
        variance = tf.get_variable(name=name+'_moving_variance', shape=shape_param, dtype=tf.float32,
                                   initializer=tf.constant_initializer(1.0, tf.float32), trainable=False)
        tf.summary.scalar(mean.op.name, mean)
        tf.summary.scalar(variance.op.name, variance)
    b_norm = tf.nn.batch_normalization(x=data, mean=mean, variance=variance,
                                       offset=beta, scale=gamma, variance_epsilon=0.001, name=name)
    return b_norm
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号