def bidirectional_rnn(self) -> Tuple[Tuple[tf.Tensor, tf.Tensor],
Tuple[tf.Tensor, tf.Tensor]]:
# BiRNN Network
fw_cell, bw_cell = self.rnn_cells() # type: RNNCellTuple
seq_lens = tf.ceil(tf.divide(
self.input_sequence.lengths,
self.segment_size))
seq_lens = tf.cast(seq_lens, tf.int32)
return tf.nn.bidirectional_dynamic_rnn(
fw_cell, bw_cell, self.highway_layer,
sequence_length=seq_lens,
dtype=tf.float32)
评论列表
文章目录