train_utils.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:kaggle-carvana 作者: ematvey 项目源码 文件源码
def gpu_preloader_iter(dataloader):
    loader_iter = iter(dataloader)
    bx, by = None, None
    while 1:
        try:
            x, y = bx, by
            bx, by = next(loader_iter)
            if torch.is_tensor(bx):
                bx = bx.cuda(async=True)
            if torch.is_tensor(by):
                by = by.cuda(async=True)
            if x is None or y is None:
                x, y = next(loader_iter)
                if torch.is_tensor(x):
                    x = x.cuda()
                if torch.is_tensor(y):
                    y = y.cuda()
            yield x, y
        except StopIteration:
            if bx is not None:
                yield bx, by
            return
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号