trainer.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:textboxes 作者: shinjayne 项目源码 文件源码
def __init__(self, model_dir=None):
        self.sess = tf.Session()

        self.imgs_ph, self.bn, self.output_tensors, self.pred_labels, self.pred_locs = model.model(self.sess)

        total_boxes = self.pred_labels.get_shape().as_list()[1]
        self.positives_ph, self.negatives_ph, self.true_labels_ph, self.true_locs_ph, self.total_loss, self.class_loss, self.loc_loss = \
            model.loss(self.pred_labels, self.pred_locs, total_boxes)

        out_shapes = [out.get_shape().as_list() for out in self.output_tensors]

        c.out_shapes = out_shapes

        c.defaults = model.default_boxes(out_shapes)
        # variables in model are already initialized, so only initialize those declared after
        with tf.variable_scope("optimizer"):
            self.global_step = tf.Variable(0)
            self.lr_ph = tf.placeholder(tf.float32)
            self.optimizer = tf.train.AdamOptimizer(1e-3).minimize(self.total_loss, global_step=self.global_step)
        new_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope="optimizer")
        init = tf.variables_initializer(new_vars)
        self.sess.run(init)

        if model_dir is None:
            model_dir = FLAGS.model_dir

        ckpt = tf.train.get_checkpoint_state(model_dir)
        self.saver = tf.train.Saver()

        if ckpt and ckpt.model_checkpoint_path:
            self.saver.restore(self.sess, ckpt.model_checkpoint_path)
            print("restored %s" % ckpt.model_checkpoint_path)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号