masking.py 文件源码

python
阅读 32 收藏 0 点赞 0 评论 0

项目:jack 作者: uclmr 项目源码 文件源码
def mask_3d(sequences, sequence_lengths, mask_value, dimension=2):
    """
    Given a batch of matrices, each with shape m x n, mask the values in each
    row after the positions indicated in sentence_sizes.
    This function is supposed to mask the last columns in the raw attention
    matrix (e_{i, j}) in cases where the sentence2 is smaller than the
    maximum.

    Args:
        sequences: tensor with shape (batch_size, m, n)
        sequence_lengths: tensor with shape (batch_size) containing the sentence sizes that
           should be limited
        mask_value: scalar value to assign to items after sentence size
        dimension: over which dimension to mask values
    Returns:
        A tensor with the same shape as `values`
    """
    if dimension == 1:
        sequences = tf.transpose(sequences, [0, 2, 1])
    time_steps1, time_steps2 = tf.shape(sequences)[1], tf.shape(sequences)[2]
    ones = tf.ones_like(sequences, dtype=tf.int32)
    pad_values = mask_value * tf.cast(ones, tf.float32)
    mask = tf.sequence_mask(sequence_lengths, time_steps2)
    # mask is (batch_size, sentence2_size). we have to tile it for 3d
    mask3d = tf.tile(tf.expand_dims(mask, 1), (1, time_steps1, 1))
    masked = tf.where(mask3d, sequences, pad_values)
    return tf.transpose(masked, [0, 2, 1]) if dimension == 1 else masked
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号