def setup_reparam_mask(self, n):
while True:
mask = torch.bernoulli(0.30 * torch.ones(n))
if torch.sum(mask) < 0.40 * n and torch.sum(mask) > 0.5:
return mask
# for doing model sampling in different sequential orders
评论列表
文章目录