def train(model_name, fold_count, train_full_set=False, load_weights_path=None, ndsb3_holdout=0, manual_labels=True):
batch_size = 16
train_files, holdout_files = get_train_holdout_files(train_percentage=80, ndsb3_holdout=ndsb3_holdout, manual_labels=manual_labels, full_luna_set=train_full_set, fold_count=fold_count)
# train_files = train_files[:100]
# holdout_files = train_files[:10]
train_gen = data_generator(batch_size, train_files, True)
print('train_gen_len:',train_gen)
holdout_gen = data_generator(batch_size, holdout_files, False)
for i in range(0, 10):
tmp = next(holdout_gen)
cube_img = tmp[0][0].reshape(CUBE_SIZE, CUBE_SIZE, CUBE_SIZE, 1)
cube_img = cube_img[:, :, :, 0]
cube_img *= 255.
cube_img += MEAN_PIXEL_VALUE
# helpers.save_cube_img("c:/tmp/img_" + str(i) + ".png", cube_img, 4, 8)
# print(tmp)
learnrate_scheduler = LearningRateScheduler(step_decay)
model = get_net(load_weight_path=load_weights_path)
holdout_txt = "_h" + str(ndsb3_holdout) if manual_labels else ""
if train_full_set:
holdout_txt = "_fs" + holdout_txt
checkpoint = ModelCheckpoint("workdir/model_" + model_name + "_" + holdout_txt + "_e" + "{epoch:02d}-{val_loss:.4f}.hd5", monitor='val_loss', verbose=1, save_best_only=not train_full_set, save_weights_only=False, mode='auto', period=1)
checkpoint_fixed_name = ModelCheckpoint("workdir/model_" + model_name + "_" + holdout_txt + "_best.hd5", monitor='val_loss', verbose=1, save_best_only=True, save_weights_only=False, mode='auto', period=1)
model.fit_generator(train_gen, len(train_files) / 1, 12, validation_data=holdout_gen, nb_val_samples=len(holdout_files) / 1, callbacks=[checkpoint, checkpoint_fixed_name, learnrate_scheduler])
model.save("workdir/model_" + model_name + "_" + holdout_txt + "_end.hd5")
step5_train_nodule_detector.py 文件源码
python
阅读 29
收藏 0
点赞 0
评论 0
评论列表
文章目录