model.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:pytorch-skipthoughts 作者: kaniblu 项目源码 文件源码
def forward(self, x, x_lens, ys, ys_lens, xys_idx):
        x = self.embeddings(x)
        h = self._encode_embed(x, x_lens)

        if self.batch_first:
            ys = ys.transpose(1, 0)
            ys_lens = ys_lens.transpose(1, 0)
            xys_idx = xys_idx.transpose(1, 0)

        logits_list = []

        for dec_idx, (y, y_lens, xy_idx) in enumerate(
                zip(ys, ys_lens, xys_idx)):
            h_dec = torch.index_select(h, 0, xy_idx)
            logits = self._decode(dec_idx, h_dec, y, y_lens)

            nil_batches = len(h_dec) - len(logits)
            if nil_batches:
                logits = pad_batch(logits, nil_batches, True)

            logits_list.append(logits.unsqueeze(0))

        logits = torch.cat(logits_list)

        if self.batch_first:
            logits = logits.transpose(1, 0)

        return logits, h
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号