train_cnn.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:painters 作者: inejc 项目源码 文件源码
def _train_model():
    data_info = load_organized_data_info(IMGS_DIM_3D[1])
    dir_tr = data_info['dir_tr']
    dir_val = data_info['dir_val']

    gen_tr, gen_val = train_val_dirs_generators(BATCH_SIZE, dir_tr, dir_val)
    model = _cnn(IMGS_DIM_3D)

    model.fit_generator(
        generator=gen_tr,
        nb_epoch=MAX_EPOCHS,
        samples_per_epoch=data_info['num_tr'],
        validation_data=gen_val,
        nb_val_samples=data_info['num_val'],
        callbacks=[ModelCheckpoint(CNN_MODEL_FILE, save_best_only=True)],
        verbose=2)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号