def test_pickle_save_load_cuda_intercompatibility(
self, net_cls, module_cls, tmpdir):
from skorch.exceptions import DeviceWarning
net = net_cls(module=module_cls, use_cuda=True).initialize()
p = tmpdir.mkdir('skorch').join('testmodel.pkl')
with open(str(p), 'wb') as f:
pickle.dump(net, f)
del net
with patch('torch.cuda.is_available', lambda *_: False):
with pytest.warns(DeviceWarning) as w:
with open(str(p), 'rb') as f:
m = pickle.load(f)
# The loaded model should not use CUDA anymore as it
# already knows CUDA is not available.
assert m.use_cuda is False
assert len(w.list) == 1 # only 1 warning
assert w.list[0].message.args[0] == (
'Model configured to use CUDA but no CUDA '
'devices available. Loading on CPU instead.')
评论列表
文章目录