def load_artifact(file_name, cuda=False, device_id=-1):
try:
if cuda:
artifact = torch.load(file_name)
else:
artifact = torch.load(file_name, map_location=lambda storage, loc: storage)
except:
logger.log_error('load_artifact: Cannot load file')
if artifact.code_version != pyprob.__version__:
logger.log()
logger.log_warning('Different pyprob versions (artifact: {0}, current: {1})'.format(artifact.code_version, pyprob.__version__))
logger.log()
if artifact.pytorch_version != torch.__version__:
logger.log()
logger.log_warning('Different PyTorch versions (artifact: {0}, current: {1})'.format(artifact.pytorch_version, torch.__version__))
logger.log()
# if print_info:
# file_size = '{:,}'.format(os.path.getsize(file_name))
# log_print('File name : {0}'.format(file_name))
# log_print('File size (Bytes) : {0}'.format(file_size))
# log_print(artifact.get_info())
# log_print()
if cuda:
if device_id == -1:
device_id = torch.cuda.current_device()
if artifact.on_cuda:
if device_id != artifact.cuda_device_id:
logger.log_warning('Loading CUDA (device {0}) artifact to CUDA (device {1})'.format(artifact.cuda_device_id, device_id))
logger.log()
artifact.move_to_cuda(device_id)
else:
logger.log_warning('Loading CPU artifact to CUDA (device {0})'.format(device_id))
logger.log()
artifact.move_to_cuda(device_id)
else:
if artifact.on_cuda:
logger.log_warning('Loading CUDA artifact to CPU')
logger.log()
artifact.move_to_cpu()
return artifact
评论列表
文章目录