def dict_to_state_tuple(input_dict, cell):
"""Reconstructs nested `state` from a dict containing state `Tensor`s.
Args:
input_dict: A dict of `Tensor`s.
cell: An instance of `RNNCell`.
Returns:
If `input_dict` does not contain keys 'STATE_PREFIX_i' for `0 <= i < n`
where `n` is the number of nested entries in `cell.state_size`, this
function returns `None`. Otherwise, returns a `Tensor` if `cell.state_size`
is an `int` or a nested tuple of `Tensor`s if `cell.state_size` is a nested
tuple.
Raises:
ValueError: State is partially specified. The `input_dict` must contain
values for all state components or none at all.
"""
flat_state_sizes = nest.flatten(cell.state_size)
state_tensors = []
with ops.name_scope('dict_to_state_tuple'):
for i, state_size in enumerate(flat_state_sizes):
state_name = _get_state_name(i)
state_tensor = input_dict.get(state_name)
if state_tensor is not None:
rank_check = check_ops.assert_rank(
state_tensor, 2, name='check_state_{}_rank'.format(i))
shape_check = check_ops.assert_equal(
array_ops.shape(state_tensor)[1],
state_size,
name='check_state_{}_shape'.format(i))
with ops.control_dependencies([rank_check, shape_check]):
state_tensor = array_ops.identity(state_tensor, name=state_name)
state_tensors.append(state_tensor)
if not state_tensors:
return None
elif len(state_tensors) == len(flat_state_sizes):
dummy_state = cell.zero_state(batch_size=1, dtype=dtypes.bool)
return nest.pack_sequence_as(dummy_state, state_tensors)
else:
raise ValueError(
'RNN state was partially specified.'
'Expected zero or {} state Tensors; got {}'.
format(len(flat_state_sizes), len(state_tensors)))
dynamic_rnn_estimator.py 文件源码
python
阅读 25
收藏 0
点赞 0
评论 0
评论列表
文章目录