def _find_support(B, ns, supp_thresh):
"""Find features with non-zero coefficients."""
try:
support = (B.norm(p=2, dim=1) >= supp_thresh).expand_as(B)
support = torch.cat([s_j[:n_j] for s_j, n_j in
zip(support, ns)])
return torch.nonzero(support)[:, 0]
except IndexError:
return None
评论列表
文章目录