def unet_cross_val(data_dir, out_dir, mapping, splits, unet_conf):
# Load spreadsheet
with pd.ExcelFile(mapping) as xls:
df = pd.read_excel(xls, 'Sheet1').set_index('index')
df['class'] = df['class'].map({'preplus': 'pre-plus', 'normal': 'normal', 'plus': 'plus'})
img_dir = join(data_dir, 'images')
seg_dir = join(data_dir, 'manual_segmentations')
mask_dir = join(data_dir, 'masks')
# Check whether all images exist
check_images_exist(df, img_dir, seg_dir, mask_dir)
# Now split into training and testing
CVFile = sio.loadmat(splits)
# # Combining Pre-Plus and Plus
# trainPlusIndex = CVFile['trainPlusIndex'][0]
# testPlusIndex = CVFile['testPlusIndex'][0]
#
# plus_dir = make_sub_dir(out_dir, 'trainTestPlus')
# print "Generating splits for combined No and Pre-Plus"
# generate_splits(trainPlusIndex, testPlusIndex, df, img_dir, mask_dir, seg_dir, plus_dir)
# Combining No and Pre-Plus
trainPrePIndex = CVFile['trainPrePIndex'][0]
testPrePIndex = CVFile['testPrePIndex'][0]
prep_dir = make_sub_dir(out_dir, 'trainTestPreP')
print "Generating splits for combined Pre-Plus and Plus"
generate_splits(trainPrePIndex, testPrePIndex, df, img_dir, mask_dir, seg_dir, prep_dir)
# Train models
train_and_test(prep_dir, unet_conf, processes=1)
# train_and_test(plus_dir, unet_conf, processes=2)
评论列表
文章目录