def zero_state(self, batch_size, dtype):
"""Return zero-filled state tensor(s).
Args:
batch_size: int, float, or unit Tensor representing the batch size.
dtype: the data type to use for the state.
Returns:
If `state_size` is an int or TensorShape, then the return value is a
`N-D` tensor of shape `[batch_size x state_size]` filled with zeros.
If `state_size` is a nested list or tuple, then the return value is
a nested list or tuple (of the same structure) of `2-D` tensors with
the shapes `[batch_size x s]` for each s in `state_size`.
"""
# Keep scope for backwards compatibility.
with tf.name_scope(type(self).__name__ + "ZeroState", values=[batch_size]):
return rnn_cell_impl._zero_state_tensors( # pylint: disable=protected-access
self.state_size, batch_size, dtype)
评论列表
文章目录