def __enter__(self): if self.idx is -1: return _lazy_init() self.prev_idx = torch._C._cuda_getDevice() if self.prev_idx != self.idx: torch._C._cuda_setDevice(self.idx)