def set_model_permutations(self):
self.model_permutations = []
self.model_unpermutations = []
for n in range(1, self.N):
permutation = list(range(2 ** (n - 1)))
if n > 1:
while permutation == list(range(2 ** (n - 1))):
permutation = torch.randperm(2 ** (n - 1)).numpy().tolist()
self.model_permutations.append(permutation)
unpermutation = list(range(len(permutation)))
for i in range(len(permutation)):
unpermutation[permutation[i]] = i
self.model_unpermutations.append(unpermutation)
评论列表
文章目录