cifar10_custom_dataset_gap.py 文件源码

python
阅读 31 收藏 0 点赞 0 评论 0

项目:pytorch_60min_blitz 作者: kyuhyoung 项目源码 文件源码
def initialize(is_gpu, dir_data, di_set_transform, ext_img, n_img_per_batch, n_worker):

    trainloader, testloader, li_class = make_dataloader_custom_file(
        dir_data, di_set_transform, ext_img, n_img_per_batch, n_worker)

    #net = Net().cuda()
    net = Net_gap()
    #t1 = net.cuda()
    criterion = nn.CrossEntropyLoss()
    if is_gpu:
        net.cuda()
        criterion.cuda()
    optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
    scheduler = ReduceLROnPlateau(optimizer, 'min', verbose=1, patience = 8, epsilon=0.00001, min_lr=0.000001) # set up scheduler

    return trainloader, testloader, net, criterion, optimizer, scheduler, li_class
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号