lambdarandomforest.py 文件源码

python
阅读 30 收藏 0 点赞 0 评论 0

项目:rankpy 作者: dmitru 项目源码 文件源码
def predict(self, queries, n_jobs=1):
        ''' 
        Predict the ranking score for each individual document of the given queries.

        n_jobs: int, optional (default is 1)
            The number of working threads that will be spawned to compute
            the ranking scores. If -1, the current number of CPUs will be used.
        '''
        if self.trained is False:
            raise ValueError('the model has not been trained yet')

        predictions = np.zeros(queries.document_count(), dtype=np.float64)

        n_jobs = max(1, min(n_jobs if n_jobs >= 0 else n_jobs + cpu_count() + 1, queries.document_count()))

        indices = np.linspace(0, queries.document_count(), n_jobs + 1).astype(np.intc)

        Parallel(n_jobs=n_jobs, backend="threading")(delayed(parallel_helper, check_pickle=False)
                (LambdaRandomForest, '_LambdaRandomForest__predict', self.estimators,
                 queries.feature_vectors[indices[i]:indices[i + 1]],
                 predictions[indices[i]:indices[i + 1]]) for i in range(indices.size - 1))

        predictions /= len(self.estimators)

        return predictions
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号