def train(self):
""" Run training function. Save model upon completion """
self.print_log('Training for %d epochs' % self.flags['num_epochs'])
tf_inputs = (self.x['TRAIN'], self.im_dims['TRAIN'], self.gt_boxes['TRAIN'])
self.step += 1
for self.epoch in trange(1, self.flags['num_epochs']+1, desc='epochs'):
train_order = randomize_training_order(len(self.names['TRAIN']))
for i in tqdm(train_order):
feed_dict = create_feed_dict(flags['data_directory'], self.names['TRAIN'], tf_inputs, i)
# Run a training iteration
if self.step % (self.flags['display_step']) == 0:
# Record training metrics every display_step interval
summary = self._record_train_metrics(feed_dict)
self._record_training_step(summary)
else:
summary = self._run_train_iter(feed_dict)
self._record_training_step(summary)
## Epoch finished
# Save model
if self.epoch % cfg.CHECKPOINT_RATE == 0:
self._save_model(section=self.epoch)
# Perform validation
if self.epoch % cfg.VALID_RATE == 0:
self.evaluate(test=False)
# # Adjust learning rate
# if self.epoch % cfg.TRAIN.LEARNING_RATE_DECAY_RATE == 0:
# self.lr = self.lr * cfg.TRAIN.LEARNING_RATE_DECAY
# self.print_log("Learning Rate: %f" % self.lr)
faster_rcnn_resnet50ish.py 文件源码
python
阅读 19
收藏 0
点赞 0
评论 0
评论列表
文章目录