def padded_gather_nd(params, indices, r, idx_rank):
"""Version of gather_nd that supports gradients and blank indices.
Works like gather_nd, but if an index is given as -1, a 0 will be inserted
in that spot in the output tensor.
Args:
params: tensor from which to gather (see gather_nd).
indices: tensor of indices (see gather_nd).
r: rank of params
idx_rank: rank of indices
Returns:
result: tensor shaped like indices containing things gathered from params
"""
# treats -1 indices as always gathering zeros
# pad 0 onto beginning of final dim of params
broadcasted_shift = tf.reshape(
tf.one_hot(
[r - 1], r, dtype=tf.int32), [1] * (idx_rank - 1) + [r])
shifted_idx = indices + broadcasted_shift
# unused indices might contain garbage, just 0 this out
shifted_idx = tf.maximum(0, shifted_idx)
padded_params = tf.pad(params, [[0, 0]] * (r - 1) + [[1, 0]])
# no gather_nd for now because gradient doesn't work
# return tf.gather_nd(padded_params,shifted_idx)
# HACK: work around lack of gradient for gather_nd
# params has shape of rank r
# indices has shape of rank idx_rank
params_shape = [d.value for d in padded_params.get_shape()]
idx_shape = [d.value for d in shifted_idx.get_shape()]
flat_params_x_size = 1
for dim in params_shape:
flat_params_x_size *= dim
flat_idx_x_size = 1
for dim in idx_shape[:-1]:
flat_idx_x_size *= dim
index_strides = tf.concat(
0, [tf.cumprod(
params_shape[1:], reverse=True), [1]])
index_strides = tf.reshape(index_strides, [1] * (idx_rank - 1) + [-1])
flat_idx = tf.reduce_sum(shifted_idx * index_strides, idx_rank - 1)
flat_idx = tf.reshape(flat_idx, [flat_idx_x_size])
flat_params = tf.reshape(padded_params, [flat_params_x_size])
result = tf.gather(flat_params, flat_idx)
result = tf.reshape(result, idx_shape[:-1])
return result
fft_tree_constrained_inference.py 文件源码
python
阅读 33
收藏 0
点赞 0
评论 0
评论列表
文章目录