train.py 文件源码

python
阅读 21 收藏 0 点赞 0 评论 0

项目:yolo-pytorch 作者: makora9143 项目源码 文件源码
def train(args):
    print('Dataset of instance(s) and batch size is {}'.format(args.batch_size))
    vgg = models.vgg16(True)
    model = YOLO(vgg.features)
    if args.use_cuda:
        model = torch.nn.DataParallel(model)
        model.cuda()

    optimizer = optim.Adam(model.parameters(), lr=args.lr)

    best = 1e+30

    for epoch in range(1, args.epochs+1):
        l = train_epoch(epoch, model, optimizer, args)

        upperleft, bottomright, classes, confs = test_epoch(model, jpg='../data/1.jpg')
        is_best = l < best
        best = min(l, best)
        save_checkpoint({
            'epoch': epoch + 1,
            'state_dict': model.state_dict(),
            'optimizer' : optimizer.state_dict(),
        }, is_best)
    checkpoint = torch.load('./model_best.pth.tar')
    state_dict = checkpoint['state_dict']

    new_state_dict = OrderedDict()

    for k, v in state_dict.items():
        name = k[7:]
        new_state_dict[name] = v

    model.load_state_dict(new_state_dict)
    model.cpu()

    torch.save(model.state_dict(), 'model_cpu.pth.tar')
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号