def cnn3d_genfit(name, nn_model, epochs, start_t, end_t, start_v, end_v, nb_train, nb_val, check_name = None):
callbacks = [EarlyStopping(monitor='val_loss', patience = 15,
verbose = 1),
ModelCheckpoint('/home/w/DS_Projects/Kaggle/DS Bowl 2017/Scripts/LUNA/CNN/Checkpoints/{}.h5'.format(name),
monitor='val_loss',
verbose = 0, save_best_only = True)]
if check_name is not None:
check_model = '/home/w/DS_Projects/Kaggle/DS Bowl 2017/Scripts/LUNA/CNN/Checkpoints/{}.h5'.format(check_name)
model = load_model(check_model)
else:
model = nn_model
model.fit_generator(generate_train(start_t, end_t), nb_epoch = epochs, verbose = 1,
validation_data = generate_val(start_v, end_v),
callbacks = callbacks,
samples_per_epoch = nb_train, nb_val_samples = nb_val)
return
评论列表
文章目录