data.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:dong_iccv_2017 作者: woozzu 项目源码 文件源码
def _get_word_vectors(self, desc, word_embedding):
        output = []
        len_desc = []
        for i in range(desc.shape[1]):
            words = self._nums2chars(desc[:, i])
            words = split_sentence_into_words(words)
            word_vecs = torch.Tensor([word_embedding[w] for w in words])
            # zero padding
            if len(words) < self.max_word_length:
                word_vecs = torch.cat((
                    word_vecs,
                    torch.zeros(self.max_word_length - len(words), word_vecs.size(1))
                ))
            output.append(word_vecs)
            len_desc.append(len(words))
        return torch.stack(output), len_desc
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号