classifiers.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:keita 作者: iwasaki-kenta 项目源码 文件源码
def forward(self, source_sentences, target_sentences):
        """
        Supervised Learning of Universal Sentence Representations from Natural Language Inference Data
        https://arxiv.org/abs/1705.02364

        A Siamese text classification network made w/ the goal of creating sentence embeddings.

        :param source_sentences:  A tuple of Variable's representing padded sentence tensor batch
            [seq. length, batch size, embed. size] and sentence lengths.
        :param target_sentences:  A tuple of Variable's representing padded sentence tensor batch
            [seq. length, batch size, embed. size] and sentence lengths.
        :return: Embedding. (batch size, # classes)
        """

        u = self.encoder(source_sentences)
        v = self.encoder(target_sentences)

        features = torch.cat((u, v, torch.abs(u - v), u * v), 1)
        return self.classifier(features)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号