def __call__(self, x1, x2):
def _dot_product(args):
x = args[0]
y = args[1]
return K.batch_dot(x, K.permute_dimensions(y, (0, 2, 1)))
def _normalize(args, transpose=False):
att_w = args[0]
x = args[1]
if transpose:
att_w = K.permute_dimensions(att_w, (0, 2, 1))
e = K.exp(att_w - K.max(att_w, axis=-1, keepdims=True))
sum_e = K.sum(e, axis=-1, keepdims=True)
nor_e = e / sum_e
return K.batch_dot(nor_e, x)
# (batch_size, timesteps1, dim)
f1 = self.model(x1)
# (batch_size, timesteps2, dim)
f2 = self.model(x2)
output_shape = (self.sequence_length, self.sequence_length,)
# attention weights, (batch_size, timesteps1, timesteps2)
att_w = Lambda(
_dot_product,
output_shape=output_shape)([f1, f2])
output_shape = (self.sequence_length, self.input_dim,)
# (batch_size, timesteps1, dim)
att1 = Lambda(
_normalize, arguments={'transpose': False},
output_shape=output_shape)([att_w, x2])
# (batch_size, timestep2, dim)
att2 = Lambda(
_normalize, arguments={'transpose': True},
output_shape=output_shape)([att_w, x1])
return att1, att2
评论列表
文章目录