def get_state_reset_op(state_variables, cell, max_batch_size):
"""Returns an operation to set each variable in a list of LSTMStateTuples to zero.
See get_state_variables() for more info.
Args:
state_variables (tuple[tf.contrib.rnn.LSTMStateTuple]): The LSTM's state variables.
cell (tf.contrib.rnn.MuliRNNCell): An MultiRNNCell consisting of multiple LSTMCells.
max_batch_size (int): The maximum size of batches that are be fed to the LSTMCell.
Returns:
tf.Operation: An operation that sets the LSTM's state to zero.
"""
zero_states = cell.zero_state(max_batch_size, tf.float32)
return get_state_update_op(state_variables, zero_states)
评论列表
文章目录