tensor.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:pytorch 作者: pytorch 项目源码 文件源码
def forward(ctx, input, dim=None, keepdim=None):
        ctx.dim = dim
        ctx.keepdim = False if keepdim is None else keepdim
        ctx.input_size = input.size()
        if dim is None:
            ctx.result = input.prod()
            ctx.save_for_backward(input)
            return input.new((ctx.result,))
        else:
            if keepdim is not None:
                output = input.prod(dim, keepdim=keepdim)
            else:
                output = input.prod(dim)
            ctx.save_for_backward(input, output)
            return output
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号