def trainable_initial_state(batch_size, state_size, dtype, initializers=None):
"""Creates an initial state consisting of trainable variables.
The trainable variables are created with the same shapes as the elements of
`state_size` and are tiled to produce an initial state.
Args:
batch_size: An int, or scalar int32 Tensor representing the batch size.
state_size: A `TensorShape` or nested tuple of `TensorShape`s to use for the
shape of the trainable variables.
dtype: The data type used to create the variables and thus initial state.
initializers: An optional container of the same structure as `state_size`
containing initializers for the variables.
Returns:
A `Tensor` or nested tuple of `Tensor`s with the same size and structure
as `state_size`, where each `Tensor` is a tiled trainable `Variable`.
Raises:
ValueError: if the user passes initializers that are not functions.
"""
flat_state_size = nest.flatten(state_size)
if not initializers:
flat_initializer = tuple(tf.zeros_initializer for _ in flat_state_size)
else:
nest.assert_same_structure(initializers, state_size)
flat_initializer = nest.flatten(initializers)
if not all([callable(init) for init in flat_initializer]):
raise ValueError("Not all the passed initializers are callable objects.")
# Produce names for the variables. In the case of a tuple or nested tuple,
# this is just a sequence of numbers, but for a flat `namedtuple`, we use
# the field names. NOTE: this could be extended to nested `namedtuple`s,
# but for now that's extra complexity that's not used anywhere.
try:
names = ["init_{}".format(state_size._fields[i])
for i in xrange(len(flat_state_size))]
except (AttributeError, IndexError):
names = ["init_state_{}".format(i) for i in xrange(len(flat_state_size))]
flat_initial_state = []
for name, size, init in zip(names, flat_state_size, flat_initializer):
shape_with_batch_dim = [1] + tensor_shape.as_shape(size).as_list()
initial_state_variable = tf.get_variable(
name, shape=shape_with_batch_dim, dtype=dtype, initializer=init)
initial_state_variable_dims = initial_state_variable.get_shape().ndims
tile_dims = [batch_size] + [1] * (initial_state_variable_dims - 1)
flat_initial_state.append(
tf.tile(initial_state_variable, tile_dims, name=(name + "_tiled")))
return nest.pack_sequence_as(structure=state_size,
flat_sequence=flat_initial_state)
评论列表
文章目录