def forward(self, input):
bsz, word_len, char_len = input.size()
encode = input.view(-1, char_len)
encode = self.char_ebd(encode).unsqueeze(1)
encode = F.relu(self.char_cnn(encode))
encode = F.max_pool2d(encode,
kernel_size=(encode.size(2), 1))
encode = F.dropout(encode.squeeze(), p=self.dropout)
return encode.view(bsz, word_len, -1)
评论列表
文章目录