metrics.py 文件源码

python
阅读 31 收藏 0 点赞 0 评论 0

项目:tflearn 作者: tflearn 项目源码 文件源码
def build(self, predictions, targets, inputs=None):
        """ Prints the number of each kind of prediction """
        self.built = True
        pshape = predictions.get_shape()
        self.inner_metric.build(predictions, targets, inputs)

        with tf.name_scope(self.name):
            if len(pshape) == 1 or (len(pshape) == 2 and int(pshape[1]) == 1):
                self.name = self.name or "binary_prediction_counts"
                y, idx, count = tf.unique_with_counts(tf.argmax(predictions))
                self.tensor = tf.Print(self.inner_metric, [y, count], name=self.inner_metric.name)
            else:
                self.name = self.name or "categorical_prediction_counts"
                y, idx, count = tf.unique_with_counts(tf.argmax(predictions, dimension=1))
                self.tensor = tf.Print(self.inner_metric.tensor, [y, count], name=self.inner_metric.name)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号