def crf_binary_score(tag_indices, sequence_lengths, transition_params):
"""Computes the binary scores of tag sequences.
Args:
tag_indices: A [batch_size, max_seq_len] matrix of tag indices.
sequence_lengths: A [batch_size] vector of true sequence lengths.
transition_params: A [num_tags, num_tags] matrix of binary potentials.
Returns:
binary_scores: A [batch_size] vector of binary scores.
"""
# Get shape information.
num_tags = transition_params.get_shape()[0]
num_transitions = array_ops.shape(tag_indices)[1] - 1
# Truncate by one on each side of the sequence to get the start and end
# indices of each transition.
start_tag_indices = array_ops.slice(tag_indices, [0, 0],
[-1, num_transitions])
end_tag_indices = array_ops.slice(tag_indices, [0, 1], [-1, num_transitions])
# Encode the indices in a flattened representation.
flattened_transition_indices = start_tag_indices * num_tags + end_tag_indices
flattened_transition_params = array_ops.reshape(transition_params, [-1])
# Get the binary scores based on the flattened representation.
binary_scores = array_ops.gather(flattened_transition_params,
flattened_transition_indices)
masks = _lengths_to_masks(sequence_lengths, array_ops.shape(tag_indices)[1])
truncated_masks = array_ops.slice(masks, [0, 1], [-1, -1])
binary_scores = math_ops.reduce_sum(binary_scores * truncated_masks, 1)
return binary_scores
评论列表
文章目录