faster_rcnn_resnet50ish.py 文件源码

python
阅读 19 收藏 0 点赞 0 评论 0

项目:tf-Faster-RCNN 作者: kevinjliang 项目源码 文件源码
def train(self):
        """ Run training function. Save model upon completion """
        self.print_log('Training for %d epochs' % self.flags['num_epochs'])

        tf_inputs = (self.x['TRAIN'], self.im_dims['TRAIN'], self.gt_boxes['TRAIN'])

        self.step += 1
        for self.epoch in trange(1, self.flags['num_epochs']+1, desc='epochs'):
            train_order = randomize_training_order(len(self.names['TRAIN']))

            for i in tqdm(train_order):
                feed_dict = create_feed_dict(flags['data_directory'], self.names['TRAIN'], tf_inputs, i)

                # Run a training iteration
                if self.step % (self.flags['display_step']) == 0:
                    # Record training metrics every display_step interval
                    summary = self._record_train_metrics(feed_dict)
                    self._record_training_step(summary)
                else: 
                    summary = self._run_train_iter(feed_dict)
                    self._record_training_step(summary)             

            ## Epoch finished
            # Save model 
            if self.epoch % cfg.CHECKPOINT_RATE == 0: 
                self._save_model(section=self.epoch)
            # Perform validation
            if self.epoch % cfg.VALID_RATE == 0: 
                self.evaluate(test=False)
#            # Adjust learning rate
#            if self.epoch % cfg.TRAIN.LEARNING_RATE_DECAY_RATE == 0:
#                self.lr = self.lr * cfg.TRAIN.LEARNING_RATE_DECAY
#                self.print_log("Learning Rate: %f" % self.lr)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号