train.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:Ultras-Sound-Nerve-Segmentation---Kaggle 作者: Simoncarbo 项目源码 文件源码
def train_segment(train_imgs, train_masks, train_index,train_i,val_i,factor,factor_val):
    def dice_coef(y_true, y_pred):
        intersection = K.sum(K.sum(y_true * y_pred,axis = -1),axis = -1)
        sum_pred = K.sum(K.sum(y_pred,axis = -1),axis = -1)
        sum_true = K.sum(K.sum(y_true,axis = -1),axis = -1)

        weighting = K.greater_equal(sum_true,1)*factor+1
        return -K.mean(weighting*(2. * intersection  + smooth) / (sum_true + sum_pred + smooth))
    def dice_coef_wval(y_true, y_pred):
        intersection = K.sum(K.sum(y_true * y_pred,axis = -1),axis = -1)
        sum_pred = K.sum(K.sum(y_pred,axis = -1),axis = -1)
        sum_true = K.sum(K.sum(y_true,axis = -1),axis = -1)

        weighting = K.greater_equal(sum_true,1)*factor_val+1
        return -K.mean(weighting*(2. * intersection  + smooth) / (sum_true + sum_pred + smooth))

    model = models.segment()

    model.compile(optimizer =Adam(lr=1e-2), loss = dice_coef,metrics = [dice_coef_wval,dice_tresh,pres_acc])

    augmentation_ratio, data_generator = dm.data_generator_segment(nb_rows_small, nb_cols_small,nb_rows_mask_small, nb_cols_mask_small)

    def schedule(epoch):
        if epoch<=5:
            return 1e-2
        elif epoch<=10:
            return 5e-3
        elif epoch<=25:
            return 2e-3
        elif epoch<=40:
            return 1e-3
        else:
            return 5e-4
    lr_schedule= LearningRateScheduler(schedule)
    modelCheck = ModelCheckpoint('Saved/model2_weights_epoch_{epoch:02d}.hdf5', verbose=0, save_best_only=False)

    print('training starts...')
    epoch_history = model.fit_generator(\
    data_generator(train_imgs[train_i], train_masks[train_i], train_index[train_i],batch_size = len(np.unique(train_index[train_i,0]))), \
    samples_per_epoch = augmentation_ratio*len(train_i),nb_epoch = 50, callbacks = [lr_schedule,modelCheck], \
    validation_data = (train_imgs[val_i],train_masks[val_i]),max_q_size=10)

    return model, epoch_history

#==============================================================================
# Data importation and processing
#==============================================================================
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号