def _get_encoder_output(self, input_word, input_char, input_pos, mask_e=None, length_e=None, hx=None):
# [batch, length, word_dim]
word = self.word_embedd(input_word)
# [batch, length, pos_dim]
pos = self.pos_embedd(input_pos)
# [batch, length, char_length, char_dim]
char = self.char_embedd(input_char)
char_size = char.size()
# first transform to [batch *length, char_length, char_dim]
# then transpose to [batch * length, char_dim, char_length]
char = char.view(char_size[0] * char_size[1], char_size[2], char_size[3]).transpose(1, 2)
# put into cnn [batch*length, char_filters, char_length]
# then put into maxpooling [batch * length, char_filters]
char, _ = self.conv1d(char).max(dim=2)
# reshape to [batch, length, char_filters]
char = torch.tanh(char).view(char_size[0], char_size[1], -1)
# apply dropout on input
word = self.dropout_in(word)
pos = self.dropout_in(pos)
char = self.dropout_in(char)
# concatenate word and char [batch, length, word_dim+char_filter]
src_encoding = torch.cat([word, char, pos], dim=2)
# output from rnn [batch, length, hidden_size]
output, hn = self.encoder(src_encoding, mask_e, hx=hx)
# apply dropout
# [batch, length, hidden_size] --> [batch, hidden_size, length] --> [batch, length, hidden_size]
output = self.dropout_out(output.transpose(1, 2)).transpose(1, 2)
return src_encoding, output, hn, mask_e, length_e
评论列表
文章目录