cnnT1.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:future-price-predictor 作者: htfy96 项目源码 文件源码
def train(model, db, args, bsz=32, eph=1, use_cuda=False):
    print("Training...")

    trainloader = data_utils.DataLoader(dataset=db, batch_size=bsz, shuffle=True)
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=1e-4, momentum=0.9)
    best_loss = 100000

    for epoch in range(eph):
        running_loss = 0.0
        for i, data in enumerate(trainloader, 1):
            inputs, targets = data
            inputs = inputs.unsqueeze(1)
            targets = target_onehot_to_classnum_tensor(targets)
            if use_cuda and cuda_ava:
                inputs = Variable(inputs.float().cuda())
                targets = Variable(targets.cuda())
            else:
                inputs = Variable(inputs.float())
                targets = Variable(targets)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()

            running_loss += loss.data[0]
            last_loss = loss.data[0]
            if i % 100 == 0:
                print("[%d, %5d] loss: %.3f" % (epoch + 1, i, running_loss / 100))
                running_loss = 0

            if last_loss < best_loss:
                best_loss = last_loss
                acc = evaluate(model, trainloader, use_cuda)
                torch.save(model.state_dict(), os.path.join('saved_model', 'cnnT1_epoch_{}_iter_{}_loss_{}_acc_{}_{}.t7'.format(epoch + 1, i, last_loss, acc, datetime.datetime.now().strftime("%b_%d_%H:%M:%S"))))
    acc = evaluate(model, trainloader, use_cuda)
    torch.save(model.state_dict(), os.path.join('saved_model', 'cnnT1_all_acc_{}.t7'.format(acc)))

    print("Finished Training!")
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号