saved_transform_io_test.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:transform 作者: tensorflow 项目源码 文件源码
def test_stale_asset_collections_are_cleaned(self):
    vocabulary_file = os.path.join(
        compat.as_bytes(test.get_temp_dir()), compat.as_bytes('asset'))
    file_io.write_string_to_file(vocabulary_file, 'foo bar baz')

    export_path = os.path.join(tempfile.mkdtemp(), 'export')

    # create a SavedModel including assets
    with tf.Graph().as_default():
      with tf.Session().as_default() as session:
        input_string = tf.placeholder(tf.string)
        # Map string through a table loaded from an asset file
        table = lookup.index_table_from_file(
            vocabulary_file, num_oov_buckets=12, default_value=12)
        output = table.lookup(input_string)
        inputs = {'input': input_string}
        outputs = {'output': output}
        saved_transform_io.write_saved_transform_from_session(
            session, inputs, outputs, export_path)

    # Load it and save it again repeatedly, verifying that the asset collections
    # remain valid.
    for _ in [1, 2, 3]:
      with tf.Graph().as_default() as g:
        with tf.Session().as_default() as session:
          input_string = tf.constant('dog')
          inputs = {'input': input_string}
          outputs = saved_transform_io.apply_saved_transform(export_path,
                                                             inputs)

          self.assertEqual(
              1, len(g.get_collection(ops.GraphKeys.ASSET_FILEPATHS)))
          self.assertEqual(
              0, len(g.get_collection(tf.saved_model.constants.ASSETS_KEY)))

          # Check that every ASSET_FILEPATHS refers to a Tensor in the graph.
          # If not, get_tensor_by_name() raises KeyError.
          for asset_path in g.get_collection(ops.GraphKeys.ASSET_FILEPATHS):
            tensor_name = asset_path.name
            g.get_tensor_by_name(tensor_name)

          export_path = os.path.join(tempfile.mkdtemp(), 'export')
          saved_transform_io.write_saved_transform_from_session(
              session, inputs, outputs, export_path)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号