def predict(self, fetches=None, feed_dict=None): # pylint: disable=arguments-differ
""" Get predictions on the data provided
Parameters
----------
fetches : tuple, list
a sequence of `tf.Operation` and/or `tf.Tensor` to calculate
feed_dict : dict
input data, where key is a placeholder name and value is a numpy value
Returns
-------
Calculated values of tensors in `fetches` in the same structure
Notes
-----
The only difference between `predict` and `train` is that `train` also executes a `train_step` operation
which involves calculating and applying gradients and thus chainging model weights.
See also
--------
`Tensorflow Session run <https://www.tensorflow.org/api_docs/python/tf/Session#run>`_
"""
with self.graph.as_default():
_feed_dict = self._fill_feed_dict(feed_dict, is_training=False)
_fetches = self._fill_fetches(fetches, default='predictions')
output = self.session.run(_fetches, _feed_dict)
return self._fill_output(output, _fetches)
评论列表
文章目录