prediction_v2.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:tefla 作者: openAGI 项目源码 文件源码
def _real_predict(self, X, xform=None, crop_bbox=None):
        tic = time.time()
        img_orig = data.load_image(X, preprocessor=self.preprocessor)
        img_orig = np.asarray(img_orig.transpose(1, 2, 0), dtype=np.uint8)
        X = data.load_image(X, preprocessor=self.preprocessor)
        X = self.standardizer(X, False)
        X = X.transpose(1, 2, 0)
        X = np.expand_dims(X, 0)
        raw_output_up = tf.nn.softmax(self.predictions)
        raw_output_up = tf.py_func(
            dense_crf, [raw_output_up, tf.expand_dims(img_orig, axis=0), self.num_classes], tf.float32)
        raw_output_up = tf.argmax(raw_output_up, dimension=3)
        predictions = self.sess.run(
            raw_output_up, {self.inputs: X})
        predictions = predictions.transpose(0, 2, 1)
        print('took %6.1f seconds' % (time.time() - tic))
        return predictions
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号