beam_aligner.py 文件源码

python
阅读 37 收藏 0 点赞 0 评论 0

项目:almond-nnparser 作者: Stanford-Mobisocial-IoT-Lab 项目源码 文件源码
def _maybe_tensor_gather_helper(gather_indices, gather_from, batch_size,
                                range_size, gather_shape):
    """Maybe applies _tensor_gather_helper.
    This applies _tensor_gather_helper when the gather_from dims is at least as
    big as the length of gather_shape. This is used in conjunction with nest so
    that we don't apply _tensor_gather_helper to inapplicable values like scalars.
    Args:
      gather_indices: The tensor indices that we use to gather.
      gather_from: The tensor that we are gathering from.
      batch_size: The batch size.
      range_size: The number of values in each range. Likely equal to beam_width.
      gather_shape: What we should reshape gather_from to in order to preserve the
        correct values. An example is when gather_from is the attention from an
        AttentionWrapperState with shape [batch_size, beam_width, attention_size].
        There, we want to preserve the attention_size elements, so gather_shape is
        [batch_size * beam_width, -1]. Then, upon reshape, we still have the
        attention_size as desired.
    Returns:
      output: Gathered tensor of shape tf.shape(gather_from)[:1+len(gather_shape)]
        or the original tensor if its dimensions are too small.
    """
    if gather_from.shape.ndims >= len(gather_shape):
        return _tensor_gather_helper(
            gather_indices=gather_indices,
            gather_from=gather_from,
            batch_size=batch_size,
            range_size=range_size,
            gather_shape=gather_shape)
    else:
        return gather_from
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号