def add_leading_idx_3(t):
"""Utility to automatically add indices used by gather_nd.
Args:
t: tensor of shape [b,s,...,n]
Returns:
t: tensor of shape [b,s,...,n+2] where the (b,s)-indices are
prepended onto the values in the final mode.
"""
dims = [d.value for d in t.get_shape().dims]
b_size, s_size, ss_size = dims[:3]
b_idx = tf.reshape(tf.range(0, b_size), [b_size] + [1] * (len(dims) - 1))
s_idx = tf.reshape(tf.range(0, s_size), [1, s_size] + [1] * (len(dims) - 2))
ss_idx = tf.reshape(
tf.range(0, ss_size), [1, 1, ss_size] + [1] * (len(dims) - 3))
tiled_b_idx = tf.tile(b_idx, [1] + dims[1:-1] + [1])
tiled_s_idx = tf.tile(s_idx, [dims[0], 1] + dims[2:-1] + [1])
tiled_ss_idx = tf.tile(ss_idx, [dims[0], dims[1], 1] + dims[3:-1] + [1])
t_with_b_s_ss_idx = tf.concat(
len(dims) - 1, [tiled_b_idx, tiled_s_idx, tiled_ss_idx, t])
return t_with_b_s_ss_idx
fft_tree_constrained_inference.py 文件源码
python
阅读 28
收藏 0
点赞 0
评论 0
评论列表
文章目录