def model_argmax(sess, x, predictions, samples, feed=None):
"""
Helper function that computes the current class prediction
:param sess: TF session
:param x: the input placeholder
:param predictions: the model's symbolic output
:param samples: numpy array with input samples (dims must match x)
:param feed: An optional dictionary that is appended to the feeding
dictionary before the session runs. Can be used to feed
the learning phase of a Keras model for instance.
:return: the argmax output of predictions, i.e. the current predicted class
"""
feed_dict = {x: samples}
if feed is not None:
feed_dict.update(feed)
probabilities = sess.run(predictions, feed_dict)
if samples.shape[0] == 1:
return np.argmax(probabilities)
else:
return np.argmax(probabilities, axis=1)
评论列表
文章目录