ChainCRF.py 文件源码

python
阅读 34 收藏 0 点赞 0 评论 0

项目:emnlp2017-bilstm-cnn-crf 作者: UKPLab 项目源码 文件源码
def batch_gather(reference, indices):
        '''Batchwise gathering of row indices.

        The numpy equivalent is reference[np.arange(batch_size), indices].

        # Arguments
            reference: tensor with ndim >= 2 of shape
              (batch_size, dim1, dim2, ..., dimN)
            indices: 1d integer tensor of shape (batch_size) satisfiying
              0 <= i < dim2 for each element i.

        # Returns
            A tensor with shape (batch_size, dim2, ..., dimN)
            equal to reference[1:batch_size, indices]
        '''
        batch_size = K.shape(reference)[0]
        indices = tf.pack([tf.range(batch_size), indices], axis=1)
        return tf.gather_nd(reference, indices)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号