def test_safe_embedding_lookup_sparse_3d_partitioned_inconsistent_weights(
self):
with self.test_session():
embedding_weights = self._random_weights(num_shards=3)
sparse_ids, sparse_weights = self._ids_and_weights_3d()
embedding_weights[1] = embedding_weights[1].astype(np.float64)
self.assertRaises(ValueError,
tf.contrib.layers.safe_embedding_lookup_sparse,
embedding_weights, sparse_ids)
embedding_weights = [
tf.constant(w, dtype=tf.float64) for w in embedding_weights
]
self.assertRaises(ValueError,
tf.contrib.layers.safe_embedding_lookup_sparse,
embedding_weights, sparse_ids, sparse_weights)
评论列表
文章目录