fft_tree_constrained_inference.py 文件源码

python
阅读 33 收藏 0 点赞 0 评论 0

项目:wip-constrained-extractor 作者: brain-research 项目源码 文件源码
def padded_gather_nd(params, indices, r, idx_rank):
  """Version of gather_nd that supports gradients and blank indices.

  Works like gather_nd, but if an index is given as -1, a 0 will be inserted
  in that spot in the output tensor.

  Args:
    params: tensor from which to gather (see gather_nd).
    indices: tensor of indices (see gather_nd).
    r: rank of params
    idx_rank: rank of indices

  Returns:
    result: tensor shaped like indices containing things gathered from params
  """

  # treats -1 indices as always gathering zeros
  # pad 0 onto beginning of final dim of params
  broadcasted_shift = tf.reshape(
      tf.one_hot(
          [r - 1], r, dtype=tf.int32), [1] * (idx_rank - 1) + [r])
  shifted_idx = indices + broadcasted_shift
  # unused indices might contain garbage, just 0 this out
  shifted_idx = tf.maximum(0, shifted_idx)
  padded_params = tf.pad(params, [[0, 0]] * (r - 1) + [[1, 0]])

  # no gather_nd for now because gradient doesn't work
  #   return tf.gather_nd(padded_params,shifted_idx)

  # HACK: work around lack of gradient for gather_nd
  # params has shape of rank r
  # indices has shape of rank idx_rank
  params_shape = [d.value for d in padded_params.get_shape()]
  idx_shape = [d.value for d in shifted_idx.get_shape()]
  flat_params_x_size = 1
  for dim in params_shape:
    flat_params_x_size *= dim
  flat_idx_x_size = 1
  for dim in idx_shape[:-1]:
    flat_idx_x_size *= dim

  index_strides = tf.concat(
      0, [tf.cumprod(
          params_shape[1:], reverse=True), [1]])
  index_strides = tf.reshape(index_strides, [1] * (idx_rank - 1) + [-1])
  flat_idx = tf.reduce_sum(shifted_idx * index_strides, idx_rank - 1)
  flat_idx = tf.reshape(flat_idx, [flat_idx_x_size])
  flat_params = tf.reshape(padded_params, [flat_params_x_size])

  result = tf.gather(flat_params, flat_idx)
  result = tf.reshape(result, idx_shape[:-1])

  return result
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号