def backward(ctx, grad_output):
input, weight, bias = ctx.saved_variables
grad_input = grad_weight = grad_bias = None
if ctx.needs_input_grad[0]:
grad_input = torch.mm(grad_output, weight)
if ctx.needs_input_grad[1]:
grad_weight = torch.mm(grad_output.t(), input)
if bias is not None and ctx.needs_input_grad[2]:
grad_bias = torch.mv(grad_output.t(), Variable(ctx.add_buffer))
if bias is not None:
return grad_input, grad_weight, grad_bias
else:
return grad_input, grad_weight
评论列表
文章目录