layers.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:fold 作者: tensorflow 项目源码 文件源码
def _create_variables(self):
    if self.input_type.ndim != 0:
      raise TypeError('Embeddings take scalar inputs.')
    dtype = tf.as_dtype(self.input_type.dtype)
    if not dtype.is_integer: raise TypeError('Embeddings take integer inputs.')
    if dtype not in (tf.int32, tf.int64):  # only dtypes supported by tf.gather
      if np.iinfo(dtype.as_numpy_dtype).max > 2147483647:
         # pedantic future-proofing to handle hypothetical tf.uint64
        raise TypeError('cannot gather or upcast dtype %s' % dtype)
      self._cast = True
    else:
      self._cast = False
    self._weights = tf.get_variable(
        'weights', self._weights_shape, initializer=self._initializer,
        trainable=self._trainable)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号