tests.py 文件源码

python
阅读 20 收藏 0 点赞 0 评论 0

项目:satellite-image-object-detection 作者: marcbelmont 项目源码 文件源码
def test_inference(self):
        with self.test_session() as sess:
            # Create model
            net = create_model(tf.zeros([1, IMG_SIZE, IMG_SIZE, 3]), .1)
            net_ph = tf.placeholder(tf.float32, shape=net.shape)
            infer = inference(net_ph, .1)

            # Test inference results
            output = np.zeros(net.shape).astype(np.float32)
            output[0, 1, 1, :5] = [.84, .4, .68, .346, .346]
            output[0, 1, 1, 10] = .3  # class
            output[0, 2, 2, :5] = [.84, .4, .68, .346, .346]
            output[0, 2, 2, 11] = .03  # class
            result = sess.run([infer], feed_dict={net_ph: output})
            p_box, p_classes, confidence, mask = result[0]

            # Test
            self.assertEqual(mask[0, 1, 1], 1)
            self.assertEqual(p_classes, 5)
            self.assertEqual(confidence, .3 * .84)
            self.assertListEqual(
                [round(x) for x in p_box.tolist()[0]],
                [50, 60, 30, 30],)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号