def test_get_graph_def_from_url_tarball(self):
"""Test `get_graph_def_from_url_tarball`."""
# Write dummy binary GraphDef to tempfile.
with tempfile.NamedTemporaryFile(delete=False) as tmp_file:
tmp_file.write(_get_dummy_graphdef().SerializeToString())
relative_path = os.path.relpath(tmp_file.name)
# Create gzip tarball.
tar_dir = tempfile.mkdtemp()
tar_filename = os.path.join(tar_dir, 'tmp.tar.gz')
with tarfile.open(tar_filename, 'w:gz') as tar:
tar.add(relative_path)
with mock.patch.object(gan_metrics, 'urllib') as mock_urllib:
mock_urllib.request.urlretrieve.return_value = tar_filename, None
graph_def = gan_metrics.get_graph_def_from_url_tarball(
'unused_url', relative_path)
self.assertIsInstance(graph_def, tf.GraphDef)
self.assertEqual(_get_dummy_graphdef(), graph_def)
评论列表
文章目录