def _train_model():
data_info = load_organized_data_info(IMGS_DIM_3D[1])
dir_tr = data_info['dir_tr']
dir_val = data_info['dir_val']
gen_tr, gen_val = train_val_dirs_generators(BATCH_SIZE, dir_tr, dir_val)
model = _cnn(IMGS_DIM_3D)
model.fit_generator(
generator=gen_tr,
nb_epoch=MAX_EPOCHS,
samples_per_epoch=data_info['num_tr'],
validation_data=gen_val,
nb_val_samples=data_info['num_val'],
callbacks=[ModelCheckpoint(CNN_MODEL_FILE, save_best_only=True)],
verbose=2)
评论列表
文章目录