def call(self, x, mask=None):
mean = super(IntraAttention, self).call(x, mask)
# x: (batch_size, input_length, input_dim)
# mean: (batch_size, input_dim)
ones = K.expand_dims(K.mean(K.ones_like(x), axis=(0, 2)), dim=0) # (1, input_length)
# (batch_size, input_length, input_dim)
tiled_mean = K.permute_dimensions(K.dot(K.expand_dims(mean), ones), (0, 2, 1))
if mask is not None:
if K.ndim(mask) > K.ndim(x):
# Assuming this is because of the bug in Bidirectional. Temporary fix follows.
# TODO: Fix Bidirectional.
mask = K.any(mask, axis=(-2, -1))
if K.ndim(mask) < K.ndim(x):
mask = K.expand_dims(mask)
x = switch(mask, x, K.zeros_like(x))
# (batch_size, input_length, proj_dim)
projected_combination = K.tanh(K.dot(x, self.vector_projector) + K.dot(tiled_mean, self.mean_projector))
scores = K.dot(projected_combination, self.scorer) # (batch_size, input_length)
weights = K.softmax(scores) # (batch_size, input_length)
attended_x = K.sum(K.expand_dims(weights) * x, axis=1) # (batch_size, input_dim)
return attended_x
评论列表
文章目录