seq_batch.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:lang2program 作者: kelvinguu 项目源码 文件源码
def change_pad_value(values, mask, pad_val):
    """Given a set of values and a pad mask, change the value of all pad entries.

    Args:
        values (Tensor): of shape [batch_size, seq_length, :, ..., :].
        mask (Tensor): binary float tensor of shape [batch_size, seq_length]
        pad_val (float): value to set all pad entries to

    Returns:
        Tensor: a new Tensor of same shape as values
    """
    # broadcast the mask to match shape of values
    mask = expand_dims_for_broadcast(mask, values)  # (batch_size, seq_length, 1, ..., 1)
    mask = broadcast(mask, values)
    mask = tf.cast(mask, tf.bool)  # cast to bool

    # broadcast val
    broadcast_val = pad_val * tf.ones(tf.shape(values))

    new_values = tf.select(mask, values, broadcast_val)
    return new_values
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号