def get_state_update_op(state_variables, new_states):
"""Returns an operation to update an LSTM's state variables.
See get_state_variables() for more info.
Args:
state_variables (tuple[tf.contrib.rnn.LSTMStateTuple]): The LSTM's state variables.
new_states (tuple[tf.contrib.rnn.LSTMStateTuple]): The new values for the state variables.
new_states may have state tuples with state sizes < max_batch_size. Then, only the first
rows of the corresponding state variables will be updated.
Returns:
tf.Operation: An operation that updates the LSTM's.
"""
# Add an operation to update the train states with the last state tensors.
update_ops = []
for state_variable, new_state in zip(state_variables, new_states):
# new_state[0] might be smaller than state_variable[0], because state_variable[0]
# contains max_batch_size entries.
# Get the update indices for both states in the tuple
update_indices = (tf.range(0, tf.shape(new_state[0])[0]),
tf.range(0, tf.shape(new_state[1])[0]))
update_ops.extend([
tf.scatter_update(state_variable[0], update_indices[0], new_state[0]),
tf.scatter_update(state_variable[1], update_indices[1], new_state[1])
])
return tf.tuple(update_ops)
评论列表
文章目录