tensorflow_backend.py 文件源码

python
阅读 33 收藏 0 点赞 0 评论 0

项目:SGAITagger 作者: zhiweiuu 项目源码 文件源码
def batch_gather(reference, indices):
    '''Batchwise gathering of row indices.

    The numpy equivalent is reference[np.arange(batch_size), indices].

    # Arguments
        reference: tensor with ndim >= 2 of shape
          (batch_size, dim1, dim2, ..., dimN)
        indices: 1d integer tensor of shape (batch_size) satisfiying
          0 <= i < dim2 for each element i.

    # Returns
        A tensor with shape (batch_size, dim2, ..., dimN)
        equal to reference[1:batch_size, indices]
    '''
    batch_size = tf.shape(reference)[0]
    indices = tf.stack([tf.range(batch_size), indices], axis=1)
    return tf.gather_nd(reference, indices)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号