def crf_loss(y, y_, transitions, nums_tags, batch_size):
tag_scores = y
nums_steps = len(tf.unstack(tag_scores, axis=1))
masks = tf.cast(tf.sign(y_), dtype=tf.float32)
lengths = tf.reduce_sum(tf.sign(y_), axis=1)
tag_ids = y_
b_id = tf.stack([[nums_tags]] * batch_size)
#e_id = tf.pack([[0]] * batch_size)
padded_tag_ids = tf.concat(axis=1, values=[b_id, tag_ids])
idx_tag_ids = tf.stack([tf.slice(padded_tag_ids, [0, i], [-1, 2]) for i in range(nums_steps)], axis=1)
tag_ids = tf.contrib.layers.one_hot_encoding(tag_ids, nums_tags)
point_score = tf.reduce_sum(tag_scores * tag_ids, axis=2)
point_score *= masks
#Save for future
#trans_score = tf.gather_nd(transitions, idx_tag_ids)
trans_sh = tf.stack(transitions.get_shape())
trans_sh = tf.cumprod(trans_sh, exclusive=True, reverse=True)
flat_tag_ids = tf.reduce_sum(trans_sh * idx_tag_ids, axis=2)
trans_score = tf.gather(tf.reshape(transitions, [-1]), flat_tag_ids)
##
#extend_mask = tf.concat(1, [tf.ones([batch_size, 1]), masks])
extend_mask = masks
trans_score *= extend_mask
target_path_score = tf.reduce_sum(point_score) + tf.reduce_sum(trans_score)
total_path_score, _, _ = Forward(tag_scores, transitions, nums_tags, lengths, batch_size)()
return - (target_path_score - total_path_score)
评论列表
文章目录