sequence_queueing_state_saver.py 文件源码

python
阅读 15 收藏 0 点赞 0 评论 0

项目:lsdc 作者: febert 项目源码 文件源码
def save_state(self, state_name, value, name=None):
    """Returns an op to save the current batch of state `state_name`.

    Args:
      state_name: string, matches a key provided in `initial_states`.
      value: A `Tensor`.
        Its type must match that of `initial_states[state_name].dtype`.
        If we had at input:

        ```python
        initial_states[state_name].get_shape() == [d1, d2, ...]
then the shape of `value` must match:

    ```python
    tf.shape(value) == [batch_size, d1, d2, ...]
    ```

  name: string (optional).  The name scope for newly created ops.

Returns:
  A control flow op that stores the new state of each entry into
  the state saver.  This op must be run for every iteration that
  accesses data from the state saver (otherwise the state saver
  will never progress through its states and run out of capacity).

Raises:
  KeyError: if `state_name` does not match any of the initial states
    declared in `initial_states`.
"""
if state_name not in self._state_saver._received_states.keys():
  raise KeyError("state was not declared: %s" % state_name)
default_name = "InputQueueingStateSaver_SaveState"
with ops.name_scope(name, default_name, values=[value]):
  # Place all operations on the CPU. Barriers and queues are only
  # implemented for CPU, but all the other book-keeping operations
  # (reshape, shape, range, ...) would be placed on GPUs if available,
  # unless we explicitly tie them to CPU.
  with ops.colocate_with(self._state_saver._capacity_queue.queue_ref):
    indices_where_not_done = array_ops.reshape(array_ops.where(
        math_ops.logical_not(self._state_saver._sequence_is_done)), [-1])
    keeping_next_key = array_ops.gather(
        self._state_saver._received_next_key, indices_where_not_done)
    value = _check_shape(
        array_ops.identity(value, name="convert_%s" % state_name),
        array_ops.shape(self._state_saver._received_states[state_name]))
    keeping_state = array_ops.gather(value, indices_where_not_done)
    return self._state_saver._barrier.insert_many(
        self._state_saver._get_barrier_index("state", state_name),
        keeping_next_key, keeping_state,
        name="BarrierInsertState_%s" % state_name)

pylint: enable=protected-access

```

评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号