def _init_bidirectional_encoder(self):
'''
??LSTM encoder
'''
with tf.variable_scope("BidirectionalEncoder") as scope:
((encoder_fw_outputs,
encoder_bw_outputs),
(encoder_fw_state,
encoder_bw_state)) = (
tf.nn.bidirectional_dynamic_rnn(cell_fw=self.encoder_cell,
cell_bw=self.encoder_cell,
inputs=self.encoder_inputs_embedded,
sequence_length=self.encoder_inputs_length,
time_major=self.time_major,
dtype=tf.float32)
)
self.encoder_outputs = tf.concat((encoder_fw_outputs, encoder_bw_outputs), 2)
if isinstance(encoder_fw_state, LSTMStateTuple):
encoder_state_c = tf.concat(
(encoder_fw_state.c, encoder_bw_state.c), 1, name='bidirectional_concat_c')
encoder_state_h = tf.concat(
(encoder_fw_state.h, encoder_bw_state.h), 1, name='bidirectional_concat_h')
self.encoder_state = LSTMStateTuple(c=encoder_state_c, h=encoder_state_h)
elif isinstance(encoder_fw_state, tf.Tensor):
self.encoder_state = tf.concat((encoder_fw_state, encoder_bw_state), 1, name='bidirectional_concat')
dynamic_seq2seq_model.py 文件源码
python
阅读 49
收藏 0
点赞 0
评论 0
评论列表
文章目录