def reindex_target(self, target, e):
""" Reindex target by embedding to be coherent. We have to invert
a permutation and add some padding to do it correctly. """
ind = torch.sort(e, 1)[1].squeeze()
# target = new_target(ind) -> new_target = target(ind_inv)
# invert permutation
ind_inv = torch.sort(ind, 1)[1]
mask = (target >= 0).astype(float)
target = target * mask
for example in xrange(self.batch_size):
tar = target[example].astype(int)
ind_inv_n = ind_inv[example].data.cpu().numpy()
tar = ind_inv_n[tar]
tar_aux = tar[np.where(mask[example] == 1)[0]]
tar[:tar_aux.shape[0]] = tar_aux
target[example] = tar
target = target * mask
return target
评论列表
文章目录