def rocket_distillation(y, teacher_scores, labels, T, alpha): return F.kl_div(F.log_softmax(y / T), F.softmax(teacher_scores / T)) * (T * T * 2. * alpha)