def train(args):
print('Dataset of instance(s) and batch size is {}'.format(args.batch_size))
vgg = models.vgg16(True)
model = YOLO(vgg.features)
if args.use_cuda:
model = torch.nn.DataParallel(model)
model.cuda()
optimizer = optim.Adam(model.parameters(), lr=args.lr)
best = 1e+30
for epoch in range(1, args.epochs+1):
l = train_epoch(epoch, model, optimizer, args)
upperleft, bottomright, classes, confs = test_epoch(model, jpg='../data/1.jpg')
is_best = l < best
best = min(l, best)
save_checkpoint({
'epoch': epoch + 1,
'state_dict': model.state_dict(),
'optimizer' : optimizer.state_dict(),
}, is_best)
checkpoint = torch.load('./model_best.pth.tar')
state_dict = checkpoint['state_dict']
new_state_dict = OrderedDict()
for k, v in state_dict.items():
name = k[7:]
new_state_dict[name] = v
model.load_state_dict(new_state_dict)
model.cpu()
torch.save(model.state_dict(), 'model_cpu.pth.tar')
评论列表
文章目录