vae.py 文件源码

python
阅读 29 收藏 0 点赞 0 评论 0

项目:seqmod 作者: emanjavacas 项目源码 文件源码
def argmax(self, z, max_len):
        # local variables
        eos, bos = self.src_dict.get_eos(), self.src_dict.get_bos()
        batch = z.size(0)
        # output variables
        scores, preds, mask = 0, [], z.data.new(batch).long() + 1
        # model inputs
        hidden = self.decoder.init_hidden_for(z)
        prev = Variable(z.data.new(batch).zero_().long() + bos, volatile=True)

        for _ in range(max_len):
            prev_emb = self.embeddings(prev).squeeze(0)
            dec_out, hidden = self.decoder(prev_emb, hidden, z=z)
            dec_out = self.project(dec_out.unsqueeze(0))

            score, pred = dec_out.max(1)
            scores += score.squeeze().data
            preds.append(pred.squeeze().data)
            prev = pred

            mask = mask * (pred.squeeze().data[0] != eos)
            if mask.int().sum() == 0:
                break

        return scores.tolist(), torch.stack(preds).transpose(0, 1).tolist()
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号