def _get_loss(self, features, labels, data_spec=None):
"""Constructs, caches, and returns the inference-based loss."""
if self._loss is not None:
return self._loss
def _average_loss():
probs = self.inference_graph(features, data_spec=data_spec)
return math_ops.reduce_sum(self.loss_fn(
probs, labels)) / math_ops.to_float(
array_ops.shape(features)[0])
self._loss = control_flow_ops.cond(
self.average_size() > 0, _average_loss,
lambda: constant_op.constant(sys.maxsize, dtype=dtypes.float32))
return self._loss
评论列表
文章目录