def get_interpolations(ae, images, attributes, params):
"""
Reconstruct images / create interpolations
"""
assert len(images) == len(attributes)
enc_outputs = ae.encode(images)
# interpolation values
alphas = np.linspace(1 - params.alpha_min, params.alpha_max, params.n_interpolations)
alphas = [torch.FloatTensor([1 - alpha, alpha]) for alpha in alphas]
# original image / reconstructed image / interpolations
outputs = []
outputs.append(images)
outputs.append(ae.decode(enc_outputs, attributes)[-1])
for alpha in alphas:
alpha = Variable(alpha.unsqueeze(0).expand((len(images), 2)).cuda())
outputs.append(ae.decode(enc_outputs, alpha)[-1])
# return stacked images
return torch.cat([x.unsqueeze(1) for x in outputs], 1).data.cpu()
评论列表
文章目录