def _tf_nth(fns, n):
"""Runs only the nth element of fns, where n is a scalar integer tensor."""
cases = [(tf.equal(tf.constant(i, n.dtype), n), fn)
for i, fn in enumerate(fns)]
final_pred, final_fn = cases.pop()
def default():
with tf.control_dependencies([
tf.Assert(final_pred, [n, len(fns)], name='nth_index_error')]):
return final_fn()
if len(fns) == 1: return default()
return tf.case(cases, default)
评论列表
文章目录