def testSparseExpandAndTile1x(self):
# Shape (3,3).
x = tf.SparseTensorValue(
indices=[
[0, 0], [0, 1],
[1, 0], [1, 1], [1, 2],
[2, 0]],
values=[
1, 2,
3, 4, 5,
6],
shape=[3, 3])
with self.test_session():
expected_result_dim0 = tf.SparseTensorValue(
indices=[[0, i[0], i[1]] for i in x.indices], values=x.values,
shape=[1, 3, 3])
self._assert_sparse_tensors_equal(
expected_result_dim0, metric_ops.expand_and_tile(x, multiple=1).eval())
for dim in (-2, 0):
self._assert_sparse_tensors_equal(
expected_result_dim0,
metric_ops.expand_and_tile(x, multiple=1, dim=dim).eval())
expected_result_dim1 = tf.SparseTensorValue(
indices=[[i[0], 0, i[1]] for i in x.indices], values=x.values,
shape=[3, 1, 3])
for dim in (-1, 1):
self._assert_sparse_tensors_equal(
expected_result_dim1,
metric_ops.expand_and_tile(x, multiple=1, dim=dim).eval())
expected_result_dim2 = tf.SparseTensorValue(
indices=[[i[0], i[1], 0] for i in x.indices], values=x.values,
shape=[3, 3, 1])
self._assert_sparse_tensors_equal(
expected_result_dim2,
metric_ops.expand_and_tile(x, multiple=1, dim=2).eval())
# TODO(ptucker): Use @parameterized when it's available in tf.
评论列表
文章目录