def backward(self, grad_output):
tensors = self.saved_tensors
if len(tensors) == 2:
input, weight = tensors
bias = None
else:
input, weight, bias = tensors
grad_input = grad_weight = grad_bias = None
if self.needs_input_grad[0]:
grad_input = torch.mm(grad_output, weight)
if self.needs_input_grad[1]:
grad_weight = torch.mm(grad_output.t(), input)
if bias is not None and self.needs_input_grad[2]:
grad_bias = torch.mv(grad_output.t(), self.add_buffer)
if bias is not None:
return grad_input, grad_weight, grad_bias
else:
return grad_input, grad_weight
评论列表
文章目录