def train_segment(train_imgs, train_masks, train_index,train_i,val_i,factor,factor_val):
def dice_coef(y_true, y_pred):
intersection = K.sum(K.sum(y_true * y_pred,axis = -1),axis = -1)
sum_pred = K.sum(K.sum(y_pred,axis = -1),axis = -1)
sum_true = K.sum(K.sum(y_true,axis = -1),axis = -1)
weighting = K.greater_equal(sum_true,1)*factor+1
return -K.mean(weighting*(2. * intersection + smooth) / (sum_true + sum_pred + smooth))
def dice_coef_wval(y_true, y_pred):
intersection = K.sum(K.sum(y_true * y_pred,axis = -1),axis = -1)
sum_pred = K.sum(K.sum(y_pred,axis = -1),axis = -1)
sum_true = K.sum(K.sum(y_true,axis = -1),axis = -1)
weighting = K.greater_equal(sum_true,1)*factor_val+1
return -K.mean(weighting*(2. * intersection + smooth) / (sum_true + sum_pred + smooth))
model = models.segment()
model.compile(optimizer =Adam(lr=1e-2), loss = dice_coef,metrics = [dice_coef_wval,dice_tresh,pres_acc])
augmentation_ratio, data_generator = dm.data_generator_segment(nb_rows_small, nb_cols_small,nb_rows_mask_small, nb_cols_mask_small)
def schedule(epoch):
if epoch<=5:
return 1e-2
elif epoch<=10:
return 5e-3
elif epoch<=25:
return 2e-3
elif epoch<=40:
return 1e-3
else:
return 5e-4
lr_schedule= LearningRateScheduler(schedule)
modelCheck = ModelCheckpoint('Saved/model2_weights_epoch_{epoch:02d}.hdf5', verbose=0, save_best_only=False)
print('training starts...')
epoch_history = model.fit_generator(\
data_generator(train_imgs[train_i], train_masks[train_i], train_index[train_i],batch_size = len(np.unique(train_index[train_i,0]))), \
samples_per_epoch = augmentation_ratio*len(train_i),nb_epoch = 50, callbacks = [lr_schedule,modelCheck], \
validation_data = (train_imgs[val_i],train_masks[val_i]),max_q_size=10)
return model, epoch_history
#==============================================================================
# Data importation and processing
#==============================================================================
train.py 文件源码
python
阅读 25
收藏 0
点赞 0
评论 0
评论列表
文章目录