def test_distributive_property(self):
"""Verifies the distributive property of matrix multiplication."""
with self.test_session():
params = constant_op.constant([.1, .2, .3])
sp_values_a = sparse_tensor_lib.SparseTensor(
values=["a"], indices=[[0, 0]], dense_shape=[3, 1])
sp_values_b = sparse_tensor_lib.SparseTensor(
values=["b"], indices=[[2, 0]], dense_shape=[3, 1])
sp_values_c = sparse_tensor_lib.SparseTensor(
values=["c"], indices=[[2, 0]], dense_shape=[3, 1])
sp_values = sparse_tensor_lib.SparseTensor(
values=["a", "b", "c"],
indices=[[0, 0], [2, 0], [2, 1]],
dense_shape=[3, 2])
result_a = embedding_ops._sampled_scattered_embedding_lookup_sparse(
params, sp_values_a, dimension=4, hash_key=self._hash_key)
result_b = embedding_ops._sampled_scattered_embedding_lookup_sparse(
params, sp_values_b, dimension=4, hash_key=self._hash_key)
result_c = embedding_ops._sampled_scattered_embedding_lookup_sparse(
params, sp_values_c, dimension=4, hash_key=self._hash_key)
result = embedding_ops._sampled_scattered_embedding_lookup_sparse(
params, sp_values, dimension=4, hash_key=self._hash_key)
result_abc = math_ops.add_n([result_a, result_b, result_c])
self.assertAllClose(result.eval(), result_abc.eval())
embedding_ops_test.py 文件源码
python
阅读 28
收藏 0
点赞 0
评论 0
评论列表
文章目录