def draw(self, N):
'''
Draw N samples from multinomial
'''
K = self.alias.size(0)
kk = torch.LongTensor(np.random.randint(0,K, size=N))
prob = self.prob.index_select(0, kk)
alias = self.alias.index_select(0, kk)
# b is whether a random number is greater than q
b = torch.bernoulli(prob)
oq = kk.mul(b.long())
oj = alias.mul((1-b).long())
return oq + oj
评论列表
文章目录