def __init__(self, model_dir=None):
self.sess = tf.Session()
self.imgs_ph, self.bn, self.output_tensors, self.pred_labels, self.pred_locs = model.model(self.sess)
total_boxes = self.pred_labels.get_shape().as_list()[1]
self.positives_ph, self.negatives_ph, self.true_labels_ph, self.true_locs_ph, self.total_loss, self.class_loss, self.loc_loss = \
model.loss(self.pred_labels, self.pred_locs, total_boxes)
out_shapes = [out.get_shape().as_list() for out in self.output_tensors]
c.out_shapes = out_shapes
c.defaults = model.default_boxes(out_shapes)
# variables in model are already initialized, so only initialize those declared after
with tf.variable_scope("optimizer"):
self.global_step = tf.Variable(0)
self.lr_ph = tf.placeholder(tf.float32)
self.optimizer = tf.train.AdamOptimizer(1e-3).minimize(self.total_loss, global_step=self.global_step)
new_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope="optimizer")
init = tf.variables_initializer(new_vars)
self.sess.run(init)
if model_dir is None:
model_dir = FLAGS.model_dir
ckpt = tf.train.get_checkpoint_state(model_dir)
self.saver = tf.train.Saver()
if ckpt and ckpt.model_checkpoint_path:
self.saver.restore(self.sess, ckpt.model_checkpoint_path)
print("restored %s" % ckpt.model_checkpoint_path)
评论列表
文章目录