embedding_ops_test.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:lsdc 作者: febert 项目源码 文件源码
def test_safe_embedding_lookup_sparse_3d_partitioned_inconsistent_weights(
      self):
    with self.test_session():
      embedding_weights = self._random_weights(num_shards=3)
      sparse_ids, sparse_weights = self._ids_and_weights_3d()

      embedding_weights[1] = embedding_weights[1].astype(np.float64)
      self.assertRaises(ValueError,
                        tf.contrib.layers.safe_embedding_lookup_sparse,
                        embedding_weights, sparse_ids)
      embedding_weights = [
          tf.constant(w, dtype=tf.float64) for w in embedding_weights
      ]
      self.assertRaises(ValueError,
                        tf.contrib.layers.safe_embedding_lookup_sparse,
                        embedding_weights, sparse_ids, sparse_weights)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号