retrieve.py 文件源码

python
阅读 35 收藏 0 点赞 0 评论 0

项目:fast-image-retrieval 作者: xueeinstein 项目源码 文件源码
def retrieve_image(target_image, model_file, deploy_file, imagemean_file,
                   threshold=1):
    model_dir = os.path.dirname(model_file)
    image_files = np.load(os.path.join(model_dir, 'image_files.npy'))
    fc7_feature_mat = np.load(os.path.join(model_dir, 'fc7_features.npy'))
    latent_feature_file = os.path.join(model_dir, 'latent_features.npy')
    latent_feature_mat = np.load(latent_feature_file)

    candidates = []
    dist = 0
    for layer, mat in layer_features(['latent', 'fc7'], model_file,
                                     deploy_file, imagemean_file,
                                     [target_image], show_pred=True):
        if layer == 'latent':
            # coarse-level search
            mat = binary_hash_codes(mat)
            mat = mat * np.ones((latent_feature_mat.shape[0], 1))
            dis_mat = np.abs(mat - latent_feature_mat)
            hamming_dis = np.sum(dis_mat, axis=1)
            distance_file = os.path.join(model_dir, 'hamming_dis.npy')
            np.save(distance_file, hamming_dis)
            candidates = np.where(hamming_dis < threshold)[0]

        if layer == 'fc7':
            # fine-level search
            kdt = KDTree(fc7_feature_mat[candidates], metric='euclidean')
            k = 6

            if not candidates.shape[0] > 6:
                k = candidates.shape[0]

            dist, idxs = kdt.query(mat, k=k)
            candidates = candidates[idxs]
            print(dist)

    return image_files[candidates][0], dist[0]
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号