pooling.py 文件源码

python
阅读 29 收藏 0 点赞 0 评论 0

项目:dl4nlp_in_theano 作者: luyaojie 项目源码 文件源码
def get_pooling_batch(hs, mask, pooling_method):
    """
    :param hs:   (batch, len, dim)
    :param mask: (batch, len)
    :param pooling_method:
    :return:
    """
    if pooling_method == 'max':
        add_v = ((1 - mask) * -BIG_INT)[:, :, None]
        return T.max(hs + add_v, axis=1)
    elif pooling_method == 'min':
        add_v = ((1 - mask) * BIG_INT)[:, :, None]
        return T.min(hs + add_v, axis=1)
    elif pooling_method in ['averaging', 'mean' , 'average']:
        return T.sum(hs * mask[:, :, None], axis=1) / T.sum(mask, axis=1)[:, None]
    elif pooling_method == 'sum':
        return T.sum(hs * mask[:, :, None], axis=1)
    elif pooling_method in ['final', 'last']:
        return hs[:, -1, :]
    else:
        raise NotImplementedError('Not implemented pooling method: {}'.format(pooling_method))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号