pot.py 文件源码

python
阅读 31 收藏 0 点赞 0 评论 0

项目:adagan 作者: tolstikhin 项目源码 文件源码
def compute_moments(_inputs, moments=[2, 3]):
    """From an image input, compute moments"""
    _inputs_sq = tf.square(_inputs)
    _inputs_cube = tf.pow(_inputs, 3)
    height = int(_inputs.get_shape()[1])
    width = int(_inputs.get_shape()[2])
    channels = int(_inputs.get_shape()[3])
    def ConvFlatten(x, kernel_size):
#                 w_sum = tf.ones([kernel_size, kernel_size, channels, 1]) / (kernel_size * kernel_size * channels)
        w_sum = tf.eye(num_rows=channels, num_columns=channels, batch_shape=[kernel_size * kernel_size])
        w_sum = tf.reshape(w_sum, [kernel_size, kernel_size, channels, channels])
        w_sum = w_sum / (kernel_size * kernel_size)
        sum_ = tf.nn.conv2d(x, w_sum, strides=[1, 1, 1, 1], padding='VALID')
        size = prod_dim(sum_)
        assert size == (height - kernel_size + 1) * (width - kernel_size + 1) * channels, size
        return tf.reshape(sum_, [-1, size])
    outputs = []
    for size in [3, 4, 5]:
        mean = ConvFlatten(_inputs, size)
        square = ConvFlatten(_inputs_sq, size)
        var = square - tf.square(mean)
        if 2 in moments:
            outputs.append(var)
        if 3 in moments:
            cube = ConvFlatten(_inputs_cube, size)
            skewness = cube - 3.0 * mean * var - tf.pow(mean, 3)  # Unnormalized
            outputs.append(skewness)
    return tf.concat(outputs, 1)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号