def _dist_and_values(self, *args, **kwargs):
# XXX currently this whole object is very inefficient
values, logits = [], []
for value, logit in self._gen_weighted_samples(*args, **kwargs):
ix = _index(values, value)
if ix == -1:
# Value is new.
values.append(value)
logits.append(logit)
else:
# Value has already been seen.
logits[ix] = util.log_sum_exp(torch.stack([logits[ix], logit]).squeeze())
logits = torch.stack(logits).squeeze()
logits -= util.log_sum_exp(logits)
if not isinstance(logits, torch.autograd.Variable):
logits = Variable(logits)
logits = logits - util.log_sum_exp(logits)
d = dist.Categorical(logits=logits, one_hot=False)
return d, values
评论列表
文章目录