def forward(self, x):
x = self.embed(x)
x = self.dropout(x)
# x = x.view(len(x), x.size(1), -1)
# x = embed.view(len(x), embed.size(1), -1)
bilstm_out, self.hidden = self.bilstm(x, self.hidden)
bilstm_out = torch.transpose(bilstm_out, 0, 1)
bilstm_out = torch.transpose(bilstm_out, 1, 2)
# bilstm_out = F.max_pool1d(bilstm_out, bilstm_out.size(2)).squeeze(2)
bilstm_out = F.max_pool1d(bilstm_out, bilstm_out.size(2))
bilstm_out = bilstm_out.squeeze(2)
hidden2lable = self.hidden2label1(F.tanh(bilstm_out))
gate_layer = F.sigmoid(self.gate_layer(bilstm_out))
# calculate highway layer values
gate_hidden_layer = torch.mul(hidden2lable, gate_layer)
# if write like follow ,can run,but not equal the HighWay NetWorks formula
# gate_input = torch.mul((1 - gate_layer), hidden2lable)
gate_input = torch.mul((1 - gate_layer), bilstm_out)
highway_output = torch.add(gate_hidden_layer, gate_input)
logit = self.logit_layer(highway_output)
return logit
model_HighWay_BiLSTM_1.py 文件源码
python
阅读 23
收藏 0
点赞 0
评论 0
评论列表
文章目录