def forward(self, source_sentences, target_sentences):
"""
Supervised Learning of Universal Sentence Representations from Natural Language Inference Data
https://arxiv.org/abs/1705.02364
A Siamese text classification network made w/ the goal of creating sentence embeddings.
:param source_sentences: A tuple of Variable's representing padded sentence tensor batch
[seq. length, batch size, embed. size] and sentence lengths.
:param target_sentences: A tuple of Variable's representing padded sentence tensor batch
[seq. length, batch size, embed. size] and sentence lengths.
:return: Embedding. (batch size, # classes)
"""
u = self.encoder(source_sentences)
v = self.encoder(target_sentences)
features = torch.cat((u, v, torch.abs(u - v), u * v), 1)
return self.classifier(features)
评论列表
文章目录