def change_pad_value(values, mask, pad_val):
"""Given a set of values and a pad mask, change the value of all pad entries.
Args:
values (Tensor): of shape [batch_size, seq_length, :, ..., :].
mask (Tensor): binary float tensor of shape [batch_size, seq_length]
pad_val (float): value to set all pad entries to
Returns:
Tensor: a new Tensor of same shape as values
"""
# broadcast the mask to match shape of values
mask = expand_dims_for_broadcast(mask, values) # (batch_size, seq_length, 1, ..., 1)
mask = broadcast(mask, values)
mask = tf.cast(mask, tf.bool) # cast to bool
# broadcast val
broadcast_val = pad_val * tf.ones(tf.shape(values))
new_values = tf.select(mask, values, broadcast_val)
return new_values
评论列表
文章目录