def create_classification_batch(serialized_example, cfg, add_summaries):
features = get_region_data(serialized_example, cfg, fetch_ids=True,
fetch_labels=False, fetch_text_labels=False)
original_image = features['image']
bboxes = features['bboxes']
ids = features['ids']
distorted_inputs = get_distorted_inputs(original_image, bboxes, cfg, add_summaries)
distorted_inputs = tf.subtract(distorted_inputs, 0.5)
distorted_inputs = tf.multiply(distorted_inputs, 2.0)
names = ('inputs', 'ids')
tensors = [distorted_inputs, ids]
return [names, tensors]
评论列表
文章目录