evaluateCatOrDog.py 文件源码

python
阅读 18 收藏 0 点赞 0 评论 0

项目:pythonLean 作者: 527515025 项目源码 文件源码
def evaluate_one_image():  
    train_dir = '/Users/yangyibo/GitWork/pythonLean/AI/????/testImg/'  
    # ???????????
    train, train_label = input_data.get_files(train_dir) 
    image_array = get_one_image(train)  

    with tf.Graph().as_default():  
        BATCH_SIZE = 1  # ????????? ??batch ???1
        N_CLASSES = 2  # 2????????1?0? ?? ?0?1???????
        # ??????
        image = tf.cast(image_array, tf.float32)  
        # ?????
        image = tf.image.per_image_standardization(image)
        # ???????? [208, 208, 3] ???????? ????4D  ??? tensor
        image = tf.reshape(image, [1, 208, 208, 3])  
        logit = model.inference(image, BATCH_SIZE, N_CLASSES)  
        # ?? inference ????????????????????softmax ??
        logit = tf.nn.softmax(logit)  

        # ??????????????????? placeholder
        x = tf.placeholder(tf.float32, shape=[208, 208, 3])  

        # ?????????
        logs_train_dir = '/Users/yangyibo/GitWork/pythonLean/AI/????/saveNet/'   
        # ??saver 
        saver = tf.train.Saver()  

        with tf.Session() as sess:  

            print("???????????????")
            # ??????sess ? 
            ckpt = tf.train.get_checkpoint_state(logs_train_dir)  
            if ckpt and ckpt.model_checkpoint_path:  
                global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]  
                saver.restore(sess, ckpt.model_checkpoint_path)  
                print('??????, ?????? %s' % global_step)  
            else:  
                print('???????????????')  
            # ??????????
            prediction = sess.run(logit, feed_dict={x: image_array})
            # ??????????????
            max_index = np.argmax(prediction)  
            if max_index==0:  
                print('???? %.6f' %prediction[:, 0])  
            else:  
                print('???? %.6f' %prediction[:, 1]) 
# ??
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号