def scatter_blocks_2d(x, indices, shape):
"""scatters blocks from x into shape with indices."""
x_shape = common_layers.shape_list(x)
# [length, batch, heads, dim]
x_t = tf.transpose(
tf.reshape(x, [x_shape[0], x_shape[1], -1, x_shape[-1]]), [2, 0, 1, 3])
x_t_shape = common_layers.shape_list(x_t)
indices = tf.reshape(indices, [-1, 1])
scattered_x = tf.scatter_nd(indices, x_t, x_t_shape)
scattered_x = tf.transpose(scattered_x, [1, 2, 0, 3])
return tf.reshape(scattered_x, shape)
评论列表
文章目录