ops_org.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:AM-GAN 作者: ZhimingZhou 项目源码 文件源码
def batch_norm(inputs, cts, ldc, epsilon=0.001, bOffset=True, bScale=True, reuse=None, decay=0.999, is_training=True):

    name = get_name('bn', cts)
    with tf.variable_scope(name, reuse=reuse):

        inputs_shape = inputs.get_shape()
        params_shape = inputs_shape[-1:]
        axis = list(range(len(inputs_shape) - 1))

        offset, scale = None, None
        if bOffset:
            offset = tf.get_variable('offset', shape=params_shape, trainable=True, initializer=tf.zeros_initializer())
        if bScale:
            scale = tf.get_variable('scale', shape=params_shape, trainable=True, initializer=tf.ones_initializer())

        batch_mean, batch_variance = tf.nn.moments(inputs, axis)
        outputs = tf.nn.batch_normalization(inputs, batch_mean, batch_variance, offset, scale, epsilon)

        # Note: here for fast training we did not do the moving average for testing. which we usually not use.

    ldc.append(name + ' offset:' + str(bOffset) + ' scale:' + str(bScale))
    return outputs
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号