def save_state(self, state_name, value, name=None):
"""Returns an op to save the current batch of state `state_name`.
Args:
state_name: string, matches a key provided in `initial_states`.
value: A `Tensor`.
Its type must match that of `initial_states[state_name].dtype`.
If we had at input:
```python
initial_states[state_name].get_shape() == [d1, d2, ...]
then the shape of `value` must match:
```python
tf.shape(value) == [batch_size, d1, d2, ...]
```
name: string (optional). The name scope for newly created ops.
Returns:
A control flow op that stores the new state of each entry into
the state saver. This op must be run for every iteration that
accesses data from the state saver (otherwise the state saver
will never progress through its states and run out of capacity).
Raises:
KeyError: if `state_name` does not match any of the initial states
declared in `initial_states`.
"""
if state_name not in self._state_saver._received_states.keys():
raise KeyError("state was not declared: %s" % state_name)
default_name = "InputQueueingStateSaver_SaveState"
with ops.name_scope(name, default_name, values=[value]):
# Place all operations on the CPU. Barriers and queues are only
# implemented for CPU, but all the other book-keeping operations
# (reshape, shape, range, ...) would be placed on GPUs if available,
# unless we explicitly tie them to CPU.
with ops.colocate_with(self._state_saver._capacity_queue.queue_ref):
indices_where_not_done = array_ops.reshape(array_ops.where(
math_ops.logical_not(self._state_saver._sequence_is_done)), [-1])
keeping_next_key = array_ops.gather(
self._state_saver._received_next_key, indices_where_not_done)
value = _check_shape(
array_ops.identity(value, name="convert_%s" % state_name),
array_ops.shape(self._state_saver._received_states[state_name]))
keeping_state = array_ops.gather(value, indices_where_not_done)
return self._state_saver._barrier.insert_many(
self._state_saver._get_barrier_index("state", state_name),
keeping_next_key, keeping_state,
name="BarrierInsertState_%s" % state_name)
pylint: enable=protected-access
```