def main():
model = config.get('config', 'model')
cachedir = utils.get_cachedir(config)
with open(os.path.join(cachedir, 'names'), 'r') as f:
names = [line.strip() for line in f]
width = config.getint(model, 'width')
height = config.getint(model, 'height')
yolo = importlib.import_module('model.' + model)
cell_width, cell_height = utils.calc_cell_width_height(config, width, height)
tf.logging.info('(width, height)=(%d, %d), (cell_width, cell_height)=(%d, %d)' % (width, height, cell_width, cell_height))
with tf.Session() as sess:
paths = [os.path.join(cachedir, profile + '.tfrecord') for profile in args.profile]
num_examples = sum(sum(1 for _ in tf.python_io.tf_record_iterator(path)) for path in paths)
tf.logging.warn('num_examples=%d' % num_examples)
image_rgb, labels = utils.data.load_image_labels(paths, len(names), width, height, cell_width, cell_height, config)
image_std = tf.image.per_image_standardization(image_rgb)
image_rgb = tf.cast(image_rgb, tf.uint8)
ph_image = tf.placeholder(image_std.dtype, [1] + image_std.get_shape().as_list(), name='ph_image')
global_step = tf.contrib.framework.get_or_create_global_step()
builder = yolo.Builder(args, config)
builder(ph_image)
variables_to_restore = slim.get_variables_to_restore()
ph_labels = [tf.placeholder(l.dtype, [1] + l.get_shape().as_list(), name='ph_' + l.op.name) for l in labels]
with tf.name_scope('total_loss') as name:
builder.create_objectives(ph_labels)
total_loss = tf.losses.get_total_loss(name=name)
tf.global_variables_initializer().run()
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess, coord)
_image_rgb, _image_std, _labels = sess.run([image_rgb, image_std, labels])
coord.request_stop()
coord.join(threads)
feed_dict = dict([(ph, np.expand_dims(d, 0)) for ph, d in zip(ph_labels, _labels)])
feed_dict[ph_image] = np.expand_dims(_image_std, 0)
logdir = utils.get_logdir(config)
assert os.path.exists(logdir)
model_path = tf.train.latest_checkpoint(logdir)
tf.logging.info('load ' + model_path)
slim.assign_from_checkpoint_fn(model_path, variables_to_restore)(sess)
tf.logging.info('global_step=%d' % sess.run(global_step))
tf.logging.info('total_loss=%f' % sess.run(total_loss, feed_dict))
_ = Drawer(sess, names, builder.model.cell_width, builder.model.cell_height, _image_rgb, _labels, builder.model, feed_dict)
plt.show()
评论列表
文章目录