common.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:pytorch 作者: ezyang 项目源码 文件源码
def assertExpected(self, s, subname=None):
        """
        Test that a string matches the recorded contents of a file
        derived from the name of this test and subname.  This file
        is placed in the 'expect' directory in the same directory
        as the test script. You can automatically update the recorded test
        output using --accept.

        If you call this multiple times in a single function, you must
        give a unique subname each time.
        """
        if not (isinstance(s, str) or (sys.version_info[0] == 2 and isinstance(s, unicode))):
            raise TypeError("assertExpected is strings only")

        def remove_prefix(text, prefix):
            if text.startswith(prefix):
                return text[len(prefix):]
            return text
        munged_id = remove_prefix(self.id(), "__main__.")
        # NB: we take __file__ from __main__, so we place the expect directory
        # where the test script lives, NOT where test/common.py lives.  This
        # doesn't matter in PyTorch where all test scripts are in the same
        # directory as test/common.py, but it matters in onnx-pytorch
        expected_file = os.path.join(os.path.dirname(os.path.realpath(__main__.__file__)),
                                     "expect",
                                     munged_id)
        if subname:
            expected_file += "-" + subname
        expected_file += ".expect"
        expected = None

        def accept_output(update_type):
            print("Accepting {} for {}:\n\n{}".format(update_type, munged_id, s))
            with open(expected_file, 'w') as f:
                f.write(s)

        try:
            with open(expected_file) as f:
                expected = f.read()
        except IOError as e:
            if e.errno != errno.ENOENT:
                raise
            elif ACCEPT:
                return accept_output("output")
            else:
                raise RuntimeError(
                    ("I got this output for {}:\n\n{}\n\n"
                     "No expect file exists; to accept the current output, run:\n"
                     "python {} {} --accept").format(munged_id, s, __main__.__file__, munged_id))
        if ACCEPT:
            if expected != s:
                return accept_output("updated output")
        else:
            if hasattr(self, "assertMultiLineEqual"):
                # Python 2.7 only
                # NB: Python considers lhs "old" and rhs "new".
                self.assertMultiLineEqual(expected, s)
            else:
                self.assertEqual(s, expected)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号