profile_train.py 文件源码

python
阅读 18 收藏 0 点赞 0 评论 0

项目:ParlAI 作者: facebookresearch 项目源码 文件源码
def main(parser):
    opt = parser.parse_args()

    if opt['torch']:
        with torch.autograd.profiler.profile() as prof:
            TrainLoop(parser).train()
        print(prof.total_average())

        sort_cpu = sorted(prof.key_averages(), key=lambda k: k.cpu_time)
        sort_cuda = sorted(prof.key_averages(), key=lambda k: k.cuda_time)

        def cpu():
            for e in sort_cpu:
                print(e)

        def cuda():
            for e in sort_cuda:
                print(e)

        cpu()

        if opt['debug']:
            print('`cpu()` prints out cpu-sorted list, '
                  '`cuda()` prints cuda-sorted list')

            pdb.set_trace()
    else:
        pr = cProfile.Profile()
        pr.enable()
        TrainLoop(parser).train()
        pr.disable()
        s = io.StringIO()
        sortby = 'cumulative'
        ps = pstats.Stats(pr, stream=s).sort_stats(sortby)
        ps.print_stats()
        print(s.getvalue())
        if opt['debug']:
            pdb.set_trace()
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号