def one_hot_wrapper(num_classes, loss_fn):
"""Some loss functions take one-hot labels."""
def _loss(probs, targets):
one_hot_labels = array_ops.one_hot(
math_ops.to_int32(targets), num_classes,
on_value=1., off_value=0., dtype=dtypes.float32)
return loss_fn(probs, one_hot_labels)
return _loss
评论列表
文章目录