ops.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:CycleGAN-Tensorflow 作者: gitlimlab 项目源码 文件源码
def _norm(input, is_train, reuse=True, norm=None):
    assert norm in ['instance', 'batch', None]
    if norm == 'instance':
        with tf.variable_scope('instance_norm', reuse=reuse):
            eps = 1e-5
            mean, sigma = tf.nn.moments(input, [1, 2], keep_dims=True)
            normalized = (input - mean) / (tf.sqrt(sigma) + eps)
            out = normalized
            # Apply momentum (not mendatory)
            #c = input.get_shape()[-1]
            #shift = tf.get_variable('shift', shape=[c],
            #                        initializer=tf.zeros_initializer())
            #scale = tf.get_variable('scale', shape=[c],
            #                        initializer=tf.random_normal_initializer(1.0, 0.02))
            #out = scale * normalized + shift
    elif norm == 'batch':
        with tf.variable_scope('batch_norm', reuse=reuse):
            out = tf.contrib.layers.batch_norm(input,
                                               decay=0.99, center=True,
                                               scale=True, is_training=is_train,
                                               updates_collections=None)
    else:
        out = input

    return out
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号