def initial_state(self, batch_size, dtype=tf.float32, trainable=False,
trainable_initializers=None, trainable_regularizers=None,
name=None):
"""Builds the default start state tensor of zeros.
Args:
batch_size: An int, float or scalar Tensor representing the batch size.
dtype: The data type to use for the state.
trainable: Boolean that indicates whether to learn the initial state.
trainable_initializers: An optional pair of initializers for the
initial hidden state and cell state.
trainable_regularizers: Optional regularizer function or nested structure
of functions with the same structure as the `state_size` property of the
core, to be used as regularizers of the initial state variable. A
regularizer should be a function that takes a single `Tensor` as an
input and returns a scalar `Tensor` output, e.g. the L1 and L2
regularizers in `tf.contrib.layers`.
name: Optional string used to prefix the initial state variable names, in
the case of a trainable initial state. If not provided, defaults to
the name of the module.
Returns:
A tensor tuple `([batch_size, state_size], [batch_size, state_size], ?)`
filled with zeros, with the third entry present when batch norm is enabled
with `max_unique_stats > 1', with value `0` (representing the time step).
"""
if self._max_unique_stats == 1:
return super(BatchNormLSTM, self).initial_state(
batch_size, dtype=dtype, trainable=trainable,
trainable_initializers=trainable_initializers,
trainable_regularizers=trainable_regularizers, name=name)
else:
with tf.name_scope(self._initial_state_scope(name)):
if not trainable:
state = self.zero_state(batch_size, dtype)
else:
# We have to manually create the state ourselves so we don't create a
# variable that never gets used for the third entry.
state = rnn_core.trainable_initial_state(
batch_size,
(tf.TensorShape([self._hidden_size]),
tf.TensorShape([self._hidden_size])),
dtype=dtype,
initializers=trainable_initializers,
regularizers=trainable_regularizers,
name=self._initial_state_scope(name))
return (state[0], state[1], tf.constant(0, dtype=tf.int32))
评论列表
文章目录