def loss_crf(self):
"""
CRF based loss.
:return: loss
"""
# Reshaping seq_len tensor [seq_len, 1]
seq_length_reshaped = tf.reshape(self.x_tokens_len, [tf.shape(self.x_tokens_len)[0], -1])
# Computing loss by scanning mini-batch tensor
out = tf.scan(self.loss_crf_scan, [self.prediction,
seq_length_reshaped,
self.y], back_prop=True, infer_shape=True, initializer=0.0)
# Division by batch_size
loss_crf = tf.divide(tf.reduce_sum(out), tf.cast(tf.shape(self.x_tokens)[0], dtype=tf.float32))
return loss_crf
评论列表
文章目录