misclassification_rate.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:master-thesis 作者: AndreasMadsen 项目源码 文件源码
def _build_metric(self, model: 'code.model.abstract.Model') -> tf.Tensor:
        with tf.name_scope(None, self.metric_name,
                           values=[self.dataset.source, self.dataset.target]):
            x = self.dataset.source
            y = self.dataset.target
            length = self.dataset.length

            # build mask
            mask = tf.cast(tf.not_equal(y, tf.zeros_like(y)), tf.float32)

            # create masked error tensor
            errors = tf.not_equal(
                model.inference_model(x, length, reuse=True), y
            )
            errors = tf.cast(errors, tf.float32) * mask  # mask errors

            # tf.sum(mask) is the number of unmasked elements
            return tf.reduce_sum(errors) / tf.reduce_sum(mask)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号