Unet_train.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:segmentation-visualization-training 作者: tkwoo 项目源码 文件源码
def train_unet(self):

        img_size = self.flag.image_size
        batch_size = self.flag.batch_size
        epochs = self.flag.total_epoch

        datagen_args = dict(featurewise_center=False,  # set input mean to 0 over the dataset
                samplewise_center=False,  # set each sample mean to 0
                featurewise_std_normalization=False,  # divide inputs by std of the dataset
                samplewise_std_normalization=False,  # divide each input by its std
                zca_whitening=False,  # apply ZCA whitening
                rotation_range=5,  # randomly rotate images in the range (degrees, 0 to 180)
                width_shift_range=0.05,  # randomly shift images horizontally (fraction of total width)
                height_shift_range=0.05,  # randomly shift images vertically (fraction of total height)
                # fill_mode='constant',
                # cval=0.,
                horizontal_flip=False,  # randomly flip images
                vertical_flip=False)  # randomly flip images

        image_datagen = ImageDataGenerator(**datagen_args)
        mask_datagen = ImageDataGenerator(**datagen_args)

        seed = random.randrange(1, 1000)
        image_generator = image_datagen.flow_from_directory(
                    os.path.join(self.flag.data_path, 'train/IMAGE'),
                    class_mode=None, seed=seed, batch_size=batch_size, color_mode='grayscale')
        mask_generator = mask_datagen.flow_from_directory(
                    os.path.join(self.flag.data_path, 'train/GT'),
                    class_mode=None, seed=seed, batch_size=batch_size, color_mode='grayscale')
        config = tf.ConfigProto()
        # config.gpu_options.per_process_gpu_memory_fraction = 0.9
        config.gpu_options.allow_growth = True
        set_session(tf.Session(config=config))

        model = get_unet(self.flag)
        if self.flag.pretrained_weight_path != None:
            model.load_weights(self.flag.pretrained_weight_path)

        if not os.path.exists(os.path.join(self.flag.ckpt_dir, self.flag.ckpt_name)):
            mkdir_p(os.path.join(self.flag.ckpt_dir, self.flag.ckpt_name))
        model_json = model.to_json()
        with open(os.path.join(self.flag.ckpt_dir, self.flag.ckpt_name, 'model.json'), 'w') as json_file:
            json_file.write(model_json)
        vis = callbacks.trainCheck(self.flag)
        model_checkpoint = ModelCheckpoint(
                    os.path.join(self.flag.ckpt_dir, self.flag.ckpt_name,'weights.{epoch:03d}.h5'), 
                    period=self.flag.total_epoch//10+1)
        learning_rate = LearningRateScheduler(self.lr_step_decay)
        model.fit_generator(
            self.train_generator(image_generator, mask_generator),
            steps_per_epoch= image_generator.n // batch_size,
            epochs=epochs,
            callbacks=[model_checkpoint, learning_rate, vis]
        )
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号