chainer_alex.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:mlimages 作者: icoxfog417 项目源码 文件源码
def train(epoch=10, batch_size=32, gpu=False):
    if gpu:
        cuda.check_cuda_available()
    xp = cuda.cupy if gpu else np

    td = TrainingData(LABEL_FILE, img_root=IMAGES_ROOT, image_property=IMAGE_PROP)

    # make mean image
    if not os.path.isfile(MEAN_IMAGE_FILE):
        print("make mean image...")
        td.make_mean_image(MEAN_IMAGE_FILE)
    else:
        td.mean_image_file = MEAN_IMAGE_FILE

    # train model
    label_def = LabelingMachine.read_label_def(LABEL_DEF_FILE)
    model = alex.Alex(len(label_def))
    optimizer = optimizers.MomentumSGD(lr=0.01, momentum=0.9)
    optimizer.setup(model)
    epoch = epoch
    batch_size = batch_size

    print("Now our model is {0} classification task.".format(len(label_def)))
    print("begin training the model. epoch:{0} batch size:{1}.".format(epoch, batch_size))

    if gpu:
        model.to_gpu()

    for i in range(epoch):
        print("epoch {0}/{1}: (learning rate={2})".format(i + 1, epoch, optimizer.lr))
        td.shuffle(overwrite=True)

        for x_batch, y_batch in td.generate_batches(batch_size):
            x = chainer.Variable(xp.asarray(x_batch))
            t = chainer.Variable(xp.asarray(y_batch))

            optimizer.update(model, x, t)
            print("loss: {0}, accuracy: {1}".format(float(model.loss.data), float(model.accuracy.data)))

        serializers.save_npz(MODEL_FILE, model)
        optimizer.lr *= 0.97
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号