demo_detect.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:yolo-tf 作者: ruiminshen 项目源码 文件源码
def main():
    model = config.get('config', 'model')
    cachedir = utils.get_cachedir(config)
    with open(os.path.join(cachedir, 'names'), 'r') as f:
        names = [line.strip() for line in f]
    width = config.getint(model, 'width')
    height = config.getint(model, 'height')
    yolo = importlib.import_module('model.' + model)
    cell_width, cell_height = utils.calc_cell_width_height(config, width, height)
    tf.logging.info('(width, height)=(%d, %d), (cell_width, cell_height)=(%d, %d)' % (width, height, cell_width, cell_height))
    with tf.Session() as sess:
        paths = [os.path.join(cachedir, profile + '.tfrecord') for profile in args.profile]
        num_examples = sum(sum(1 for _ in tf.python_io.tf_record_iterator(path)) for path in paths)
        tf.logging.warn('num_examples=%d' % num_examples)
        image_rgb, labels = utils.data.load_image_labels(paths, len(names), width, height, cell_width, cell_height, config)
        image_std = tf.image.per_image_standardization(image_rgb)
        image_rgb = tf.cast(image_rgb, tf.uint8)
        ph_image = tf.placeholder(image_std.dtype, [1] + image_std.get_shape().as_list(), name='ph_image')
        global_step = tf.contrib.framework.get_or_create_global_step()
        builder = yolo.Builder(args, config)
        builder(ph_image)
        variables_to_restore = slim.get_variables_to_restore()
        ph_labels = [tf.placeholder(l.dtype, [1] + l.get_shape().as_list(), name='ph_' + l.op.name) for l in labels]
        with tf.name_scope('total_loss') as name:
            builder.create_objectives(ph_labels)
            total_loss = tf.losses.get_total_loss(name=name)
        tf.global_variables_initializer().run()
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess, coord)
        _image_rgb, _image_std, _labels = sess.run([image_rgb, image_std, labels])
        coord.request_stop()
        coord.join(threads)
        feed_dict = dict([(ph, np.expand_dims(d, 0)) for ph, d in zip(ph_labels, _labels)])
        feed_dict[ph_image] = np.expand_dims(_image_std, 0)
        logdir = utils.get_logdir(config)
        assert os.path.exists(logdir)
        model_path = tf.train.latest_checkpoint(logdir)
        tf.logging.info('load ' + model_path)
        slim.assign_from_checkpoint_fn(model_path, variables_to_restore)(sess)
        tf.logging.info('global_step=%d' % sess.run(global_step))
        tf.logging.info('total_loss=%f' % sess.run(total_loss, feed_dict))
        _ = Drawer(sess, names, builder.model.cell_width, builder.model.cell_height, _image_rgb, _labels, builder.model, feed_dict)
        plt.show()
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号