tensor_forest.py 文件源码

python
阅读 21 收藏 0 点赞 0 评论 0

项目:lsdc 作者: febert 项目源码 文件源码
def _get_loss(self, features, labels, data_spec=None):
    """Constructs, caches, and returns the inference-based loss."""
    if self._loss is not None:
      return self._loss

    def _average_loss():
      probs = self.inference_graph(features, data_spec=data_spec)
      return math_ops.reduce_sum(self.loss_fn(
          probs, labels)) / math_ops.to_float(
              array_ops.shape(features)[0])

    self._loss = control_flow_ops.cond(
        self.average_size() > 0, _average_loss,
        lambda: constant_op.constant(sys.maxsize, dtype=dtypes.float32))

    return self._loss
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号