util.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:pyprob 作者: probprog 项目源码 文件源码
def load_artifact(file_name, cuda=False, device_id=-1):
    try:
        if cuda:
            artifact = torch.load(file_name)
        else:
            artifact = torch.load(file_name, map_location=lambda storage, loc: storage)
    except:
        logger.log_error('load_artifact: Cannot load file')
    if artifact.code_version != pyprob.__version__:
        logger.log()
        logger.log_warning('Different pyprob versions (artifact: {0}, current: {1})'.format(artifact.code_version, pyprob.__version__))
        logger.log()
    if artifact.pytorch_version != torch.__version__:
        logger.log()
        logger.log_warning('Different PyTorch versions (artifact: {0}, current: {1})'.format(artifact.pytorch_version, torch.__version__))
        logger.log()

    # if print_info:
    #     file_size = '{:,}'.format(os.path.getsize(file_name))
    #     log_print('File name             : {0}'.format(file_name))
    #     log_print('File size (Bytes)     : {0}'.format(file_size))
    #     log_print(artifact.get_info())
    #     log_print()

    if cuda:
        if device_id == -1:
            device_id = torch.cuda.current_device()

        if artifact.on_cuda:
            if device_id != artifact.cuda_device_id:
                logger.log_warning('Loading CUDA (device {0}) artifact to CUDA (device {1})'.format(artifact.cuda_device_id, device_id))
                logger.log()
                artifact.move_to_cuda(device_id)
        else:
            logger.log_warning('Loading CPU artifact to CUDA (device {0})'.format(device_id))
            logger.log()
            artifact.move_to_cuda(device_id)
    else:
        if artifact.on_cuda:
            logger.log_warning('Loading CUDA artifact to CPU')
            logger.log()
            artifact.move_to_cpu()

    return artifact
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号