def th_gather_nd(x, coords): x = x.contiguous() inds = coords.mv(th.LongTensor(x.stride())) x_gather = th.index_select(th_flatten(x), 0, inds) return x_gather