def loss(loss_value):
"""Calculates aggregated mean loss."""
total_loss = tf.Variable(0.0, False)
loss_count = tf.Variable(0, False)
total_loss_update = tf.assign_add(total_loss, loss_value)
loss_count_update = tf.assign_add(loss_count, 1)
loss_op = total_loss / tf.cast(loss_count, tf.float32)
return [total_loss_update, loss_count_update], loss_op
评论列表
文章目录