def forward(self, unique_word_chars, unique_word_lengths, sequences_as_uniqs):
long_tensor = torch.cuda.LongTensor if torch.cuda.device_count() > 0 else torch.LongTensor
embedded_chars = self._embeddings(unique_word_chars.type(long_tensor))
# [N, S, L]
conv_out = self._conv(embedded_chars.transpose(1, 2))
# [N, L]
conv_mask = misc.mask_for_lengths(unique_word_lengths)
conv_out = conv_out + conv_mask.unsqueeze(1)
embedded_words = conv_out.max(2)[0]
if not isinstance(sequences_as_uniqs, list):
sequences_as_uniqs = [sequences_as_uniqs]
all_embedded = []
for word_idxs in sequences_as_uniqs:
all_embedded.append(functional.embedding(
word_idxs.type(long_tensor), embedded_words))
return all_embedded
评论列表
文章目录