def test_sparse_roundtrip(self):
export_path = os.path.join(tempfile.mkdtemp(), 'export')
with tf.Graph().as_default():
with tf.Session().as_default() as session:
input_float = tf.sparse_placeholder(tf.float32)
output = input_float / 5.0
inputs = {'input': input_float}
outputs = {'output': output}
saved_transform_io.write_saved_transform_from_session(
session, inputs, outputs, export_path)
with tf.Graph().as_default():
with tf.Session().as_default() as session:
indices = np.array([[3, 2, 0], [4, 5, 1]], dtype=np.int64)
values = np.array([1.0, 2.0], dtype=np.float32)
shape = np.array([7, 9, 2], dtype=np.int64)
input_sparse = tf.SparseTensor(
indices=indices, values=values, dense_shape=shape)
# Using a computed input gives confidence that the graphs are fused
inputs = {'input': input_sparse * 10}
outputs = saved_transform_io.apply_saved_transform(export_path, inputs)
output_sparse = outputs['output']
self.assertTrue(isinstance(output_sparse, tf.SparseTensor))
result = session.run(output_sparse)
# indices and shape unchanged; values divided by 2
self.assertEqual(indices.tolist(), result.indices.tolist())
self.assertEqual([2.0, 4.0], result.values.tolist())
self.assertEqual(shape.tolist(), result.dense_shape.tolist())
评论列表
文章目录