def predict(
self, x=None, input_fn=None, axis=None, batch_size=None, outputs=None,
as_iterable=True):
"""Returns predictions for given features.
Args:
x: features.
input_fn: Input function. If set, x must be None.
axis: Axis on which to argmax (for classification).
Last axis is used by default.
batch_size: Override default batch size.
outputs: list of `str`, name of the output to predict.
If `None`, returns all.
as_iterable: If True, return an iterable which keeps yielding predictions
for each example until inputs are exhausted. Note: The inputs must
terminate if you want the iterable to terminate (e.g. be sure to pass
num_epochs=1 if you are using something like read_batch_features).
Returns:
Numpy array of predicted classes or regression values (or an iterable of
predictions if as_iterable is True).
"""
results = self._estimator.predict(
x=x, input_fn=input_fn, batch_size=batch_size, outputs=outputs,
as_iterable=as_iterable)
predict_name = (eval_metrics.INFERENCE_PROB_NAME if self.params.regression
else eval_metrics.INFERENCE_PRED_NAME)
if as_iterable:
return (x[predict_name] for x in results)
else:
return results[predict_name]
评论列表
文章目录