def apply_loss(labels, net_out, loss_fn, weight_decay, is_training,
return_mean_loss=False, mask_voids=True):
'''Applies the user-specified loss function and returns the loss
Note:
SoftmaxCrossEntropyWithLogits expects labels NOT to be one-hot
and net_out to be one-hot.
'''
cfg = gflags.cfg
if mask_voids and len(cfg.void_labels):
# TODO Check this
print('Masking the void labels')
mask = tf.not_equal(labels, cfg.void_labels)
labels *= tf.cast(mask, 'int32') # void_class --> 0 (random class)
# Train loss
loss = loss_fn(labels=labels,
logits=tf.reshape(net_out, [-1, cfg.nclasses]))
mask = tf.cast(mask, 'float32')
loss *= mask
else:
# Train loss
loss = loss_fn(labels=labels,
logits=tf.reshape(net_out, [-1, cfg.nclasses]))
if is_training:
loss = apply_l2_penalty(loss, weight_decay)
# Return the mean loss (over pixels *and* batches)
if return_mean_loss:
if mask_voids and len(cfg.void_labels):
return tf.reduce_sum(loss) / tf.reduce_sum(mask)
else:
return tf.reduce_mean(loss)
else:
return loss
评论列表
文章目录