def psi(x, linearized=False):
"""
linearized=False -> psi from paper
linearized=True -> piecewise linear psi (which makes more sense to me)
"""
if not linearized:
ks = torch.floor(x / np.pi)
return (1 - 2 * (ks % 2)) * x.cos() - (2 * ks)
else:
return torch.minimum(np.pi / 2 - x, x.cos())
评论列表
文章目录