def total_variation(image_batch):
"""
:param image_batch: A 4D tensor of shape [batch_size, width, height, channels]
"""
batch_shape = image_batch.get_shape().as_list()
width = batch_shape[1]
left = tf.slice(image_batch, [0, 0, 0, 0], [-1, width - 1, -1, -1])
right = tf.slice(image_batch, [0, 1, 0, 0], [-1, -1, -1, -1])
height = batch_shape[2]
top = tf.slice(image_batch, [0, 0, 0, 0], [-1, -1, height - 1, -1])
bottom = tf.slice(image_batch, [0, 0, 1, 0], [-1, -1, -1, -1])
# left and right are 1 less wide than the original, top and bottom 1 less tall
# In order to combine them, we take 1 off the height of left-right, and 1 off width of top-bottom
horizontal_diff = tf.slice(tf.sub(left, right), [0, 0, 0, 0], [-1, -1, height - 1, -1])
vertical_diff = tf.slice(tf.sub(top, bottom), [0, 0, 0, 0], [-1, width - 1, -1, -1])
sum_of_pixel_diffs_squared = tf.add(tf.square(horizontal_diff), tf.square(vertical_diff))
total_variation = tf.reduce_sum(tf.sqrt(sum_of_pixel_diffs_squared))
# TODO: Should this be normalized by the number of pixels?
return total_variation
style_helpers.py 文件源码
python
阅读 18
收藏 0
点赞 0
评论 0
评论列表
文章目录