train.py 文件源码

python
阅读 21 收藏 0 点赞 0 评论 0

项目:pytorch-skipthoughts 作者: kaniblu 项目源码 文件源码
def val_sents(self, data, dec_logits):
        vocab, previews = self.model.vocab, self.previews
        x, x_lens, ys_i, ys_t, ys_lens, xys_idx = data

        if self.batch_first:
            cdata = [ys_i, ys_t, ys_lens, xys_idx, dec_logits]
            cdata = [d.transpose(1, 0).contiguous() for d in cdata]
            ys_i, ys_t, ys_lens, xys_idx, dec_logits = cdata

        _, xys_ridx = torch.sort(xys_idx, 1)
        xys_ridx_exp = xys_ridx.unsqueeze(-1).expand_as(ys_i)
        ys_i = torch.gather(ys_i, 1, xys_ridx_exp)
        ys_t = torch.gather(ys_t, 1, xys_ridx_exp)
        dec_logits = [torch.index_select(logits, 0, xy_ridx)
                      for logits, xy_ridx in zip(dec_logits, xys_ridx)]
        ys_lens = torch.gather(ys_lens, 1, xys_ridx)

        x, x_lens = x[:previews], x_lens[:previews]
        ys_i, ys_t = ys_i[:, :previews], ys_t[:, :previews]
        dec_logits = torch.cat(
            [logits[:previews].max(2)[1].squeeze(-1).unsqueeze(0)
             for logits in dec_logits], 0)
        ys_lens = ys_lens[:, :previews]

        ys_i, ys_t = ys_i.transpose(1, 0), ys_t.transpose(1, 0)
        dec_logits, ys_lens = dec_logits.transpose(1, 0), ys_lens.transpose(1,
                                                                            0)

        x, x_lens = x.data.tolist(), x_lens.data.tolist()
        ys_i, ys_t = ys_i.data.tolist(), ys_t.data.tolist()
        dec_logits, ys_lens = dec_logits.data.tolist(), ys_lens.data.tolist()

        def to_sent(data, length, vocab):
            return " ".join(vocab.i2f[data[i]] for i in range(length))

        def to_sents(data, lens, vocab):
            return [to_sent(d, l, vocab) for d, l in zip(data, lens)]

        x_sents = to_sents(x, x_lens, vocab)
        yi_sents = [to_sents(yi, y_lens, vocab) for yi, y_lens in
                    zip(ys_i, ys_lens)]
        yt_sents = [to_sents(yt, y_lens, vocab) for yt, y_lens in
                    zip(ys_t, ys_lens)]
        o_sents = [to_sents(dec_logit, y_lens, vocab)
                   for dec_logit, y_lens in zip(dec_logits, ys_lens)]

        return x_sents, yi_sents, yt_sents, o_sents
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号