def get_pooling_batch(hs, mask, pooling_method):
"""
:param hs: (batch, len, dim)
:param mask: (batch, len)
:param pooling_method:
:return:
"""
if pooling_method == 'max':
add_v = ((1 - mask) * -BIG_INT)[:, :, None]
return T.max(hs + add_v, axis=1)
elif pooling_method == 'min':
add_v = ((1 - mask) * BIG_INT)[:, :, None]
return T.min(hs + add_v, axis=1)
elif pooling_method in ['averaging', 'mean' , 'average']:
return T.sum(hs * mask[:, :, None], axis=1) / T.sum(mask, axis=1)[:, None]
elif pooling_method == 'sum':
return T.sum(hs * mask[:, :, None], axis=1)
elif pooling_method in ['final', 'last']:
return hs[:, -1, :]
else:
raise NotImplementedError('Not implemented pooling method: {}'.format(pooling_method))
评论列表
文章目录