base.py 文件源码

python
阅读 29 收藏 0 点赞 0 评论 0

项目:dataset 作者: analysiscenter 项目源码 文件源码
def output(self, inputs, ops=None, prefix=None, **kwargs):
        """ Add output operations to a model graph, like predictions, quality metrics, etc.

        Parameters
        ----------
        inputs : tf.Tensor or a sequence of tf.Tensors
            input tensors

        ops : a sequence of str
            operation names::
            - 'sigmoid' - add ``sigmoid(inputs)``
            - 'proba' - add ``softmax(inputs)``
            - 'labels' - add ``argmax(inputs)``
            - 'accuracy' - add ``mean(predicted_labels == true_labels)``

        prefix : a sequence of str
            a prefix for each input if there are multiple inputs

        Raises
        ------
        ValueError if the number of outputs does not equal to the number of prefixes
        TypeError if inputs is not a Tensor or a sequence of Tensors
        """
        kwargs = self.fill_params('output', **kwargs)
        predictions_op = self.pop('predictions', kwargs, default=None)

        if ops is None:
            ops = []
        elif not isinstance(ops, (list, tuple)):
            ops = [ops]

        if not isinstance(inputs, (tuple, list)):
            inputs = [inputs]
            prefix = prefix or 'output'
            prefix = [prefix]

        if len(inputs) != len(prefix):
            raise ValueError('Each output in multiple output models should have its own prefix')

        for i, tensor in enumerate(inputs):
            if not isinstance(tensor, tf.Tensor):
                raise TypeError("Network output is expected to be a Tensor, but given {}".format(type(inputs)))

            current_prefix = prefix[i]
            if current_prefix:
                ctx = tf.variable_scope(current_prefix)
                ctx.__enter__()
            else:
                ctx = None
            attr_prefix = current_prefix + '_' if current_prefix else ''

            pred_prefix = '' if len(inputs) == 1 else attr_prefix
            self._add_output_op(tensor, predictions_op, 'predictions', pred_prefix, **kwargs)
            for oper in ops:
                self._add_output_op(tensor, oper, oper, attr_prefix, **kwargs)

            if ctx:
                ctx.__exit__(None, None, None)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号