def eliminate_rows(self, prob_sc, ind, phis):
""" eliminate rows of phis and prob_matrix scale """
length = prob_sc.size()[1]
mask = (prob_sc[:, :, 0] > 0.85).type(dtype)
rang = (Variable(torch.range(0, length - 1).unsqueeze(0)
.expand_as(mask)).
type(dtype))
ind_sc = torch.sort(rang * (1-mask) + length * mask, 1)[1]
# permute prob_sc
m = mask.unsqueeze(2).expand_as(prob_sc)
mm = m.clone()
mm[:, :, 1:] = 0
prob_sc = (torch.gather(prob_sc * (1 - m) + mm, 1,
ind_sc.unsqueeze(2).expand_as(prob_sc)))
# compose permutations
ind = torch.gather(ind, 1, ind_sc)
active = torch.gather(1-mask, 1, ind_sc)
# permute phis
active1 = active.unsqueeze(2).expand_as(phis)
ind1 = ind.unsqueeze(2).expand_as(phis)
active2 = active.unsqueeze(1).expand_as(phis)
ind2 = ind.unsqueeze(1).expand_as(phis)
phis_out = torch.gather(phis, 1, ind1) * active1
phis_out = torch.gather(phis_out, 2, ind2) * active2
return prob_sc, ind, phis_out, active
评论列表
文章目录