def conv_feat_map_tensor_gram(conv_fmap_tensor):
"""Compute Gram matrix of conv feature maps.
Used in style transfer.
"""
tf.assert_equal(tf.rank(conv_fmap_tensor), 4)
shape = tf.shape(conv_fmap_tensor)
num_images = shape[0]
width = shape[1]
height = shape[2]
num_filters = shape[3]
filters = tf.reshape(conv_fmap_tensor,
tf.stack([num_images, -1, num_filters]))
grams = tf.matmul(
filters, filters,
transpose_a=True) / tf.to_float(width * height * num_filters)
return grams
评论列表
文章目录