metric_ops_test.py 文件源码

python
阅读 31 收藏 0 点赞 0 评论 0

项目:lsdc 作者: febert 项目源码 文件源码
def testSparseExpandAndTile5x(self):
    # Shape (3,3).
    x = tf.SparseTensorValue(
        indices=(
            (0, 0), (0, 1),
            (1, 0), (1, 1), (1, 2),
            (2, 0)),
        values=(
            1, 2,
            3, 4, 5,
            6),
        shape=(3, 3))
    with self.test_session():
      expected_result_dim0 = tf.SparseTensorValue(
          indices=[(d0, i[0], i[1]) for d0 in range(5) for i in x.indices],
          values=[v for _ in range(5) for v in x.values],
          shape=(5, 3, 3))
      self._assert_sparse_tensors_equal(
          expected_result_dim0,
          metric_ops.expand_and_tile(x, multiple=5).eval())
      for dim in (-2, 0):
        self._assert_sparse_tensors_equal(
            expected_result_dim0,
            metric_ops.expand_and_tile(x, multiple=5, dim=dim).eval())

      expected_result_dim1 = tf.SparseTensorValue(
          indices=[
              (d0, d1, i[1])
              for d0 in range(3)
              for d1 in range(5)
              for i in x.indices if i[0] == d0],
          values=x.values[0:2] * 5 + x.values[2:5] * 5 + x.values[5:] * 5,
          shape=(3, 5, 3))
      for dim in (-1, 1):
        self._assert_sparse_tensors_equal(
            expected_result_dim1,
            metric_ops.expand_and_tile(x, multiple=5, dim=dim).eval())

      expected_result_dim2 = tf.SparseTensorValue(
          indices=[(i[0], i[1], d2) for i in x.indices for d2 in range(5)],
          values=[v for v in x.values for _ in range(5)],
          shape=(3, 3, 5))
      self._assert_sparse_tensors_equal(
          expected_result_dim2,
          metric_ops.expand_and_tile(x, multiple=5, dim=2).eval())
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号