def __call__(self, x, z, ze, mask, conv_mask):
att_scale = self.xp.sum(
mask, axis=2, keepdims=True)[:, None, :, :] ** 0.5
pad = self.xp.zeros(
(x.shape[0], x.shape[1], self.width - 1, 1), dtype=x.dtype)
base_x = x
z = F.squeeze(z, axis=3)
# Note: these behaviors of input, output, and attention result
# may refer to the code by authors, which looks little different
# from the paper's saying.
for conv_name, preatt_name in zip(self.conv_names, self.preatt_names):
# Calculate Output of GLU
out = getattr(self, conv_name)(
F.concat([pad, x], axis=2), conv_mask)
# Calcualte Output of Attention using Output of GLU
preatt = seq_linear(getattr(self, preatt_name), out)
query = base_x + preatt
query = F.squeeze(query, axis=3)
c = self.attend(query, z, ze, mask) * att_scale
# Merge Them in Redidual Calculation and Scaling
x = (x + (c + out) * scale05) * scale05
return x
评论列表
文章目录