trainer.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:deepanalytics_compe26_benchmark 作者: takagiwa-ss 项目源码 文件源码
def train():
    from keras.optimizers import SGD
    from keras.preprocessing.image import ImageDataGenerator

    logging.info('... building model')

    sgd = SGD(lr=_sgd_lr, decay=_sgd_decay, momentum=0.9, nesterov=True)

    model = resnet()
    model.compile(
        loss=_objective,
        optimizer=sgd,
        metrics=['mae'])

    logging.info('... loading data')

    X, Y = load_train_data()

    logging.info('... training')

    datagen = ImageDataGenerator(
        # data augmentation
        width_shift_range  = 1./8.,
        height_shift_range = 1./8.,
        rotation_range     = 0.,
        shear_range        = 0.,
        zoom_range         = 0.,
    )

    model.fit_generator(
        datagen.flow(X, Y, batch_size=_batch_size),
        samples_per_epoch=X.shape[0],
        nb_epoch=_nb_epoch,
        verbose=1)

    return model
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号