def forward(self, input, context, src_pad_mask, tgt_pad_mask):
# Args Checks
input_batch, input_len, _ = input.size()
contxt_batch, contxt_len, _ = context.size()
aeq(input_batch, contxt_batch)
src_batch, t_len, s_len = src_pad_mask.size()
tgt_batch, t_len_, t_len__ = tgt_pad_mask.size()
aeq(input_batch, contxt_batch, src_batch, tgt_batch)
aeq(t_len, t_len_, t_len__, input_len)
aeq(s_len, contxt_len)
# END Args Checks
dec_mask = torch.gt(tgt_pad_mask + self.mask[:, :tgt_pad_mask.size(1),
:tgt_pad_mask.size(1)]
.expand_as(tgt_pad_mask), 0)
input_norm = self.layer_norm_1(input)
query, attn = self.self_attn(input_norm, input_norm, input_norm,
mask=dec_mask)
query_norm = self.layer_norm_2(query+input)
mid, attn = self.context_attn(context, context, query_norm,
mask=src_pad_mask)
output = self.feed_forward(mid+query+input)
# CHECKS
output_batch, output_len, _ = output.size()
aeq(input_len, output_len)
aeq(contxt_batch, output_batch)
n_batch_, t_len_, s_len_ = attn.size()
aeq(input_batch, n_batch_)
aeq(contxt_len, s_len_)
aeq(input_len, t_len_)
# END CHECKS
return output, attn
评论列表
文章目录