basic.py 文件源码

python
阅读 35 收藏 0 点赞 0 评论 0

项目:PyTorch-Encoding 作者: zhanghang1989 项目源码 文件源码
def upsample(input, size=None, scale_factor=None, mode='nearest'):
    """Multi-GPU version torch.nn.functional.upsample

    Upsamples the input to either the given :attr:`size` or the given
    :attr:`scale_factor`

    The algorithm used for upsampling is determined by :attr:`mode`.

    Currently temporal, spatial and volumetric upsampling are supported, i.e.
    expected inputs are 3-D, 4-D or 5-D in shape.

    The input dimensions are interpreted in the form:
    `mini-batch x channels x [depth] x [height] x width`

    The modes available for upsampling are: `nearest`, `linear` (3D-only),
    `bilinear` (4D-only), `trilinear` (5D-only)

    Args:
        input (Variable): input
        size (int or Tuple[int] or Tuple[int, int] or Tuple[int, int, int]):
            output spatial size.
        scale_factor (int): multiplier for spatial size. Has to be an integer.
        mode (string): algorithm used for upsampling:
            'nearest' | 'linear' | 'bilinear' | 'trilinear'. Default: 'nearest'
    """
    if isinstance(input, Variable):
        return F.upsample(input, size=size, scale_factor=scale_factor,
                          mode=mode)
    elif isinstance(input, tuple) or isinstance(input, list):
        lock = threading.Lock()
        results = {}
        def _worker(i, x):
            try:
                with torch.cuda.device_of(x):
                    result =  F.upsample(x, size=size, \
                        scale_factor=scale_factor,mode=mode)
                with lock:
                    results[i] = result
            except Exception as e:
                with lock:
                    resutls[i] = e 
        # multi-threading for different gpu
        threads = [threading.Thread(target=_worker,
                                    args=(i, x),
                                    )
                   for i, (x) in enumerate(input)]
        for thread in threads:
            thread.start()
        for thread in threads:
            thread.join() 
        outputs = dict_to_list(results)
        return outputs
    else:
        raise RuntimeError('unknown input type')
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号