def _create_variables(self):
if self.input_type.ndim != 0:
raise TypeError('Embeddings take scalar inputs.')
dtype = tf.as_dtype(self.input_type.dtype)
if not dtype.is_integer: raise TypeError('Embeddings take integer inputs.')
if dtype not in (tf.int32, tf.int64): # only dtypes supported by tf.gather
if np.iinfo(dtype.as_numpy_dtype).max > 2147483647:
# pedantic future-proofing to handle hypothetical tf.uint64
raise TypeError('cannot gather or upcast dtype %s' % dtype)
self._cast = True
else:
self._cast = False
self._weights = tf.get_variable(
'weights', self._weights_shape, initializer=self._initializer,
trainable=self._trainable)
评论列表
文章目录