def backward(ctx, grad_output):
if ctx.tensor_first:
var, = ctx.saved_variables
return grad_output.mul(ctx.constant).mul(var.pow(ctx.constant - 1)), None
else:
var_result, = ctx.saved_variables
return None, grad_output.mul(var_result).mul_(math.log(ctx.constant))
评论列表
文章目录