layers.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:fold 作者: tensorflow 项目源码 文件源码
def __init__(self, num_buckets, num_units_out, initializer=None, name=None,
               trainable=True, mod_inputs=True):
    """Initializes the layer.

    Args:
      num_buckets: How many buckets the embedding has.
      num_units_out: The number of output units in the layer.
      initializer: the initializer for the weights. Defaults to uniform unit
        scaling. The initializer can also be a Tensor or numpy array, in which
        case the weights are initialized to this value and shape. Note that in
        this case the weights will still be trainable unless you also pass
        `trainable=False`.
      name: An optional string name. Defaults to
        `Embedding_%d_%d % (num_buckets, num_units_out)`. Used to name the
        variable scope where the variables for the layer live.
      trainable: Whether or not to make the weights trainable.
      mod_inputs: Whether or not to mod the input by the number of buckets.

    Raises:
      ValueError: If the shape of `weights` is not
        `(num_buckets, num_units_out)`.
    """

    self.set_constructor_args('td.Embedding',
                              *get_local_arguments(Embedding.__init__, True))

    self._weights_shape = (num_buckets, num_units_out)
    if name is None: name = 'Embedding_%d_%d' % self._weights_shape
    if initializer is None:
      initializer = tf.uniform_unit_scaling_initializer(1.0)
    elif isinstance(initializer, np.ndarray):
      initializer = tf.convert_to_tensor(initializer)
    if isinstance(initializer, tf.Tensor):
      initializer.set_shape(self._weights_shape)
      self._weights_shape = None  # otherwise get_variable barfs
    self._initializer = initializer
    self._num_buckets = num_buckets
    self._num_units_out = num_units_out
    self._trainable = trainable
    self._mod_inputs = bool(mod_inputs)
    super(Embedding, self).__init__(
        output_type=tdt.TensorType([num_units_out]), name_or_scope=name)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号