def global_handle(self, emb_layer, flag):
fw_lstm_out = self.forward_lstm(emb_layer)
bw_lstm_out = self.backward_lstm(emb_layer)
conv_out = self.conv_dropout(self.conv(emb_layer))
fw_lstm_out = TimeDistributed(Dense(self.params['attention_dim']), name='fw_tb_'+flag)(fw_lstm_out)
fw_lstm_att = Attention()(fw_lstm_out)
# fw_lstm_att = Reshape((self.params['lstm_output_dim'], 1))(fw_lstm_att)
conv_out = TimeDistributed(Dense(self.params['attention_dim']), name='conv_tb_'+flag)(conv_out)
conv_att = Attention()(conv_out)
# conv_att = Reshape((self.params['filters'], 1))(conv_att)
bw_lstm_out = TimeDistributed(Dense(self.params['attention_dim']), name='bw_tb_'+flag)(bw_lstm_out)
bw_lstm_att = Attention()(bw_lstm_out)
# bw_lstm_att = Reshape((self.params['lstm_output_dim'], 1))(bw_lstm_att)
return concatenate([fw_lstm_att, conv_att, bw_lstm_att], axis=2)
评论列表
文章目录