def use_cudnn(should_use): orig = torch.backends.cudnn.enabled torch.backends.cudnn.enabled = should_use try: yield finally: torch.backends.cudnn.enabled = orig