cifar10_custom_dataset.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:pytorch_60min_blitz 作者: kyuhyoung 项目源码 文件源码
def main():

    #'''
    li_mode = ['TORCHVISION_MEMORY', 'TORCHVISION_IMAGEFOLDER',
               'CUSTOM_MEMORY', 'CUSTOM_FILE', 'CUSTOM_TENSORDATSET']
    #'''
    '''
    li_mode = ['CUSTOM_TENSORDATSET', 'TORCHVISION_MEMORY',
               'TORCHVISION_IMAGEFOLDER', 'CUSTOM_MEMORY', 'CUSTOM_FILE']
    '''
    dir_data = './data'
    ext_img = 'png'
    #n_epoch = 100
    n_epoch = 50
    #n_img_per_batch = 40
    n_img_per_batch = 60
    n_worker = 4
    interval_train_loss = int(round(20000 / n_img_per_batch)) * n_img_per_batch
    is_gpu = torch.cuda.device_count() > 0

    transform = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

    di_set_transform = {'train' : transform, 'test' : transform}

    #fig = plt.figure(num=None, figsize=(1, 2), dpi=500)
    fig = plt.figure(num=None, figsize=(12, 18), dpi=100)
    plt.ion()
    ax_time = fig.add_subplot(3, 1, 1)
    ax_time.set_title(
        'Elapsed time (sec.) of validation on 10k images vs. epoch. Note that value for epoch 0 is the elapsed time of init.')
    ax_time.xaxis.set_major_locator(MaxNLocator(integer=True))
    ax_loss_train = fig.add_subplot(3, 1, 2)
    ax_loss_train.set_title('Avg. train loss per image vs. # train input images')
    ax_loss_train.xaxis.set_major_locator(MaxNLocator(integer=True))
    ax_loss_val = fig.add_subplot(3, 1, 3)
    ax_loss_val.set_title('Avg. val. loss per image vs. # train input images')
    ax_loss_val.xaxis.set_major_locator(MaxNLocator(integer=True))
    for i_m, mode in enumerate(li_mode):
        start = time()
        trainloader, testloader, net, criterion, optimizer, scheduler, li_class = \
            initialize(
                mode, is_gpu, dir_data, di_set_transform, ext_img, n_img_per_batch, n_worker)
        lap_init = time() - start
        #print('[%s] lap of initializing : %d sec' % (lap_sec))
        kolor = np.random.rand(3)
        #if 2 == i_m:
        #    a = 0
        train(is_gpu, trainloader, testloader, net, criterion, optimizer, scheduler, #li_class,
              n_epoch, lap_init, ax_time, ax_loss_train, ax_loss_val,
              mode, kolor, n_img_per_batch, interval_train_loss)
    print('Finished all.')
    plt.pause(1000)
    return
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号