dynamic_rnn_estimator.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:DeepLearning_VirtualReality_BigData_Project 作者: rashmitripathi 项目源码 文件源码
def dict_to_state_tuple(input_dict, cell):
  """Reconstructs nested `state` from a dict containing state `Tensor`s.

  Args:
    input_dict: A dict of `Tensor`s.
    cell: An instance of `RNNCell`.
  Returns:
    If `input_dict` does not contain keys 'STATE_PREFIX_i' for `0 <= i < n`
    where `n` is the number of nested entries in `cell.state_size`, this
    function returns `None`. Otherwise, returns a `Tensor` if `cell.state_size`
    is an `int` or a nested tuple of `Tensor`s if `cell.state_size` is a nested
    tuple.
  Raises:
    ValueError: State is partially specified. The `input_dict` must contain
      values for all state components or none at all.
  """
  flat_state_sizes = nest.flatten(cell.state_size)
  state_tensors = []
  with ops.name_scope('dict_to_state_tuple'):
    for i, state_size in enumerate(flat_state_sizes):
      state_name = _get_state_name(i)
      state_tensor = input_dict.get(state_name)
      if state_tensor is not None:
        rank_check = check_ops.assert_rank(
            state_tensor, 2, name='check_state_{}_rank'.format(i))
        shape_check = check_ops.assert_equal(
            array_ops.shape(state_tensor)[1],
            state_size,
            name='check_state_{}_shape'.format(i))
        with ops.control_dependencies([rank_check, shape_check]):
          state_tensor = array_ops.identity(state_tensor, name=state_name)
        state_tensors.append(state_tensor)
    if not state_tensors:
      return None
    elif len(state_tensors) == len(flat_state_sizes):
      dummy_state = cell.zero_state(batch_size=1, dtype=dtypes.bool)
      return nest.pack_sequence_as(dummy_state, state_tensors)
    else:
      raise ValueError(
          'RNN state was partially specified.'
          'Expected zero or {} state Tensors; got {}'.
          format(len(flat_state_sizes), len(state_tensors)))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号