def attend(self, query, key, value, mask, minfs=None):
"""
Input shapes:
q=(b, units, dec_l), k=(b, units, enc_l),
v=(b, units, dec_l, enc_l), m=(b, dec_l, enc_l)
"""
# Calculate Attention Scores with Mask for Zero-padded Areas
pre_a = F.batch_matmul(query, key, transa=True) # (b, dec_l, enc_l)
minfs = self.xp.full(pre_a.shape, -np.inf, pre_a.dtype) \
if minfs is None else minfs
pre_a = F.where(mask, pre_a, minfs)
a = F.softmax(pre_a, axis=2)
# if values in axis=2 are all -inf, they become nan. thus do re-mask.
a = F.where(self.xp.isnan(a.data),
self.xp.zeros(a.shape, dtype=a.dtype), a)
reshaped_a = a[:, None] # (b, 1, dec_xl, enc_l)
# Calculate Weighted Sum
pre_c = F.broadcast_to(reshaped_a, value.shape) * value
c = F.sum(pre_c, axis=3, keepdims=True) # (b, units, dec_xl, 1)
return c
评论列表
文章目录