def test_stale_asset_collections_are_cleaned(self):
vocabulary_file = os.path.join(
compat.as_bytes(test.get_temp_dir()), compat.as_bytes('asset'))
file_io.write_string_to_file(vocabulary_file, 'foo bar baz')
export_path = os.path.join(tempfile.mkdtemp(), 'export')
# create a SavedModel including assets
with tf.Graph().as_default():
with tf.Session().as_default() as session:
input_string = tf.placeholder(tf.string)
# Map string through a table loaded from an asset file
table = lookup.index_table_from_file(
vocabulary_file, num_oov_buckets=12, default_value=12)
output = table.lookup(input_string)
inputs = {'input': input_string}
outputs = {'output': output}
saved_transform_io.write_saved_transform_from_session(
session, inputs, outputs, export_path)
# Load it and save it again repeatedly, verifying that the asset collections
# remain valid.
for _ in [1, 2, 3]:
with tf.Graph().as_default() as g:
with tf.Session().as_default() as session:
input_string = tf.constant('dog')
inputs = {'input': input_string}
outputs = saved_transform_io.apply_saved_transform(export_path,
inputs)
self.assertEqual(
1, len(g.get_collection(ops.GraphKeys.ASSET_FILEPATHS)))
self.assertEqual(
0, len(g.get_collection(tf.saved_model.constants.ASSETS_KEY)))
# Check that every ASSET_FILEPATHS refers to a Tensor in the graph.
# If not, get_tensor_by_name() raises KeyError.
for asset_path in g.get_collection(ops.GraphKeys.ASSET_FILEPATHS):
tensor_name = asset_path.name
g.get_tensor_by_name(tensor_name)
export_path = os.path.join(tempfile.mkdtemp(), 'export')
saved_transform_io.write_saved_transform_from_session(
session, inputs, outputs, export_path)
评论列表
文章目录