def evaluate_one_image():
train_dir = '/Users/yangyibo/GitWork/pythonLean/AI/????/testImg/'
# ???????????
train, train_label = input_data.get_files(train_dir)
image_array = get_one_image(train)
with tf.Graph().as_default():
BATCH_SIZE = 1 # ????????? ??batch ???1
N_CLASSES = 2 # 2????????1?0? ?? ?0?1???????
# ??????
image = tf.cast(image_array, tf.float32)
# ?????
image = tf.image.per_image_standardization(image)
# ???????? [208, 208, 3] ???????? ????4D ??? tensor
image = tf.reshape(image, [1, 208, 208, 3])
logit = model.inference(image, BATCH_SIZE, N_CLASSES)
# ?? inference ????????????????????softmax ??
logit = tf.nn.softmax(logit)
# ??????????????????? placeholder
x = tf.placeholder(tf.float32, shape=[208, 208, 3])
# ?????????
logs_train_dir = '/Users/yangyibo/GitWork/pythonLean/AI/????/saveNet/'
# ??saver
saver = tf.train.Saver()
with tf.Session() as sess:
print("???????????????")
# ??????sess ?
ckpt = tf.train.get_checkpoint_state(logs_train_dir)
if ckpt and ckpt.model_checkpoint_path:
global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]
saver.restore(sess, ckpt.model_checkpoint_path)
print('??????, ?????? %s' % global_step)
else:
print('???????????????')
# ??????????
prediction = sess.run(logit, feed_dict={x: image_array})
# ??????????????
max_index = np.argmax(prediction)
if max_index==0:
print('???? %.6f' %prediction[:, 0])
else:
print('???? %.6f' %prediction[:, 1])
# ??
评论列表
文章目录