abstract_infer.py 文件源码

python
阅读 30 收藏 0 点赞 0 评论 0

项目:pyro 作者: uber 项目源码 文件源码
def _dist_and_values(self, *args, **kwargs):
        # XXX currently this whole object is very inefficient
        values, logits = [], []
        for value, logit in self._gen_weighted_samples(*args, **kwargs):
            ix = _index(values, value)
            if ix == -1:
                # Value is new.
                values.append(value)
                logits.append(logit)
            else:
                # Value has already been seen.
                logits[ix] = util.log_sum_exp(torch.stack([logits[ix], logit]).squeeze())

        logits = torch.stack(logits).squeeze()
        logits -= util.log_sum_exp(logits)
        if not isinstance(logits, torch.autograd.Variable):
            logits = Variable(logits)
        logits = logits - util.log_sum_exp(logits)

        d = dist.Categorical(logits=logits, one_hot=False)
        return d, values
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号