def _thnn(self, fn_name, input, weight, *args):
impl = _thnn_convs[self.thnn_class_name(input)]
if self.groups == 1:
return impl[fn_name](self, self._bufs[0], input, weight, *args)
else:
res = []
for g in range(self.groups):
def group(tensor, dim=None):
if tensor is None:
return None
if dim is None:
dim = 0 if tensor.dim() == 1 else 1
n = tensor.size(dim) // self.groups
return tensor.narrow(dim, n * g, n).contiguous()
grouped_args = [group(input, 1), group(weight, 0)]
grouped_args += [group(t) for t in args]
res.append(impl[fn_name](self, self._bufs[g], *grouped_args))
if fn_name == 'grad_params':
return [torch.cat(t, 0) if t[0] is not None else None
for t in zip(*res)]
else:
return torch.cat(res, 1)
评论列表
文章目录