def forward(self, x):
embed = self.embed(x)
# CNN
embed = self.dropout(embed)
cnn_x = embed
cnn_x = cnn_x.unsqueeze(1)
cnn_x = [F.relu(conv(cnn_x)).squeeze(3) for conv in self.convs1] # [(N,Co,W), ...]*len(Ks)
cnn_x = torch.cat(cnn_x, 0)
cnn_x = torch.transpose(cnn_x, 1, 2)
# BiLSTM
bilstm_out, self.hidden = self.bilstm(cnn_x, self.hidden)
bilstm_out = torch.transpose(bilstm_out, 0, 1)
bilstm_out = torch.transpose(bilstm_out, 1, 2)
bilstm_out = F.max_pool1d(bilstm_out, bilstm_out.size(2)).squeeze(2)
# linear
cnn_bilstm_out = self.hidden2label1(F.tanh(bilstm_out))
cnn_bilstm_out = self.hidden2label2(F.tanh(cnn_bilstm_out))
# dropout
logit = self.dropout(cnn_bilstm_out)
return logit
model_CBiLSTM.py 文件源码
python
阅读 25
收藏 0
点赞 0
评论 0
评论列表
文章目录