def predict(self, model_path, x_test):
"""
Uses the model to create a prediction for the given data
:param model_path: path to the model checkpoint to restore
:param x_test: Data to predict on. Shape [n, nx, ny, channels]
:returns prediction: The unet prediction Shape [n, px, py, labels] (px=nx-self.offset/2)
"""
init = tf.global_variables_initializer()
with tf.Session() as sess:
# Initialize variables
sess.run(init)
# Restore model weights from previously saved model
self.restore(sess, model_path)
y_dummy = np.empty((x_test.shape[0], x_test.shape[1], x_test.shape[2], self.n_class))
prediction = sess.run(self.predicter, feed_dict={self.x: x_test, self.y: y_dummy, self.keep_prob: 1.})
return prediction
评论列表
文章目录