def predict(self, X):
if type(self.model.input_shape) is tuple:
X = np.array(X)
if len(self.model.input_shape) == 2:
X = X.reshape((X.shape[0], -1))
else:
raise LanguageClassifierException('Mult-input models are not supported yet')
predictions = self.model.predict(X, verbose=True, batch_size=32)
if (len(predictions.shape) > 1) and (1 not in predictions.shape):
predictions = predictions.argmax(axis=-1)
else:
predictions = 1 * (predictions > 0.5).ravel()
return predictions
评论列表
文章目录