imagenet_utils.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:DeepGold 作者: scottvallance 项目源码 文件源码
def decode_predictions(preds):
    global CLASS_INDEX
    assert len(preds.shape) == 2 and preds.shape[1] == 1000
    if CLASS_INDEX is None:
        fpath = get_file('imagenet_class_index.json',
                         CLASS_INDEX_PATH,
                         cache_subdir='models')
        CLASS_INDEX = json.load(open(fpath))
    indices = np.argmax(preds, axis=-1)
    results = []
    for i in indices:
        results.append(CLASS_INDEX[str(i)])
    return results
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号