def forward(ctx, input, dim=None, keepdim=None):
ctx.dim = dim
ctx.keepdim = False if keepdim is None else keepdim
ctx.input_size = input.size()
if dim is None:
ctx.result = input.prod()
ctx.save_for_backward(input)
return input.new((ctx.result,))
else:
if keepdim is not None:
output = input.prod(dim, keepdim=keepdim)
else:
output = input.prod(dim)
ctx.save_for_backward(input, output)
return output
评论列表
文章目录