inputs.py 文件源码

python
阅读 38 收藏 0 点赞 0 评论 0

项目:tf_classification 作者: visipedia 项目源码 文件源码
def create_training_batch(serialized_example, cfg, add_summaries):

    features = get_region_data(serialized_example, cfg, fetch_ids=False,
                               fetch_labels=True, fetch_text_labels=False)

    original_image = features['image']
    bboxes = features['bboxes']
    labels = features['labels']

    distorted_inputs = get_distorted_inputs(original_image, bboxes, cfg, add_summaries)

    distorted_inputs = tf.subtract(distorted_inputs, 0.5)
    distorted_inputs = tf.multiply(distorted_inputs, 2.0)

    names = ('inputs', 'labels')
    tensors = [distorted_inputs, labels]
    return [names, tensors]
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号