def reduce_mean(seq_batch, allow_empty=False):
"""Compute the mean of each sequence in a SequenceBatch.
Args:
seq_batch (SequenceBatch): a SequenceBatch with the following attributes:
values (Tensor): a Tensor of shape (batch_size, seq_length, :, ..., :)
mask (Tensor): if the mask values are arbitrary floats (rather than binary), the mean will be
a weighted average.
allow_empty (bool): allow computing the average of an empty sequence. In this case, we assume 0/0 == 0, rather
than NaN. Default is False, causing an error to be thrown.
Returns:
Tensor: of shape (batch_size, :, ..., :)
"""
values, mask = seq_batch.values, seq_batch.mask
# compute weights for the average
sums = tf.reduce_sum(mask, 1, keep_dims=True) # (batch_size, 1)
if allow_empty:
asserts = [] # no assertion
sums = tf.select(tf.equal(sums, 0), tf.ones(tf.shape(sums)), sums) # replace 0's with 1's
else:
asserts = [tf.assert_positive(sums)] # throw error if 0's exist
with tf.control_dependencies(asserts):
weights = mask / sums # (batch_size, seq_length)
return weighted_sum(seq_batch, weights)
评论列表
文章目录