tf_util.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:distributional_perspective_on_RL 作者: Kiwoo 项目源码 文件源码
def lengths_to_mask(lengths_b, max_length):
    """
    Turns a vector of lengths into a boolean mask

    Args:
        lengths_b: an integer vector of lengths
        max_length: maximum length to fill the mask

    Returns:
        a boolean array of shape (batch_size, max_length)
        row[i] consists of True repeated lengths_b[i] times, followed by False
    """
    lengths_b = tf.convert_to_tensor(lengths_b)
    assert lengths_b.get_shape().ndims == 1
    mask_bt = tf.expand_dims(tf.range(max_length), 0) < tf.expand_dims(lengths_b, 1)
    return mask_bt
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号