def sort_by_embeddings(self, Phis, Inputs_N, e):
ind = torch.sort(e, 1)[1].squeeze()
for i, phis in enumerate(Phis):
# rearange phis
phis_out = (torch.gather(Phis[i], 1, ind.unsqueeze(2)
.expand_as(phis)))
Phis[i] = (torch.gather(phis_out, 2, ind.unsqueeze(1)
.expand_as(phis)))
# rearange inputs
Inputs_N[i] = torch.gather(Inputs_N[i], 1,
ind.unsqueeze(2).expand_as(Inputs_N[i]))
return Phis, Inputs_N
评论列表
文章目录