def tf_batch_gram_matrix(batch):
_, height, width, channels = tensor_shape(batch)
batch = tf.reshape(batch, (-1, height * width, channels))
batch_T = tf.batch_matrix_transpose(batch)
return tf.batch_matmul(batch_T, batch) / (height * width * channels)
评论列表
文章目录