def forward(self, x):
embed = self.embed(x)
embed = self.dropout(embed)
# CNN
cnn_x = embed
cnn_x = torch.transpose(cnn_x, 0, 1)
cnn_x = cnn_x.unsqueeze(1)
# cnn_x = [F.relu(conv(cnn_x)).squeeze(3) for conv in self.convs1] # [(N,Co,W), ...]*len(Ks)
cnn_x = [conv(cnn_x).squeeze(3) for conv in self.convs1] # [(N,Co,W), ...]*len(Ks)
# cnn_x = [F.max_pool1d(i, i.size(2)).squeeze(2) for i in cnn_x] # [(N,Co), ...]*len(Ks)
cnn_x = [F.tanh(F.max_pool1d(i, i.size(2)).squeeze(2)) for i in cnn_x] # [(N,Co), ...]*len(Ks)
cnn_x = torch.cat(cnn_x, 1)
cnn_x = self.dropout(cnn_x)
# BiGRU
bigru_x = embed.view(len(x), embed.size(1), -1)
bigru_x, self.hidden = self.bigru(bigru_x, self.hidden)
bigru_x = torch.transpose(bigru_x, 0, 1)
bigru_x = torch.transpose(bigru_x, 1, 2)
# bilstm_out = F.tanh(bilstm_out)
bigru_x = F.max_pool1d(bigru_x, bigru_x.size(2)).squeeze(2)
bigru_x = F.tanh(bigru_x)
# CNN and BiGRU CAT
cnn_x = torch.transpose(cnn_x, 0, 1)
bigru_x = torch.transpose(bigru_x, 0, 1)
cnn_bigru_out = torch.cat((cnn_x, bigru_x), 0)
cnn_bigru_out = torch.transpose(cnn_bigru_out, 0, 1)
# linear
cnn_bigru_out = self.hidden2label1(F.tanh(cnn_bigru_out))
logit = self.hidden2label2(F.tanh(cnn_bigru_out))
return logit
model_CNN_BiGRU.py 文件源码
python
阅读 24
收藏 0
点赞 0
评论 0
评论列表
文章目录