common_attention.py 文件源码

python
阅读 32 收藏 0 点赞 0 评论 0

项目:tensor2tensor 作者: tensorflow 项目源码 文件源码
def scatter_blocks_2d(x, indices, shape):
  """scatters blocks from x into shape with indices."""
  x_shape = common_layers.shape_list(x)
  # [length, batch, heads, dim]
  x_t = tf.transpose(
      tf.reshape(x, [x_shape[0], x_shape[1], -1, x_shape[-1]]), [2, 0, 1, 3])
  x_t_shape = common_layers.shape_list(x_t)
  indices = tf.reshape(indices, [-1, 1])
  scattered_x = tf.scatter_nd(indices, x_t, x_t_shape)
  scattered_x = tf.transpose(scattered_x, [1, 2, 0, 3])
  return tf.reshape(scattered_x, shape)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号