def _build_metric(self, model: 'code.model.abstract.Model') -> tf.Tensor:
with tf.name_scope(None, self.metric_name,
values=[self.dataset.source, self.dataset.target]):
x = self.dataset.source
y = self.dataset.target
length = self.dataset.length
# build mask
mask = tf.cast(tf.not_equal(y, tf.zeros_like(y)), tf.float32)
# create masked error tensor
errors = tf.not_equal(
model.inference_model(x, length, reuse=True), y
)
errors = tf.cast(errors, tf.float32) * mask # mask errors
# tf.sum(mask) is the number of unmasked elements
return tf.reduce_sum(errors) / tf.reduce_sum(mask)
misclassification_rate.py 文件源码
python
阅读 27
收藏 0
点赞 0
评论 0
评论列表
文章目录