conv.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:pytorch-coriander 作者: hughperkins 项目源码 文件源码
def _thnn(self, fn_name, input, weight, *args):
        impl = _thnn_convs[self.thnn_class_name(input)]
        if self.groups == 1:
            return impl[fn_name](self, self._bufs[0], input, weight, *args)
        else:
            res = []
            for g in range(self.groups):
                def group(tensor, dim=None):
                    if tensor is None:
                        return None
                    if dim is None:
                        dim = 0 if tensor.dim() == 1 else 1
                    n = tensor.size(dim) // self.groups
                    return tensor.narrow(dim, n * g, n).contiguous()

                grouped_args = [group(input, 1), group(weight, 0)]
                grouped_args += [group(t) for t in args]
                res.append(impl[fn_name](self, self._bufs[g], *grouped_args))
            if fn_name == 'grad_params':
                return [torch.cat(t, 0) if t[0] is not None else None
                        for t in zip(*res)]
            else:
                return torch.cat(res, 1)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号