transform.py 文件源码

python
阅读 31 收藏 0 点赞 0 评论 0

项目:torch_light 作者: ne7ermore 项目源码 文件源码
def sent2tenosr(self, sentence):
        max_len = self.args.max_word_len - 2
        sentence = normalizeString(sentence)
        words = [w for w in sentence.strip().split()]

        if len(words) > max_len:
            words = words[:max_len]

        words = [WORD[BOS]] + words + [WORD[EOS]]
        idx = [self.src_dict[w] if w in self.src_dict else UNK for w in words]

        idx_data = torch.LongTensor(idx)
        idx_position = torch.LongTensor([pos_i+1 if w_i != PAD else 0 for pos_i, w_i in enumerate(idx)])
        idx_data_tensor = Variable(idx_data.unsqueeze(0), volatile=True)
        idx_position_tensor = Variable(idx_position.unsqueeze(0), volatile=True)

        if self.cuda:
            idx_data_tensor = idx_data_tensor.cuda()
            idx_position_tensor = idx_position_tensor.cuda()

        return idx_data_tensor, idx_position_tensor
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号