def get_step(v, dv): I = dv < 1e-12 if torch.sum(I) > 0: # TODO: Use something like torch.any(dv < 0) a = -v / dv return torch.min(a[I]) else: return 1