def train(self):
img_size = [self.image_height, self.image_width, self.image_depth]
train_batch = tf.train.shuffle_batch([read_tfrecord(self.train_file, img_size)],
batch_size = self.train_batch_size,
capacity = 2000,
num_threads = 2,
min_after_dequeue = 1000)
test_batch = tf.train.shuffle_batch([read_tfrecord(self.test_file, img_size)],
batch_size = self.test_batch_size,
capacity = 500,
num_threads = 2,
min_after_dequeue = 300)
init = tf.global_variables_initializer()
init_fn = slim.assign_from_checkpoint_fn("vgg_16.ckpt", slim.get_model_variables('vgg_16'))
saver = tf.train.Saver()
with tf.Session() as sess:
sess.run(init)
init_fn(sess)
train_writer = tf.summary.FileWriter(self.log_dir + "/train", sess.graph)
test_writer = tf.summary.FileWriter(self.log_dir + "/test", sess.graph)
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
inputs_test, outputs_gt_test = build_img_pair(sess.run(test_batch))
for iter in range(self.max_iteration):
inputs_train, outputs_gt_train = build_img_pair(sess.run(train_batch))
# train with dynamic learning rate
if iter <= 500:
self.train_step.run({self.input_data:inputs_train, self.gt:outputs_gt_train,
self.learning_rate:1e-4, self.is_training:True})
elif iter <= self.max_iteration - 1000:
self.train_step.run({self.input_data:inputs_train, self.gt:outputs_gt_train,
self.learning_rate:0.5e-4, self.is_training:True})
else:
self.train_step.run({self.input_data:inputs_train, self.gt:outputs_gt_train,
self.learning_rate:1e-5, self.is_training:True})
# print training loss and test loss
if iter%10 == 0:
summary_train = sess.run(self.summary, {self.input_data:inputs_train, self.gt:outputs_gt_train, self.is_training:False})
train_writer.add_summary(summary_train, iter)
train_writer.flush()
summary_test = sess.run(self.summary, {self.input_data:inputs_test, self.gt:outputs_gt_test, self.is_training:False})
test_writer.add_summary(summary_test, iter)
test_writer.flush()
# record training loss and test loss
if iter%10 == 0:
train_loss = self.cross_entropy.eval({self.input_data:inputs_train, self.gt:outputs_gt_train, self.is_training:False})
test_loss = self.cross_entropy.eval({self.input_data:inputs_test, self.gt:outputs_gt_test, self.is_training:False})
print("iter step %d trainning batch loss %f"%(iter, train_loss))
print("iter step %d test loss %f\n"%(iter, test_loss))
# record model
if iter%100 == 0:
saver.save(sess, self.log_dir + "/model.ckpt", global_step=iter)
coord.request_stop()
coord.join(threads)
评论列表
文章目录