interpolate.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:FaderNetworks 作者: facebookresearch 项目源码 文件源码
def get_interpolations(ae, images, attributes, params):
    """
    Reconstruct images / create interpolations
    """
    assert len(images) == len(attributes)
    enc_outputs = ae.encode(images)

    # interpolation values
    alphas = np.linspace(1 - params.alpha_min, params.alpha_max, params.n_interpolations)
    alphas = [torch.FloatTensor([1 - alpha, alpha]) for alpha in alphas]

    # original image / reconstructed image / interpolations
    outputs = []
    outputs.append(images)
    outputs.append(ae.decode(enc_outputs, attributes)[-1])
    for alpha in alphas:
        alpha = Variable(alpha.unsqueeze(0).expand((len(images), 2)).cuda())
        outputs.append(ae.decode(enc_outputs, alpha)[-1])

    # return stacked images
    return torch.cat([x.unsqueeze(1) for x in outputs], 1).data.cpu()
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号