def _max_pooling_matching(self, h1, h2, w):
"""Max pooling matching operation.
# Arguments
h1: (batch_size, h1_timesteps, embedding_size)
h2: (batch_size, h2_timesteps, embedding_size)
w: weights of one direction, (mp_dim, embedding_size)
# Output shape
(batch_size, h1_timesteps, mp_dim)
"""
# h1 * weights, (batch_size, h1_timesteps, mp_dim, embedding_size)
h1 = self._time_distributed_multiply(h1, w)
# h2 * weights, (batch_size, h2_timesteps, mp_dim, embedding_size)
h2 = self._time_distributed_multiply(h2, w)
# reshape v1 to (batch_size, h1_timesteps, 1, mp_dim, embedding_size)
h1 = K.expand_dims(h1, axis=2)
# reshape v1 to (batch_size, 1, h2_timesteps, mp_dim, embedding_size)
h2 = K.expand_dims(h2, axis=1)
# cosine similarity, (batch_size, h1_timesteps, h2_timesteps, mp_dim)
cos = self._cosine_similarity(h1, h2)
# (batch_size, h1_timesteps, mp_dim)
matching = K.max(cos, axis=2)
return matching
评论列表
文章目录