def task_loss(Y_sched, Y_actual, params):
return (params["gamma_under"] * torch.clamp(Y_actual - Y_sched, min=0) +
params["gamma_over"] * torch.clamp(Y_sched - Y_actual, min=0) +
0.5 * (Y_sched - Y_actual)**2).mean(0)
评论列表
文章目录