def compute_loss(self):
"""
??loss
Return:
loss: scalar
"""
if not self._use_crf:
labels = tf.reshape(
tf.contrib.layers.one_hot_encoding(
tf.reshape(self.input_label_ph, [-1]), num_classes=self._nb_classes),
shape=[-1, self._sequence_length, self._nb_classes])
cross_entropy = -tf.reduce_sum(labels * tf.log(self.logits), axis=2)
mask = tf.sign(tf.reduce_max(tf.abs(labels), axis=2))
cross_entropy_masked = tf.reduce_sum(
cross_entropy*mask, axis=1) / tf.cast(self.sequence_actual_length, tf.float32)
return tf.reduce_mean(cross_entropy_masked)
else:
log_likelihood, self.transition_params = tf.contrib.crf.crf_log_likelihood(
self.logits, self.input_label_ph, self.sequence_actual_length)
return tf.reduce_mean(-log_likelihood)
评论列表
文章目录