test_gan_metrics.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:tefla 作者: openAGI 项目源码 文件源码
def test_get_graph_def_from_url_tarball(self):
        """Test `get_graph_def_from_url_tarball`."""
        # Write dummy binary GraphDef to tempfile.
        with tempfile.NamedTemporaryFile(delete=False) as tmp_file:
            tmp_file.write(_get_dummy_graphdef().SerializeToString())
        relative_path = os.path.relpath(tmp_file.name)

        # Create gzip tarball.
        tar_dir = tempfile.mkdtemp()
        tar_filename = os.path.join(tar_dir, 'tmp.tar.gz')
        with tarfile.open(tar_filename, 'w:gz') as tar:
            tar.add(relative_path)

        with mock.patch.object(gan_metrics, 'urllib') as mock_urllib:
            mock_urllib.request.urlretrieve.return_value = tar_filename, None
            graph_def = gan_metrics.get_graph_def_from_url_tarball(
                'unused_url', relative_path)

        self.assertIsInstance(graph_def, tf.GraphDef)
        self.assertEqual(_get_dummy_graphdef(), graph_def)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号