models.py 文件源码

python
阅读 21 收藏 0 点赞 0 评论 0

项目:cervantes 作者: textclf 项目源码 文件源码
def predict(self, X):
        if type(self.model.input_shape) is tuple:
            X = np.array(X)
            if len(self.model.input_shape) == 2:
                X = X.reshape((X.shape[0], -1))
        else:
            raise LanguageClassifierException('Mult-input models are not supported yet')

        predictions = self.model.predict(X, verbose=True, batch_size=32)
        if (len(predictions.shape) > 1) and (1 not in predictions.shape):
            predictions = predictions.argmax(axis=-1)
        else:
            predictions = 1 * (predictions > 0.5).ravel()
        return predictions
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号