def version(self):
""" Returns the PyTorch version, as a tuple of (MAJOR, MINOR, PATCH).
"""
import torch # pylint: disable=import-error
version = torch.__version__
match = re.match(r'([0-9]+)\.([0-9]+)\.([0-9]+)\.*', version)
if not match:
logger.warning('Unable to infer PyTorch version. We '
'cannot check for version incompatibilities.')
return (0, 0, 0)
return tuple(int(x) for x in match.groups())
###########################################################################
评论列表
文章目录