seq_batch.py 文件源码

python
阅读 35 收藏 0 点赞 0 评论 0

项目:lang2program 作者: kelvinguu 项目源码 文件源码
def reduce_mean(seq_batch, allow_empty=False):
    """Compute the mean of each sequence in a SequenceBatch.

    Args:
        seq_batch (SequenceBatch): a SequenceBatch with the following attributes:
            values (Tensor): a Tensor of shape (batch_size, seq_length, :, ..., :)
            mask (Tensor): if the mask values are arbitrary floats (rather than binary), the mean will be
            a weighted average.
        allow_empty (bool): allow computing the average of an empty sequence. In this case, we assume 0/0 == 0, rather
            than NaN. Default is False, causing an error to be thrown.

    Returns:
        Tensor: of shape (batch_size, :, ..., :)
    """
    values, mask = seq_batch.values, seq_batch.mask
    # compute weights for the average
    sums = tf.reduce_sum(mask, 1, keep_dims=True)  # (batch_size, 1)

    if allow_empty:
        asserts = []  # no assertion
        sums = tf.select(tf.equal(sums, 0), tf.ones(tf.shape(sums)), sums)  # replace 0's with 1's
    else:
        asserts = [tf.assert_positive(sums)]  # throw error if 0's exist

    with tf.control_dependencies(asserts):
        weights = mask / sums  # (batch_size, seq_length)
    return weighted_sum(seq_batch, weights)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号