def sample(self):
"""
Returns a sample which has the same shape as `ps` (or `vs`), except
that if ``one_hot=True`` (and no `vs` is specified), the last dimension
will have the same size as the number of events. The type of the sample
is `numpy.ndarray` if `vs` is a list or a numpy array, else a tensor
is returned.
:return: sample from the Categorical distribution
:rtype: numpy.ndarray or torch.LongTensor
"""
sample = torch_multinomial(self.ps.data, 1, replacement=True).expand(*self.shape())
sample_one_hot = torch_zeros_like(self.ps.data).scatter_(-1, sample, 1)
if self.vs is not None:
if isinstance(self.vs, np.ndarray):
sample_bool_index = sample_one_hot.cpu().numpy().astype(bool)
return self.vs[sample_bool_index].reshape(*self.shape())
else:
return self.vs.masked_select(sample_one_hot.byte())
if self.one_hot:
return Variable(sample_one_hot)
return Variable(sample)
评论列表
文章目录