retrain.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:Embarrassingly-Parallel-Image-Classification 作者: Azure 项目源码 文件源码
def get_dataset(dataset_name, dataset_dir, image_count, class_count, split_name):
    slim = tf.contrib.slim
    items_to_descriptions = {'image': 'A color image.',
                             'label': 'An integer in range(0, class_count)'}
    file_pattern = os.path.join(dataset_dir, '{}_{}_*.tfrecord'.format(dataset_name, split_name))
    reader = tf.TFRecordReader
    keys_to_features = {'image/encoded': tf.FixedLenFeature((), tf.string, default_value=''),
                        'image/format': tf.FixedLenFeature((), tf.string, default_value='png'),
                        'image/class/label': tf.FixedLenFeature([], tf.int64,
                                                                default_value=tf.zeros([], dtype=tf.int64))}
    items_to_handlers = {'image': slim.tfexample_decoder.Image(),
                         'label': slim.tfexample_decoder.Tensor('image/class/label')}
    decoder = slim.tfexample_decoder.TFExampleDecoder(keys_to_features, items_to_handlers)
    labels_to_names = read_label_file(dataset_dir)
    return(slim.dataset.Dataset(data_sources=file_pattern,
                                reader=reader,
                                decoder=decoder,
                                num_samples=image_count,
                                items_to_descriptions=items_to_descriptions,
                                num_classes=class_count,
                                labels_to_names=labels_to_names,
                                shuffle=True))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号