def stream(stream): if stream is None: yield return prev_stream = current_stream() torch._C._cuda_setStream(stream._cdata) try: yield finally: torch._C._cuda_setStream(prev_stream._cdata)