def rank_crossentropy_loss(kwargs=None):
neg_num = 1
if isinstance(kwargs, dict) and 'neg_num' in kwargs:
neg_num = kwargs['neg_num']
def _cross_entropy_loss(y_true, y_pred):
y_pos_logits = Lambda(lambda a: a[::(neg_num+1), :], output_shape= (1,))(y_pred)
y_pos_labels = Lambda(lambda a: a[::(neg_num+1), :], output_shape= (1,))(y_true)
logits_list, labels_list = [y_pos_logits], [y_pos_labels]
for i in range(neg_num):
y_neg_logits = Lambda(lambda a: a[(i+1)::(neg_num+1), :], output_shape= (1,))(y_pred)
y_neg_labels = Lambda(lambda a: a[(i+1)::(neg_num+1), :], output_shape= (1,))(y_true)
logits_list.append(y_neg_logits)
labels_list.append(y_neg_labels)
logits = tf.concat(logits_list, axis=1)
labels = tf.concat(labels_list, axis=1)
return tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=labels, logits=logits))
return _cross_entropy_loss
评论列表
文章目录