def create_training_batch(serialized_example, cfg, add_summaries):
features = get_region_data(serialized_example, cfg, fetch_ids=False,
fetch_labels=True, fetch_text_labels=False)
original_image = features['image']
bboxes = features['bboxes']
labels = features['labels']
distorted_inputs = get_distorted_inputs(original_image, bboxes, cfg, add_summaries)
distorted_inputs = tf.subtract(distorted_inputs, 0.5)
distorted_inputs = tf.multiply(distorted_inputs, 2.0)
names = ('inputs', 'labels')
tensors = [distorted_inputs, labels]
return [names, tensors]
评论列表
文章目录