def simple_rnn(rnn_input, initial_state=None):
"""Implements Simple RNN
Args:
rnn_input: List of tensors of sizes [-1, sentembed_size]
Returns:
encoder_outputs, encoder_state
"""
# Setup cell
cell_enc = get_lstm_cell()
# Setup RNNs
dtype = tf.float16 if FLAGS.use_fp16 else tf.float32
rnn_outputs, rnn_state = tf.nn.rnn(cell_enc, rnn_input, dtype=dtype, initial_state=initial_state)
# print(rnn_outputs)
# print(rnn_state)
return rnn_outputs, rnn_state
评论列表
文章目录