def compute_loss(self, scores, scores_no_dropout, labels):
loss = tf.constant(0.0)
if self.viterbi:
zero_elements = tf.equal(self.sequence_lengths, tf.zeros_like(self.sequence_lengths))
count_zeros_per_row = tf.reduce_sum(tf.to_int32(zero_elements), axis=1)
flat_sequence_lengths = tf.add(tf.reduce_sum(self.sequence_lengths, 1),
tf.scalar_mul(2, count_zeros_per_row))
log_likelihood, transition_params = tf.contrib.crf.crf_log_likelihood(scores, labels, flat_sequence_lengths,
transition_params=self.transition_params)
loss += tf.reduce_mean(-log_likelihood)
else:
if self.which_loss == "mean" or self.which_loss == "block":
losses = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=scores, labels=labels)
masked_losses = tf.multiply(losses, self.input_mask)
loss += tf.div(tf.reduce_sum(masked_losses), tf.reduce_sum(self.input_mask))
elif self.which_loss == "sum":
losses = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=scores, labels=labels)
masked_losses = tf.multiply(losses, self.input_mask)
loss += tf.reduce_sum(masked_losses)
elif self.which_loss == "margin":
# todo put into utils
# also todo put idx-into-3d as sep func
flat_labels = tf.reshape(labels, [-1])
batch_offsets = tf.multiply(tf.range(self.batch_size), self.max_seq_len * self.num_classes)
repeated_batch_offsets = tf_utils.repeat(batch_offsets, self.max_seq_len)
tok_offsets = tf.multiply(tf.range(self.max_seq_len), self.num_classes)
tiled_tok_offsets = tf.tile(tok_offsets, [self.batch_size])
indices = tf.add(tf.add(repeated_batch_offsets, tiled_tok_offsets), flat_labels)
# scores w/ true label set to -inf
sparse = tf.sparse_to_dense(indices, [self.batch_size * self.max_seq_len * self.num_classes], np.NINF)
loss_augmented_flat = tf.add(tf.reshape(scores, [-1]), sparse)
loss_augmented = tf.reshape(loss_augmented_flat, [self.batch_size, self.max_seq_len, self.num_classes])
# maxes excluding true label
max_scores = tf.reshape(tf.reduce_max(loss_augmented, [-1]), [-1])
sparse = tf.sparse_to_dense(indices, [self.batch_size * self.max_seq_len * self.num_classes],
-self.margin)
loss_augmented_flat = tf.add(tf.reshape(scores, [-1]), sparse)
label_scores = tf.gather(loss_augmented_flat, indices)
# margin + max_logit - correct_logit == max_logit - (correct - margin)
max2_diffs = tf.subtract(max_scores, label_scores)
mask = tf.reshape(self.input_mask, [-1])
loss += tf.reduce_mean(tf.multiply(mask, tf.nn.relu(max2_diffs)))
loss += self.l2_penalty * self.l2_loss
drop_loss = tf.nn.l2_loss(tf.subtract(scores, scores_no_dropout))
loss += self.drop_penalty * drop_loss
return loss
评论列表
文章目录