def _GetLossFn(name):
'''
Helper function for selecting loss function
name: The name of the loss function
return: A handle for a loss function LF(YH, Y)
'''
return {'cos': lambda YH, Y : tf.losses.cosine_distance(Y, YH), 'hinge': lambda YH, Y : tf.losses.hinge_loss(Y, YH),
'l1': lambda YH, Y : tf.losses.absolute_difference(Y, YH), 'l2': lambda YH, Y : tf.squared_difference(Y, YH),
'log': lambda YH, Y : tf.losses.log_loss(Y, YH),
'sgce': lambda YH, Y : tf.nn.sigmoid_cross_entropy_with_logits(labels = Y, logits = YH),
'smce': lambda YH, Y : tf.nn.softmax_cross_entropy_with_logits(labels = Y, logits = YH)}.get(name)
评论列表
文章目录