def decode(self, target):
target.fill(0.)
original_shape = target.shape
target = target.ravel()
for gene in self.genes:
target[gene.index] = gene.value
target = target.reshape(original_shape)
len_shape = len(original_shape)
kwargs = dict(norm='ortho')
if len_shape == 1:
out = idct(target, **kwargs)
elif len_shape == 2:
out = idct(idct(target.T, **kwargs).T, **kwargs)
elif len_shape >= 3:
shape = (np.prod(original_shape[:-1]), original_shape[-1])
target = target.reshape(shape)
out = idct(idct(target.T, **kwargs).T, **kwargs)
out = out.reshape(original_shape)
return out
评论列表
文章目录