def embed(sequence_batch, embeds):
mask = sequence_batch.mask
embedded_values = tf.gather(embeds, sequence_batch.values)
embedded_values = tf.verify_tensor_all_finite(embedded_values, 'embedded_values')
# set all pad embeddings to zero
broadcasted_mask = expand_dims_for_broadcast(mask, embedded_values)
embedded_values *= broadcasted_mask
return SequenceBatch(embedded_values, mask)
评论列表
文章目录