def standardize_images(x):
"""Image standardization on batches (tf.image.per_image_standardization)."""
with tf.name_scope("standardize_images", [x]):
x = tf.to_float(x)
x_mean = tf.reduce_mean(x, axis=[1, 2, 3], keep_dims=True)
x_variance = tf.reduce_mean(
tf.square(x - x_mean), axis=[1, 2, 3], keep_dims=True)
x_shape = shape_list(x)
num_pixels = tf.to_float(x_shape[1] * x_shape[2] * 3)
x = (x - x_mean) / tf.maximum(tf.sqrt(x_variance), tf.rsqrt(num_pixels))
# TODO(lukaszkaiser): remove hack below, needed for greedy decoding for now.
if x.shape and len(x.shape) == 4 and x.shape[3] == 1:
x = tf.concat([x, x, x], axis=3) # Not used, just a dead tf.cond branch.
x.set_shape([None, None, None, 3])
return x
评论列表
文章目录