def compute_moments(_inputs, moments=[2, 3]):
"""From an image input, compute moments"""
_inputs_sq = tf.square(_inputs)
_inputs_cube = tf.pow(_inputs, 3)
height = int(_inputs.get_shape()[1])
width = int(_inputs.get_shape()[2])
channels = int(_inputs.get_shape()[3])
def ConvFlatten(x, kernel_size):
# w_sum = tf.ones([kernel_size, kernel_size, channels, 1]) / (kernel_size * kernel_size * channels)
w_sum = tf.eye(num_rows=channels, num_columns=channels, batch_shape=[kernel_size * kernel_size])
w_sum = tf.reshape(w_sum, [kernel_size, kernel_size, channels, channels])
w_sum = w_sum / (kernel_size * kernel_size)
sum_ = tf.nn.conv2d(x, w_sum, strides=[1, 1, 1, 1], padding='VALID')
size = prod_dim(sum_)
assert size == (height - kernel_size + 1) * (width - kernel_size + 1) * channels, size
return tf.reshape(sum_, [-1, size])
outputs = []
for size in [3, 4, 5]:
mean = ConvFlatten(_inputs, size)
square = ConvFlatten(_inputs_sq, size)
var = square - tf.square(mean)
if 2 in moments:
outputs.append(var)
if 3 in moments:
cube = ConvFlatten(_inputs_cube, size)
skewness = cube - 3.0 * mean * var - tf.pow(mean, 3) # Unnormalized
outputs.append(skewness)
return tf.concat(outputs, 1)
评论列表
文章目录