train_resnet_theano.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:Synkhronos 作者: astooke 项目源码 文件源码
def train_resnet(
        batch_size=64,  # batch size on each GPU
        validFreq=1,
        do_valid=False,
        learning_rate=1e-3,
        update_rule=updates.sgd,  # updates.nesterov_momentum,
        n_epoch=3,
        **update_kwargs):

    # Initialize single GPU.
    theano.gpuarray.use("cuda")

    t_0 = time.time()
    print("Loading data (synthetic)")
    train, valid, test = load_data()

    x_train, y_train = train
    x_valid, y_valid = valid
    x_test, y_test = test

    print("Building model")
    resnet = build_resnet()
    params = L.get_all_params(resnet.values(), trainable=True)

    f_train_minibatch, f_predict = build_training(resnet, params, update_rule,
                                                  learning_rate=learning_rate,
                                                  **update_kwargs)

    t_last = t_1 = time.time()
    print("Total setup time: {:,.1f} s".format(t_1 - t_0))
    print("Starting training")

    for ep in range(n_epoch):
        train_loss = 0.
        i = 0
        for mb_idxs in iter_mb_idxs(batch_size, len(x_train), shuffle=True):
            train_loss += f_train_minibatch(x_train[mb_idxs], y_train[mb_idxs])
            i += 1
        train_loss /= i

        print("\nEpoch: ", ep)
        print("Training Loss: {:.3f}".format(train_loss))

        if do_valid and ep % validFreq == 0:
            valid_loss = valid_mc = 0.
            i = 0
            for mb_idxs in iter_mb_idxs(batch_size, len(x_valid), shuffle=False):
                mb_loss, mb_mc = f_predict(x_valid[mb_idxs], y_valid[mb_idxs])
                valid_loss += mb_loss
                valid_mc += mb_mc
                i += 1
            valid_loss /= i
            valid_mc /= i
            print("Validation Loss: {:3f},   Accuracy: {:3f}".format(valid_loss, 1 - valid_mc))

        t_2 = time.time()
        print("(epoch total time: {:,.1f} s)".format(t_2 - t_last))
        t_last = t_2
    print("\nTotal training time: {:,.1f} s".format(t_last - t_1))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号