def _get_rnn_output(self, input_word, input_char, mask=None, length=None, hx=None):
# hack length from mask
# we do not hack mask from length for special reasons.
# Thus, always provide mask if it is necessary.
if length is None and mask is not None:
length = mask.data.sum(dim=1).long()
# [batch, length, word_dim]
word = self.word_embedd(input_word)
# [batch, length, char_length, char_dim]
char = self.char_embedd(input_char)
char_size = char.size()
# first transform to [batch *length, char_length, char_dim]
# then transpose to [batch * length, char_dim, char_length]
char = char.view(char_size[0] * char_size[1], char_size[2], char_size[3]).transpose(1, 2)
# put into cnn [batch*length, char_filters, char_length]
# then put into maxpooling [batch * length, char_filters]
char, _ = self.conv1d(char).max(dim=2)
# reshape to [batch, length, char_filters]
char = torch.tanh(char).view(char_size[0], char_size[1], -1)
# concatenate word and char [batch, length, word_dim+char_filter]
input = torch.cat([word, char], dim=2)
# apply dropout
input = self.dropout_in(input)
# prepare packed_sequence
if length is not None:
seq_input, hx, rev_order, mask = utils.prepare_rnn_seq(input, length, hx=hx, masks=mask, batch_first=True)
seq_output, hn = self.rnn(seq_input, hx=hx)
output, hn = utils.recover_rnn_seq(seq_output, rev_order, hx=hn, batch_first=True)
else:
# output from rnn [batch, length, hidden_size]
output, hn = self.rnn(input, hx=hx)
output = self.dropout_rnn(output)
if self.dense is not None:
# [batch, length, tag_space]
output = F.elu(self.dense(output))
return output, hn, mask, length
评论列表
文章目录