misc.py 文件源码

python
阅读 72 收藏 0 点赞 0 评论 0

项目:pytorch-semantic-segmentation 作者: ZijunDeng 项目源码 文件源码
def get_upsampling_weight(in_channels, out_channels, kernel_size):
    factor = (kernel_size + 1) // 2
    if kernel_size % 2 == 1:
        center = factor - 1
    else:
        center = factor - 0.5
    og = np.ogrid[:kernel_size, :kernel_size]
    filt = (1 - abs(og[0] - center) / factor) * (1 - abs(og[1] - center) / factor)
    weight = np.zeros((in_channels, out_channels, kernel_size, kernel_size), dtype=np.float64)
    weight[list(range(in_channels)), list(range(out_channels)), :, :] = filt
    return torch.from_numpy(weight).float()
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号