utils.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:text-classification-theano 作者: syuoni 项目源码 文件源码
def predict(self, estimator_args, with_prob=False):
        if self.voting == 'hard':
            # sub_res -> (estimator_dim, batch_dim)
            sub_res = np.array([estimator.predict_func(*estimator_args) for estimator in self.estimators], 
                               dtype=theano.config.floatX)
            mode_res, count = mode(sub_res, axis=0)
            return (mode_res[0], count[0]/self.n_estimators) if with_prob else mode_res[0]
        else:
            # sub_res -> (estimator_dim, batch_dim, target_dim)
            sub_res = np.array([estimator.predict_prob_func(*estimator_args) for estimator in self.estimators], 
                               dtype=theano.config.floatX)
            sub_res = sub_res.mean(axis=0)
            max_res = np.argmax(sub_res, axis=1)
            mean_prob = sub_res[np.arange(sub_res.shape[0]), max_res]
            return (max_res, mean_prob) if with_prob else max_res
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号