def attention_image_summary(attn, image_shapes=None):
"""Compute color image summary.
Args:
attn: a Tensor with shape [batch, num_heads, query_length, memory_length]
image_shapes: optional tuple of integer scalars.
If the query positions and memory positions represent the
pixels of flattened images, then pass in their dimensions:
(query_rows, query_cols, memory_rows, memory_cols).
If the query positions and memory positions represent the
pixels x channels of flattened images, then pass in their dimensions:
(query_rows, query_cols, query_channels,
memory_rows, memory_cols, memory_channels).
"""
num_heads = common_layers.shape_list(attn)[1]
# [batch, query_length, memory_length, num_heads]
image = tf.transpose(attn, [0, 2, 3, 1])
image = tf.pow(image, 0.2) # for high-dynamic-range
# Each head will correspond to one of RGB.
# pad the heads to be a multiple of 3
image = tf.pad(image, [[0, 0], [0, 0], [0, 0], [0, tf.mod(-num_heads, 3)]])
image = split_last_dimension(image, 3)
image = tf.reduce_max(image, 4)
if image_shapes is not None:
if len(image_shapes) == 4:
q_rows, q_cols, m_rows, m_cols = list(image_shapes)
image = tf.reshape(image, [-1, q_rows, q_cols, m_rows, m_cols, 3])
image = tf.transpose(image, [0, 1, 3, 2, 4, 5])
image = tf.reshape(image, [-1, q_rows * m_rows, q_cols * m_cols, 3])
else:
assert len(image_shapes) == 6
q_rows, q_cols, q_channnels, m_rows, m_cols, m_channels = list(
image_shapes)
image = tf.reshape(
image,
[-1, q_rows, q_cols, q_channnels, m_rows, m_cols, m_channels, 3])
image = tf.transpose(image, [0, 1, 4, 3, 2, 5, 6, 7])
image = tf.reshape(
image,
[-1, q_rows * m_rows * q_channnels, q_cols * m_cols * m_channels, 3])
tf.summary.image("attention", image, max_outputs=1)
评论列表
文章目录