embedding.py 文件源码

python
阅读 30 收藏 0 点赞 0 评论 0

项目:jack 作者: uclmr 项目源码 文件源码
def forward(self, unique_word_chars, unique_word_lengths, sequences_as_uniqs):
        long_tensor = torch.cuda.LongTensor if torch.cuda.device_count() > 0 else torch.LongTensor
        embedded_chars = self._embeddings(unique_word_chars.type(long_tensor))
        # [N, S, L]
        conv_out = self._conv(embedded_chars.transpose(1, 2))
        # [N, L]
        conv_mask = misc.mask_for_lengths(unique_word_lengths)
        conv_out = conv_out + conv_mask.unsqueeze(1)
        embedded_words = conv_out.max(2)[0]

        if not isinstance(sequences_as_uniqs, list):
            sequences_as_uniqs = [sequences_as_uniqs]

        all_embedded = []
        for word_idxs in sequences_as_uniqs:
            all_embedded.append(functional.embedding(
                word_idxs.type(long_tensor), embedded_words))
        return all_embedded
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号