def unet_fit(name, start_t, end_t, start_v, end_v, check_name = None):
t = time.time()
callbacks = [EarlyStopping(monitor='val_loss', patience = 15,
verbose = 1),
ModelCheckpoint('/home/w/DS_Projects/Kaggle/DS Bowl 2017/Scripts/LUNA/CNN/Checkpoints/{}.h5'.format(name),
monitor='val_loss',
verbose = 0, save_best_only = True)]
if check_name is not None:
check_model = '/home/w/DS_Projects/Kaggle/DS Bowl 2017/Scripts/LUNA/CNN/Checkpoints/{}.h5'.format(check_name)
model = load_model(check_model,
custom_objects={'dice_coef_loss': dice_coef_loss, 'dice_coef': dice_coef})
else:
model = unet_model()
model.fit_generator(generate_train(start_t, end_t), nb_epoch = 150, verbose = 1,
validation_data = generate_val(start_v, end_v),
callbacks = callbacks,
samples_per_epoch = 551, nb_val_samples = 50)
return
# In[5]:
评论列表
文章目录