gated_rnn.py 文件源码

python
阅读 30 收藏 0 点赞 0 评论 0

项目:sonnet 作者: deepmind 项目源码 文件源码
def initial_state(self, batch_size, dtype=tf.float32, trainable=False,
                    trainable_initializers=None, trainable_regularizers=None,
                    name=None):
    """Builds the default start state tensor of zeros.

    Args:
      batch_size: An int, float or scalar Tensor representing the batch size.
      dtype: The data type to use for the state.
      trainable: Boolean that indicates whether to learn the initial state.
      trainable_initializers: An optional pair of initializers for the
          initial hidden state and cell state.
      trainable_regularizers: Optional regularizer function or nested structure
        of functions with the same structure as the `state_size` property of the
        core, to be used as regularizers of the initial state variable. A
        regularizer should be a function that takes a single `Tensor` as an
        input and returns a scalar `Tensor` output, e.g. the L1 and L2
        regularizers in `tf.contrib.layers`.
      name: Optional string used to prefix the initial state variable names, in
          the case of a trainable initial state. If not provided, defaults to
          the name of the module.

    Returns:
      A tensor tuple `([batch_size, state_size], [batch_size, state_size], ?)`
      filled with zeros, with the third entry present when batch norm is enabled
      with `max_unique_stats > 1', with value `0` (representing the time step).
    """
    if self._max_unique_stats == 1:
      return super(BatchNormLSTM, self).initial_state(
          batch_size, dtype=dtype, trainable=trainable,
          trainable_initializers=trainable_initializers,
          trainable_regularizers=trainable_regularizers, name=name)
    else:
      with tf.name_scope(self._initial_state_scope(name)):
        if not trainable:
          state = self.zero_state(batch_size, dtype)
        else:
          # We have to manually create the state ourselves so we don't create a
          # variable that never gets used for the third entry.
          state = rnn_core.trainable_initial_state(
              batch_size,
              (tf.TensorShape([self._hidden_size]),
               tf.TensorShape([self._hidden_size])),
              dtype=dtype,
              initializers=trainable_initializers,
              regularizers=trainable_regularizers,
              name=self._initial_state_scope(name))
        return (state[0], state[1], tf.constant(0, dtype=tf.int32))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号