rnn_core.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:tf-sparql 作者: derdav3 项目源码 文件源码
def trainable_initial_state(batch_size, state_size, dtype, initializers=None):
  """Creates an initial state consisting of trainable variables.

  The trainable variables are created with the same shapes as the elements of
  `state_size` and are tiled to produce an initial state.

  Args:
    batch_size: An int, or scalar int32 Tensor representing the batch size.
    state_size: A `TensorShape` or nested tuple of `TensorShape`s to use for the
        shape of the trainable variables.
    dtype: The data type used to create the variables and thus initial state.
    initializers: An optional container of the same structure as `state_size`
        containing initializers for the variables.

  Returns:
    A `Tensor` or nested tuple of `Tensor`s with the same size and structure
    as `state_size`, where each `Tensor` is a tiled trainable `Variable`.

  Raises:
    ValueError: if the user passes initializers that are not functions.
  """
  flat_state_size = nest.flatten(state_size)

  if not initializers:
    flat_initializer = tuple(tf.zeros_initializer for _ in flat_state_size)
  else:
    nest.assert_same_structure(initializers, state_size)
    flat_initializer = nest.flatten(initializers)
    if not all([callable(init) for init in flat_initializer]):
      raise ValueError("Not all the passed initializers are callable objects.")

  # Produce names for the variables. In the case of a tuple or nested tuple,
  # this is just a sequence of numbers, but for a flat `namedtuple`, we use
  # the field names. NOTE: this could be extended to nested `namedtuple`s,
  # but for now that's extra complexity that's not used anywhere.
  try:
    names = ["init_{}".format(state_size._fields[i])
             for i in xrange(len(flat_state_size))]
  except (AttributeError, IndexError):
    names = ["init_state_{}".format(i) for i in xrange(len(flat_state_size))]

  flat_initial_state = []

  for name, size, init in zip(names, flat_state_size, flat_initializer):
    shape_with_batch_dim = [1] + tensor_shape.as_shape(size).as_list()
    initial_state_variable = tf.get_variable(
        name, shape=shape_with_batch_dim, dtype=dtype, initializer=init)

    initial_state_variable_dims = initial_state_variable.get_shape().ndims
    tile_dims = [batch_size] + [1] * (initial_state_variable_dims - 1)
    flat_initial_state.append(
        tf.tile(initial_state_variable, tile_dims, name=(name + "_tiled")))

  return nest.pack_sequence_as(structure=state_size,
                               flat_sequence=flat_initial_state)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号