normalization.py 文件源码

python
阅读 33 收藏 0 点赞 0 评论 0

项目:tensorbayes 作者: RuiShu 项目源码 文件源码
def instance_norm(x,
                  shift=True,
                  scale=True,
                  eps=1e-3,
                  scope=None,
                  reuse=None):

    # Expect a 4-D Tensor
    C = x._shape_as_list()[-1]

    with tf.variable_scope(scope, 'instance_norm', reuse=reuse):
        # Get mean and variance, normalize input
        m, v = tf.nn.moments(x, [1, 2], keep_dims=True)
        output = (x - m) * tf.rsqrt(v + eps)

        if scale:
            output *= tf.get_variable('gamma', C, initializer=tf.ones_initializer)

        if shift:
            output += tf.get_variable('beta', C, initializer=tf.zeros_initializer)

    return output
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号