fft_tree_constrained_inference.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:wip-constrained-extractor 作者: brain-research 项目源码 文件源码
def add_leading_idx_3(t):
  """Utility to automatically add indices used by gather_nd.

  Args:
    t: tensor of shape [b,s,...,n]

  Returns:
    t: tensor of shape [b,s,...,n+2] where the (b,s)-indices are
      prepended onto the values in the final mode.
  """

  dims = [d.value for d in t.get_shape().dims]
  b_size, s_size, ss_size = dims[:3]
  b_idx = tf.reshape(tf.range(0, b_size), [b_size] + [1] * (len(dims) - 1))
  s_idx = tf.reshape(tf.range(0, s_size), [1, s_size] + [1] * (len(dims) - 2))
  ss_idx = tf.reshape(
      tf.range(0, ss_size), [1, 1, ss_size] + [1] * (len(dims) - 3))
  tiled_b_idx = tf.tile(b_idx, [1] + dims[1:-1] + [1])
  tiled_s_idx = tf.tile(s_idx, [dims[0], 1] + dims[2:-1] + [1])
  tiled_ss_idx = tf.tile(ss_idx, [dims[0], dims[1], 1] + dims[3:-1] + [1])
  t_with_b_s_ss_idx = tf.concat(
      len(dims) - 1, [tiled_b_idx, tiled_s_idx, tiled_ss_idx, t])
  return t_with_b_s_ss_idx
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号