metrics.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:tefla 作者: openAGI 项目源码 文件源码
def create_metric_ops(self, _inputs, labels, predictions):
        """Creates (value, update_op) tensors
        """
        with tf.variable_scope(self._name):

            # Join tokens into single strings
            predictions_flat = tf.reduce_join(
                predictions["predicted_tokens"], 1, separator=self._separator)
            labels_flat = tf.reduce_join(
                labels["target_tokens"], 1, separator=self._separator)

            sources_value, sources_update = accumulate_strings(
                values=predictions_flat, name="sources")
            targets_value, targets_update = accumulate_strings(
                values=labels_flat, name="targets")

            metric_value = tf.py_func(
                func=self._py_func,
                inp=[sources_value, targets_value],
                Tout=tf.float32,
                name="value")

        with tf.control_dependencies([sources_update, targets_update]):
            update_op = tf.identity(metric_value, name="update_op")

        return metric_value, update_op
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号