train_resnet.py 文件源码

python
阅读 19 收藏 0 点赞 0 评论 0

项目:Synkhronos 作者: astooke 项目源码 文件源码
def train_resnet(
        batch_size=64,  # batch size on each GPU
        validFreq=1,
        do_valid=False,
        learning_rate=1e-3,
        update_rule=updates.sgd,  # updates.nesterov_momentum,
        n_epoch=3,
        n_gpu=None,  # later get this from synk.fork
        **update_kwargs):

    n_gpu = synk.fork(n_gpu)  # (n_gpu==None will use all)

    t_0 = time.time()
    print("Loading data (synthetic)")
    train, valid, test = load_data()

    x_train, y_train = [synk.data(d) for d in train]
    x_valid, y_valid = [synk.data(d) for d in valid]
    x_test, y_test = [synk.data(d) for d in test]

    full_mb_size = batch_size * n_gpu
    learning_rate = learning_rate * n_gpu  # (one technique for larger minibatches)
    num_valid_slices = len(x_valid) // n_gpu // batch_size
    print("Will compute validation using {} slices".format(num_valid_slices))

    print("Building model")
    resnet = build_resnet()
    params = L.get_all_params(resnet.values(), trainable=True)

    f_train_minibatch, f_predict = build_training(resnet, params, update_rule,
                                                  learning_rate=learning_rate,
                                                  **update_kwargs)

    synk.distribute()
    synk.broadcast(params)  # (ensure all GPUs have same values)

    t_last = t_1 = time.time()
    print("Total setup time: {:,.1f} s".format(t_1 - t_0))
    print("Starting training")

    for ep in range(n_epoch):
        train_loss = 0.
        i = 0
        for mb_idxs in iter_mb_idxs(full_mb_size, len(x_train), shuffle=True):
            train_loss += f_train_minibatch(x_train, y_train, batch=mb_idxs)
            i += 1
        train_loss /= i

        print("\nEpoch: ", ep)
        print("Training Loss: {:.3f}".format(train_loss))

        if do_valid and ep % validFreq == 0:
            valid_loss, valid_mc = f_predict(x_valid, y_valid,
                                             num_slices=num_valid_slices)
            print("Validation Loss: {:3f},   Accuracy: {:3f}".format(
                float(valid_loss), float(1 - valid_mc)))

        t_2 = time.time()
        print("(epoch total time: {:,.1f} s)".format(t_2 - t_last))
        t_last = t_2
    print("\nTotal training time: {:,.1f} s".format(t_last - t_1))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号