def _lengths_to_masks(lengths, max_length):
"""Creates a binary matrix that can be used to mask away padding.
Args:
lengths: A vector of integers representing lengths.
max_length: An integer indicating the maximum length. All values in
lengths should be less than max_length.
Returns:
masks: Masks that can be used to get rid of padding.
"""
tiled_ranges = array_ops.tile(
array_ops.expand_dims(math_ops.range(max_length), 0),
[array_ops.shape(lengths)[0], 1])
lengths = array_ops.expand_dims(lengths, 1)
masks = math_ops.to_float(
math_ops.to_int64(tiled_ranges) < math_ops.to_int64(lengths))
return masks
评论列表
文章目录