misc.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:jack 作者: uclmr 项目源码 文件源码
def mask_for_lengths(lengths, max_length=None, mask_right=True, value=-1000.0):
    """
    Creates a [batch_size x max_length] mask.

    Args:
        lengths: int32 1-dim tensor of batch_size lengths
        max_length: int32 0-dim tensor or python int
        mask_right: if True, everything before "lengths" becomes zero and the
            rest "value", else vice versa
        value: value for the mask

    Returns:
        [batch_size x max_length] mask of zeros and "value"s
    """
    mask = tf.sequence_mask(lengths, max_length, dtype=tf.float32)
    if mask_right:
        mask = 1.0 - mask
    mask *= value
    return mask
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号