def gather_tree(values, parents): """Tensor version of gather_tree_py""" res = tf.py_func( func=gather_tree_py, inp=[values, parents], Tout=values.dtype) res.set_shape(values.get_shape().as_list()) return res