3DUNet_train_generator.py 文件源码

python
阅读 31 收藏 0 点赞 0 评论 0

项目:Kaggle-DSB 作者: Wrosinski 项目源码 文件源码
def unet_fit(name, start_t, end_t, start_v, end_v, check_name = None):

    t = time.time()
    callbacks = [EarlyStopping(monitor='val_loss', patience = 15, 
                                   verbose = 1),
    ModelCheckpoint('/home/w/DS_Projects/Kaggle/DS Bowl 2017/Scripts/LUNA/CNN/Checkpoints/{}.h5'.format(name), 
                        monitor='val_loss', 
                        verbose = 0, save_best_only = True)]

    if check_name is not None:
        check_model = '/home/w/DS_Projects/Kaggle/DS Bowl 2017/Scripts/LUNA/CNN/Checkpoints/{}.h5'.format(check_name)
        model = load_model(check_model, 
                           custom_objects={'dice_coef_loss': dice_coef_loss, 'dice_coef': dice_coef})
    else:
        model = unet_model()

    model.fit_generator(generate_train(start_t, end_t), nb_epoch = 150, verbose = 1, 
                        validation_data = generate_val(start_v, end_v), 
                        callbacks = callbacks,
                        samples_per_epoch = 551, nb_val_samples = 50)

    return


# In[5]:
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号