def recurrent_layer(tensor, cell=None, hidden_dims=128, sequence_length=None, decoder_fn=None,
activation=tf.nn.tanh, initializer=tf.orthogonal_initializer(), initial_state=None,
keep_prob=1.0,
return_final_state=False, return_next_cell_input=True, **opts):
if cell is None:
cell = tf.contrib.rnn.BasicRNNCell(hidden_dims, activation=activation)
# cell = tf.contrib.rnn.LSTMCell(hidden_dims, activation=activation)
if keep_prob < 1.0:
keep_prob = _global_keep_prob(keep_prob)
cell = tf.contrib.rnn.DropoutWrapper(cell, keep_prob, keep_prob)
if opts.get("name"):
tf.add_to_collection(opts.get("name"), cell)
if decoder_fn is None:
outputs, final_state = tf.nn.dynamic_rnn(cell, tensor,
sequence_length=sequence_length, initial_state=initial_state, dtype=tf.float32)
final_context_state = None
else:
# TODO: turn off sequence_length?
outputs, final_state, final_context_state = seq2seq.dynamic_rnn_decoder(
cell, decoder_fn, inputs=None, sequence_length=sequence_length)
if return_final_state:
return final_state
else:
return outputs
评论列表
文章目录