saved_transform_io_test.py 文件源码

python
阅读 29 收藏 0 点赞 0 评论 0

项目:transform 作者: tensorflow 项目源码 文件源码
def test_sparse_roundtrip(self):
    export_path = os.path.join(tempfile.mkdtemp(), 'export')

    with tf.Graph().as_default():
      with tf.Session().as_default() as session:
        input_float = tf.sparse_placeholder(tf.float32)
        output = input_float / 5.0
        inputs = {'input': input_float}
        outputs = {'output': output}
        saved_transform_io.write_saved_transform_from_session(
            session, inputs, outputs, export_path)

    with tf.Graph().as_default():
      with tf.Session().as_default() as session:
        indices = np.array([[3, 2, 0], [4, 5, 1]], dtype=np.int64)
        values = np.array([1.0, 2.0], dtype=np.float32)
        shape = np.array([7, 9, 2], dtype=np.int64)
        input_sparse = tf.SparseTensor(
            indices=indices, values=values, dense_shape=shape)

        # Using a computed input gives confidence that the graphs are fused
        inputs = {'input': input_sparse * 10}
        outputs = saved_transform_io.apply_saved_transform(export_path, inputs)
        output_sparse = outputs['output']
        self.assertTrue(isinstance(output_sparse, tf.SparseTensor))
        result = session.run(output_sparse)

        # indices and shape unchanged; values divided by 2
        self.assertEqual(indices.tolist(), result.indices.tolist())
        self.assertEqual([2.0, 4.0], result.values.tolist())
        self.assertEqual(shape.tolist(), result.dense_shape.tolist())
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号