common_attention.py 文件源码

python
阅读 29 收藏 0 点赞 0 评论 0

项目:tensor2tensor 作者: tensorflow 项目源码 文件源码
def gather_indices_2d(x, block_shape, block_stride):
  """Getting gather indices."""
  # making an identity matrix kernel
  kernel = tf.eye(block_shape[0] * block_shape[1])
  kernel = reshape_range(kernel, 0, 1, [block_shape[0], block_shape[1], 1])
  # making indices [1, h, w, 1] to appy convs
  x_shape = common_layers.shape_list(x)
  indices = tf.range(x_shape[2] * x_shape[3])
  indices = tf.reshape(indices, [1, x_shape[2], x_shape[3], 1])
  indices = tf.nn.conv2d(
      tf.cast(indices, tf.float32),
      kernel,
      strides=[1, block_stride[0], block_stride[1], 1],
      padding="VALID")
  # making indices [num_blocks, dim] to gather
  dims = common_layers.shape_list(indices)[:3]
  if all([isinstance(dim, int) for dim in dims]):
    num_blocks = functools.reduce(operator.mul, dims, 1)
  else:
    num_blocks = tf.reduce_prod(dims)
  indices = tf.reshape(indices, [num_blocks, -1])
  return tf.cast(indices, tf.int32)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号