def max_oracle(scores,
y_truth):
n_classes = scores.shape[1]
t_range = T.arange(y_truth.shape[0])
# classification loss for any combination
losses = 1. - T.extra_ops.to_one_hot(y_truth, n_classes)
# get max score for each sample
y_star = T.argmax(scores + losses, axis=1)
# compute classification loss for batch
delta = losses[t_range, y_star].sum()
return y_star, delta
评论列表
文章目录