reshape.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:inferno 作者: inferno-pytorch 项目源码 文件源码
def forward(self, *inputs):
        dim = inputs[0].dim()
        assert_(dim in [4, 5],
                'Input tensors must either be 4 or 5 '
                'dimensional, but inputs[0] is {}D.'.format(dim),
                ShapeError)
        # Get resize function
        spatial_dim = {4: 2, 5: 3}[dim]
        resize_function = getattr(F, 'adaptive_{}_pool{}d'.format(self.pool_mode,
                                                                  spatial_dim))
        target_size = pyu.as_tuple_of_len(self.target_size, spatial_dim)
        # Do the resizing
        resized_inputs = []
        for input_num, input in enumerate(inputs):
            # Make sure the dim checks out
            assert_(input.dim() == dim,
                    "Expected inputs[{}] to be a {}D tensor, got a {}D "
                    "tensor instead.".format(input_num, dim, input.dim()),
                    ShapeError)
            resized_inputs.append(resize_function(input, target_size))
        # Concatenate along the channel axis
        concatenated = torch.cat(tuple(resized_inputs), 1)
        # Done
        return concatenated
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号