def _create_position_embedding(embedding_dim, num_positions, lengths, maxlen):
"""Creates position embeddings.
Args:
embedding_dim: Dimensionality of the embeddings. An integer.
num_positions: The number of positions to be embedded. For example,
if you have inputs of length up to 100, this should be 100. An integer.
lengths: The lengths of the inputs to create position embeddings for.
An int32 tensor of shape `[batch_size]`.
maxlen: The maximum length of the input sequence to create position
embeddings for. An int32 tensor.
Returns:
A tensor of shape `[batch_size, maxlen, embedding_dim]` that contains
embeddings for each position. All elements past `lengths` are zero.
"""
# Create constant position encodings
position_encodings = tf.constant(
position_encoding(num_positions, embedding_dim),
name="position_encoding")
# Slice to size of current sequence
pe_slice = position_encodings[:maxlen, :]
# Replicate encodings for each element in the batch
batch_size = tf.shape(lengths)[0]
pe_batch = tf.tile([pe_slice], [batch_size, 1, 1])
# Mask out positions that are padded
positions_mask = tf.sequence_mask(
lengths=lengths, maxlen=maxlen, dtype=tf.float32)
positions_embed = pe_batch * tf.expand_dims(positions_mask, 2)
return positions_embed
评论列表
文章目录