def train_unet(self):
img_size = self.flag.image_size
batch_size = self.flag.batch_size
epochs = self.flag.total_epoch
datagen_args = dict(featurewise_center=False, # set input mean to 0 over the dataset
samplewise_center=False, # set each sample mean to 0
featurewise_std_normalization=False, # divide inputs by std of the dataset
samplewise_std_normalization=False, # divide each input by its std
zca_whitening=False, # apply ZCA whitening
rotation_range=5, # randomly rotate images in the range (degrees, 0 to 180)
width_shift_range=0.05, # randomly shift images horizontally (fraction of total width)
height_shift_range=0.05, # randomly shift images vertically (fraction of total height)
# fill_mode='constant',
# cval=0.,
horizontal_flip=False, # randomly flip images
vertical_flip=False) # randomly flip images
image_datagen = ImageDataGenerator(**datagen_args)
mask_datagen = ImageDataGenerator(**datagen_args)
seed = random.randrange(1, 1000)
image_generator = image_datagen.flow_from_directory(
os.path.join(self.flag.data_path, 'train/IMAGE'),
class_mode=None, seed=seed, batch_size=batch_size, color_mode='grayscale')
mask_generator = mask_datagen.flow_from_directory(
os.path.join(self.flag.data_path, 'train/GT'),
class_mode=None, seed=seed, batch_size=batch_size, color_mode='grayscale')
config = tf.ConfigProto()
# config.gpu_options.per_process_gpu_memory_fraction = 0.9
config.gpu_options.allow_growth = True
set_session(tf.Session(config=config))
model = get_unet(self.flag)
if self.flag.pretrained_weight_path != None:
model.load_weights(self.flag.pretrained_weight_path)
if not os.path.exists(os.path.join(self.flag.ckpt_dir, self.flag.ckpt_name)):
mkdir_p(os.path.join(self.flag.ckpt_dir, self.flag.ckpt_name))
model_json = model.to_json()
with open(os.path.join(self.flag.ckpt_dir, self.flag.ckpt_name, 'model.json'), 'w') as json_file:
json_file.write(model_json)
vis = callbacks.trainCheck(self.flag)
model_checkpoint = ModelCheckpoint(
os.path.join(self.flag.ckpt_dir, self.flag.ckpt_name,'weights.{epoch:03d}.h5'),
period=self.flag.total_epoch//10+1)
learning_rate = LearningRateScheduler(self.lr_step_decay)
model.fit_generator(
self.train_generator(image_generator, mask_generator),
steps_per_epoch= image_generator.n // batch_size,
epochs=epochs,
callbacks=[model_checkpoint, learning_rate, vis]
)
Unet_train.py 文件源码
python
阅读 23
收藏 0
点赞 0
评论 0
评论列表
文章目录